In a previous post we looked at how to write a finite domain constraint solver in Haskell. Now we're going to look at how to use this solver to solve Sudoku puzzles. First we give the module header and define a useful type:
module Sudoku (Puzzle, printSudoku, displayPuzzle, sudoku) where import Control.Monad import Data.List (transpose) import FD type Puzzle = [Int]
We represent both unsolved and solved puzzles as a list of 81 Ints. In an unsolved puzzle, we use 0 to represent a blank square. The numbers 1 to 9 represent squares with known values. Here is an example of an unsolved puzzle that I copied from the local newspaper:
test :: Puzzle test = [ 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 1, 0, 6, 5, 0, 7, 4, 0, 2, 7, 0, 0, 0, 0, 0, 0, 8, 0, 3, 0, 0, 1, 0, 0, 0, 0, 3, 0, 0, 0, 8, 0, 0, 0, 0, 5, 0, 0, 9, 0, 7, 0, 0, 5, 0, 0, 0, 8, 0, 0, 6, 3, 0, 1, 2, 0, 4, 0, 0, 0, 0, 0, 6, 0, 1, 0, 0, 0, 0 ]
The function displayPuzzle displays a puzzle for us in rows and columns. The function printSudoku will solve a puzzle by calling sudoku, which we will define below. It then prints each solution.
displayPuzzle :: Puzzle -> String displayPuzzle = unlines . map show . chunk 9 printSudoku :: Puzzle -> IO () printSudoku = putStr . unlines . map displayPuzzle . sudoku chunk :: Int -> [a] -> [[a]] chunk _ [] = [] chunk n xs = ys : chunk n zs where (ys, zs) = splitAt n xs
We now present the code to actually solve the puzzle:
sudoku :: Puzzle -> [Puzzle] sudoku puzzle = runFD $ do vars <- newVars 81 [1..9] zipWithM_ (\x n -> when (n > 0) (x `hasValue` n)) vars puzzle mapM_ allDifferent (rows vars) mapM_ allDifferent (columns vars) mapM_ allDifferent (boxes vars) labelling vars rows, columns, boxes :: [a] -> [[a]] rows = chunk 9 columns = transpose . rows boxes = concat . map (map concat . transpose) . chunk 3 . chunk 3 . chunk 3
We start by initialising 81 new solver variables, one for each square in the puzzle. Next, we constrain each known square in the puzzle to its given value. The next three lines create the Sudoku constraints: each number 1 to 9 may occur only once in each row, column and 3x3 box. The functions for grouping the variables into rows, columns and boxes are given below the main function. Finally, we call labelling to search for solutions.
Let's see how this performs with our test puzzle, running on a 2.4GHz Intel Core 2 Duo with 4GB RAM:
> ghc --make -O2 test [1 of 3] Compiling FD ( FD.hs, FD.o ) [2 of 3] Compiling Sudoku ( Sudoku.hs, Sudoku.o ) [3 of 3] Compiling Main ( test.hs, test.o ) Linking test ... > time ./test [5,6,7,4,8,3,2,9,1] [9,3,8,1,2,6,5,4,7] [4,1,2,7,9,5,3,6,8] [6,8,9,3,7,2,1,5,4] [7,4,3,6,5,1,8,2,9] [1,2,5,8,4,9,6,7,3] [2,5,4,9,3,8,7,1,6] [3,7,1,2,6,4,9,8,5] [8,9,6,5,1,7,4,3,2] real 0m6.518s user 0m6.480s sys 0m0.032s
So it takes around 6.5 seconds to find all solutions to this puzzle (of which there is exactly one, as expected). Can we do any better than that? Yes we can, actually. Recall our earlier definition of the function different which constrains two variables to have different values, and is used by the allDifferent function which features prominently in our Sudoku solving code:
different :: FDVar s -> FDVar s -> FD s () different = addBinaryConstraint $ \x y -> do xv <- lookup x yv <- lookup y guard $ IntSet.size xv > 1 || IntSet.size yv > 1 || xv /= yv
Notice how this function doesn't try to constrain the domains of the variables or do any constraint propagation. That's because, in the general case, we don't have enough information to do this, we just have to store the constraint and retest it each time the domains are changed. This means that in our Sudoku solver, the labelling step will effectively try each value 1 to 9 in turn for each blank square and then test whether the allDifferent constraints still hold. This is almost a brute force algorithm.
However, we have overlooked one case where we can further constrain the variables: if the domain of one variable is a singleton set, then we can remove that value from the domain of the other variable. This should drastically reduce the number of tests we need to do during labelling. Here's the modified function:
different = addBinaryConstraint $ \x y -> do xv <- lookup x yv <- lookup y guard $ IntSet.size xv > 1 || IntSet.size yv > 1 || xv /= yv when (IntSet.size xv == 1 && xv `IntSet.isProperSubsetOf` yv) $ update y (yv `IntSet.difference` xv) when (IntSet.size yv == 1 && yv `IntSet.isProperSubsetOf` xv) $ update x (xv `IntSet.difference` yv)
Here are the times with the modified funtion:
> time ./test [5,6,7,4,8,3,2,9,1] [9,3,8,1,2,6,5,4,7] [4,1,2,7,9,5,3,6,8] [6,8,9,3,7,2,1,5,4] [7,4,3,6,5,1,8,2,9] [1,2,5,8,4,9,6,7,3] [2,5,4,9,3,8,7,1,6] [3,7,1,2,6,4,9,8,5] [8,9,6,5,1,7,4,3,2] real 0m0.180s user 0m0.160s sys 0m0.020sThat's a 36-fold improvement. Not bad for adding four lines of code!