sudoku solver
Question: Given a partially filled Sudoku board, write an efficient solver to complete the board according to the rules of Sudoku.
The rules of Sudoku
The rules are very simple:
- cells need to be filled with numbers ranging from 0 to 9.
- each row must contain unique numbers
- each column must contain unique numbers
- each 3x3 square section must contain unique numbers
Solution
You may think of solving a Sudoku board as a highly complex problem, but as you'll see it's actually very straightforward, and only involves:
- a DFS
- set membership checking
- backtracking.
Backtracking doesn't really mean much, either, since Backtracking happens as a consequence of the DFS and set membership checking. It occurs naturally.
Constraints checking
Each row must contain unique numbers
Checking whether a number is in a set is an O(1)
operation, so filling a row with unique numbers can be done in O(n)
.
Each column must contain unique numbers
The same rule applies to columns. Once again, filling up a column with unique numbers is a fairly trivial problem.
Each section must contain unique numbers
Likewise, filling a square with unique numbers is a trivial task. Let's depict it to be as explicit as possible.
Linear scan
We'll scan the board in a linear manner, one cell at a time.
At each cell, if it's empty, we'll:
- guess a number
- perform constraints checking
- if the guess passes all constraints, then we recurse into the next cell, else we recurse up the stack.
Solving with all constraints
This solver is already in the 80% fastest on leetcode; it is fast, and due to its simplicity, it's easy to remember.
Below we can see the backtracking in action.
When the solver can no longer progress; when it reaches a deadend in the solution space; it has no choice but to recurse back up the call stack, and explore new paths.
Let's break down the solution.
1 rows, cols, triples = ddict(set), ddict(set), ddict(set)
2 for r, c in product(range(9), repeat=2):
3 if board[r][c] != ".":
4 rows[r].add(board[r][c])
5 cols[c].add(board[r][c])
6 triples[(r // 3, c // 3)].add(board[r][c])
Above, we begin by populating row, column, and section sets of the pre-existing numbers. Those allow us to check our constraints in O(1)
time.
1 def dfs(r, c):
2 if r == 9:
3 return True
This is our dfs basecase, in case we are done exploring the board.
1 if A[r][c] != '.':
2 return dfs((r, r+1)[c==8], (c+1,0)[c==8])
In case we already have a digit in the current cell, we move on to the next one.
1 for dig in '123456789':
2 if dig not in rows[r] and dig not in cols[c] and dig not in triples[t]:
3 board[r][c] = dig
4 rows[r].add(dig)
5 cols[c].add(dig)
6 triples[t].add(dig)
We loop over all numbers from 1 through to 9, to test them out. The test in this case is set membership in the current row, column or section.
1 if dfs((r, r+1)[c==8], (c+1,0)[c==8]):
2 return True
If there are no clashes, we simply recurse into the following cell.
1from time import time
2
3from collections import defaultdict as ddict
4from itertools import product
5
6def solveSudoku(board):
7 def dfs(r, c):
8 if r == 9:
9 return True
10 if A[r][c] != '.':
11 return dfs((r, r+1)[c==8], (c+1,0)[c==8])
12 t = (r // 3, c // 3)
13 for dig in '123456789':
14 if dig not in rows[r] and dig not in cols[c] and dig not in triples[t]:
15 board[r][c] = dig
16 rows[r].add(dig)
17 cols[c].add(dig)
18 triples[t].add(dig)
19 if dfs((r, r+1)[c==8], (c+1,0)[c==8]):
20 return True
21 else:
22 board[r][c] = "."
23 rows[r].discard(dig)
24 cols[c].discard(dig)
25 triples[t].discard(dig)
26 return False
27 rows, cols, triples = ddict(set), ddict(set), ddict(set)
28 for r, c in product(range(9), repeat=2):
29 if board[r][c] != ".":
30 rows[r].add(board[r][c])
31 cols[c].add(board[r][c])
32 triples[(r // 3, c // 3)].add(board[r][c])
33 dfs(0, 0)
34
35def print_sudoku(A):
36 for r in A:
37 print(r)
38 print('')
39
40A = [[".",".","3","8",".",".","4",".","."],
41 [".",".",".",".","1",".",".","7","."],
42 [".","6",".",".",".","5",".",".","9"],
43 [".",".",".","9",".",".","6",".","."],
44 [".","2",".",".",".",".",".","1","."],
45 [".",".","4",".",".","3",".",".","2"],
46 [".",".","2",".",".",".","8",".","."],
47 [".","1",".",".",".",".",".","5","."],
48 ["9",".",".",".",".","7",".",".","3"]]
49
50
51
52s = time()
53print(solveSudoku(A))
54print_sudoku(A)
55print(time() - s)