sudoku solver


Question: Given a partially filled Sudoku board, write an efficient solver to complete the board according to the rules of Sudoku.

The rules of Sudoku

The rules are very simple:


You may think of solving a Sudoku board as a highly complex problem, but as you'll see it's actually very straightforward, and only involves:

Backtracking doesn't really mean much, either, since Backtracking happens as a consequence of the DFS and set membership checking. It occurs naturally.

Constraints checking

Each row must contain unique numbers

Checking whether a number is in a set is an O(1) operation, so filling a row with unique numbers can be done in O(n).

Each column must contain unique numbers

The same rule applies to columns. Once again, filling up a column with unique numbers is a fairly trivial problem.

Each section must contain unique numbers

Likewise, filling a square with unique numbers is a trivial task. Let's depict it to be as explicit as possible.

Linear scan

We'll scan the board in a linear manner, one cell at a time.

At each cell, if it's empty, we'll:

Solving with all constraints

This solver is already in the 80% fastest on leetcode; it is fast, and due to its simplicity, it's easy to remember.

Below we can see the backtracking in action.

When the solver can no longer progress; when it reaches a deadend in the solution space; it has no choice but to recurse back up the call stack, and explore new paths.

Let's break down the solution.

1    rows, cols, triples = ddict(set), ddict(set), ddict(set)
2    for r, c in product(range(9), repeat=2):
3        if board[r][c] != ".":
4            rows[r].add(board[r][c])
5            cols[c].add(board[r][c])
6            triples[(r // 3, c // 3)].add(board[r][c])

Above, we begin by populating row, column, and section sets of the pre-existing numbers. Those allow us to check our constraints in O(1) time.

1    def dfs(r, c):
2        if r == 9:
3            return True

This is our dfs basecase, in case we are done exploring the board.

1        if A[r][c] != '.':
2            return dfs((r, r+1)[c==8], (c+1,0)[c==8])

In case we already have a digit in the current cell, we move on to the next one.

1        for dig in '123456789':
2            if dig not in rows[r] and dig not in cols[c] and dig not in triples[t]:
3                board[r][c] = dig
4                rows[r].add(dig)
5                cols[c].add(dig)
6                triples[t].add(dig)

We loop over all numbers from 1 through to 9, to test them out. The test in this case is set membership in the current row, column or section.

1                if dfs((r, r+1)[c==8], (c+1,0)[c==8]):
2                    return True

If there are no clashes, we simply recurse into the following cell.

 1from time import time
 3from collections import defaultdict as ddict
 4from itertools import product
 6def solveSudoku(board):
 7    def dfs(r, c):
 8        if r == 9:
 9            return True
10        if A[r][c] != '.':
11            return dfs((r, r+1)[c==8], (c+1,0)[c==8])
12        t = (r // 3, c // 3)
13        for dig in '123456789':
14            if dig not in rows[r] and dig not in cols[c] and dig not in triples[t]:
15                board[r][c] = dig
16                rows[r].add(dig)
17                cols[c].add(dig)
18                triples[t].add(dig)
19                if dfs((r, r+1)[c==8], (c+1,0)[c==8]):
20                    return True
21                else:
22                    board[r][c] = "."
23                    rows[r].discard(dig)
24                    cols[c].discard(dig)
25                    triples[t].discard(dig)
26        return False
27    rows, cols, triples = ddict(set), ddict(set), ddict(set)
28    for r, c in product(range(9), repeat=2):
29        if board[r][c] != ".":
30            rows[r].add(board[r][c])
31            cols[c].add(board[r][c])
32            triples[(r // 3, c // 3)].add(board[r][c])
33    dfs(0, 0)
35def print_sudoku(A):
36    for r in A:
37        print(r)
38    print('')
40A = [[".",".","3","8",".",".","4",".","."],
41    [".",".",".",".","1",".",".","7","."],
42    [".","6",".",".",".","5",".",".","9"],
43    [".",".",".","9",".",".","6",".","."],
44    [".","2",".",".",".",".",".","1","."],
45    [".",".","4",".",".","3",".",".","2"],
46    [".",".","2",".",".",".","8",".","."],
47    [".","1",".",".",".",".",".","5","."],
48    ["9",".",".",".",".","7",".",".","3"]]
52s = time()
55print(time() - s)