Sudoku is an interesting problem because ultimately it has a mechanical solution, and yet finding this solution still presents an difficult challenge. One way to solve the problem is to systematically generate all possible permutations and discard those that violate the rules. You can do this one row at a time, and if you finally fill in all the rows, you've found a valid solution.
However, this method, brute force, can be slow. There are a variety of optimisations you can apply to the above problem to improve the efficiency:
- Rather than solving rows from top to bottom, order them from least empty to most empty. This crease a bottom heavy search tree which ensures that as many branches will be culled with the smallest amount of work.
- Rotate the puzzle 90° to increase the effect of the above optimisation.
- Encode the rules into the permutation generation function to avoid generating permutations that won't work.
- Within individual rows, generate permutations for columns sorted by least possibilities to most possibilies.
There are many more optimisations, but here is my efficient implementation which does all of the above and more.
//
// main.cpp
// COSC329 Sudoku
//
// Created by Samuel Williams on 26/05/11.
// Copyright 2011 Samuel Williams. All rights reserved.
//
#include <vector>
#include <iostream>
#include <fstream>
#include <cassert>
#include <time.h>
#include <stdexcept>
#include <sstream>
//#define G_GRAPH
//#define G_DEBUG
template <typename ItemT, typename CostT, typename ArrayT, typename CostArrayT>
void insertion_sort(ItemT item, CostT cost, ArrayT & v, CostArrayT &c, std::size_t offset) {
while (offset > 0 && cost < c[offset-1]) {
c[offset] = c[offset-1];
v[offset] = v[offset-1];
offset -= 1;
}
v[offset] = item;
c[offset] = cost;
}
template <typename AnyT>
inline void set_bit(AnyT & bits, unsigned position) {
bits |= (AnyT(1) << position);
}
template <typename AnyT>
inline void clear_bit(AnyT & bits, unsigned position) {
bits &= ~(AnyT(1) << position);
}
template <typename AnyT>
inline bool is_bit_set(const AnyT & bits, unsigned position) {
return bits & (AnyT(1) << position);
}
#ifdef G_DEBUG
static long long permutation_count = 0;
static long long solution_count = 0;
// Statistics for row spaces
static long long row_spaces = 0;
static long long row_count = 0;
#endif
#ifdef G_GRAPH
class Grapher {
private:
std::stringstream _buffer;
std::vector<std::string> _tree;
unsigned _count;
public:
void start ();
void end ();
void enter(std::string text);
void exit();
void save (std::ostream &);
};
void replace(std::string& str, const std::string& oldStr, const std::string& newStr)
{
std::size_t pos = 0;
while((pos = str.find(oldStr, pos)) != std::string::npos)
{
str.replace(pos, oldStr.length(), newStr);
pos += newStr.length();
}
}
void Grapher::start () {
_count = 0;
_buffer << "digraph { rankdir=TB; size=\"8,5\";" << std::endl;
_buffer << "\t ranksep=\"1.0 equally\";" << std::endl;
}
void Grapher::end () {
_buffer << "}" << std::endl;
}
void Grapher::enter(std::string text) {
std::stringstream child_name;
child_name << "n" << _count++;
replace(text, "\n", "\\n");
if (_tree.size() > 0) {
std::string parent_name = _tree.back();
_buffer << "\tnode [shape=rectangle,label=\"" << text << "\"]; " << parent_name << " -> " << child_name.str() << ";" << std::endl;
} else {
_buffer << "\tnode [shape=rectangle,label=\"" << text << "\"]; " << child_name.str() << ";" << std::endl;
}
_tree.push_back(child_name.str());
}
void Grapher::exit() {
_tree.pop_back();
}
void Grapher::save (std::ostream & out) {
out << _buffer.str();
}
Grapher g_grapher;
#endif
template <unsigned N>
class Permutations {
public:
static const unsigned MASK = ((1 << N) - 1) << 1;
unsigned forbidden[N];
struct Result {
unsigned data[N];
Result () {
}
Result (unsigned _data[N]) {
memcpy(data, _data, sizeof(unsigned[N]));
}
operator unsigned * () {
return data;
}
void print (std::ostream & out) const {
out << data[0];
for (std::size_t i = 1; i < N; i += 1) {
out << ' ' << data[i];
}
}
};
std::vector<Result> permutations;
protected:
unsigned out[N], order[N], count;
unsigned available[N], fixed[N];
unsigned top;
bool okay (unsigned k, unsigned value) {
return !(forbidden[k] & (1 << value));
}
// This permutation generation doesn't generate lexicographic order.
// However, it might be possible if we consider a pointer into available,
// and rather than take items from the top, take items from the front by
// incrementing pointer.
void recurse(unsigned k) {
#ifdef G_GRAPH
std::stringstream current;
current << "(" << permutation_count << ") ";
Result intermediate_result(out);
intermediate_result.print(current);
g_grapher.enter(current.str());
#endif
if (top != 0) {
// We reduce the number of available items by 1 and recurse.
top--;
#ifdef G_GRAPH
permute(k);
#else
if (top != 1) {
permute(k);
} else {
bottom(k);
}
#endif
top++;
} else {
next();
}
#ifdef G_GRAPH
g_grapher.exit();
#endif
}
void bottom(unsigned k) {
#ifdef G_DEBUG
permutation_count += 1;
#endif
unsigned a = order[k];
unsigned b = order[k+1];
// available[0], available[1]
if (okay(a, available[0]) && okay(b, available[1])) {
out[a] = available[0];
out[b] = available[1];
next();
}
if (okay(b, available[0]) && okay(a, available[1])) {
out[b] = available[0];
out[a] = available[1];
next();
}
}
void permute(unsigned k) {
#ifdef G_DEBUG
permutation_count += 1;
#endif
unsigned p = order[k];
// We need to handle the last available item specifically,
// So this loop handles the case where available items >= 2
for (unsigned i = 0; i < top; i += 1) {
if (okay(p, available[i])) {
// We take available[i] and put it in the permutation.
out[p] = available[i];
// We take the top value and replace available[i].
available[i] = available[top];
recurse(k+1);
// We replace the item back unsignedo available[i].
available[i] = out[p];
#ifdef G_GRAPH
// Improve debugging
out[p] = 0;
#endif
}
}
if (okay(p, available[top])) {
out[p] = available[top];
recurse(k+1);
#ifdef G_GRAPH
// Improve debugging
out[p] = 0;
#endif
}
}
// Could be optimised to use BSF instruction.
static inline unsigned bit_shift_left (unsigned value) {
return ffs((int)(value & MASK)) - 1;
/*
unsigned count = 0;
while (value != 0) {
count++;
value >>= 1;
}
return count - 1;
*/
}
// http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetNaive
static inline unsigned count_zero_bits (unsigned value) {
unsigned int c, mask = MASK;
// We inverse the value and count 1 bits.
value = ~value;
for (c = 0; mask; value >>= 1, mask >>= 1)
{
c += (mask & 1) & (value & 1);
}
return c;
}
public:
Permutations () {
memset(forbidden, 0, sizeof(forbidden));
memset(fixed, 0, sizeof(fixed));
}
void generate () {
// Check if there are any possible permutations:
// We need mask all initial possibilities, e.g for N = 9, b1111111110
unsigned sum = MASK;
for (unsigned i = 0; i < N; i += 1) {
sum &= forbidden[i];
// If it is not possible to put any value in a particular column,
// no permutation can be generated.
if ((forbidden[i] & MASK) == MASK)
return;
}
// If sum is non-zero then there is at least one number
// forbidden in all positions - e.g. no valid permutations.
if (sum != 0) {
return;
}
// Reset the permutation generation state:
permutations.resize(0);
unsigned used[N+1];
memset(used, 0, sizeof(used));
// Number of empty spots.
count = 0;
unsigned cost[N];
for (unsigned j = 0; j < N; j += 1) {
// Go backwards through the columns
//unsigned i = (N - 1) - j;
// Go forwards through the columns
unsigned i = j;
// Position zero is not available.
// forbidden[i] |= 1;
// Check if only one bit is zero in the given forbidden bitset.
// This indicates that there is only one possible value for this position (e.g. fixed)
unsigned mask = ~forbidden[i] & MASK;
if (mask & (mask - 1)) {
// More than one bit is set
out[i] = 0;
unsigned c = count_zero_bits(forbidden[i]);
// --- Insertion Sort ---
// Most inefficient order:
//c = N - c;
insertion_sort(i, c, order, cost, count);
count += 1;
// --- Unordered ---
//order[count] = i;
//cost[count] = c;
//order[count++] = i;
} else {
// Only one bit isn't set, so this value isn't part of
// permutation generation.
// Find power of two bit position.
unsigned value = bit_shift_left(~forbidden[i]);
// Save the value here - it is never updated.
out[i] = value;
// Used for direct recursion.
fixed[i] = 1;
// An item cannot be used twice in the same row.
if (used[value] == 1) {
// If we encounter this case, it means that there are no possible permutations.
return;
}
// Item is not added to available list.
used[value] = 1;
}
}
#ifdef G_DEBUG
row_count += 1;
row_spaces += count;
#endif
// There is only one permutation, because all items are fixed.
if (count == 0) {
next();
return;
}
//quicksort(order, cost, 0, count);
// Setup list of sequential items excluding used
top = 0;
for (unsigned i = 0; i < N; i += 1) {
if (!used[i+1]) {
available[top] = i+1;
top += 1;
}
}
recurse(0);
}
void next () {
//ResultT result(out, out+N);
permutations.push_back(out);
}
};
class SudokuBoard {
public:
typedef unsigned RowT[9];
// Index _pieces[row][column]
unsigned _pieces[9][9];
// Index _forbidden[column]
unsigned _forbidden[9];
// Index _fixed[row]
unsigned _fixed[9];
// Index _blocks[row][column]
unsigned _blocks[3][3];
bool print_solution;
clock_t solution_time;
SudokuBoard () {
memset(_pieces, 0, sizeof(_pieces));
memset(_forbidden, 0, sizeof(_forbidden));
memset(_blocks, 0, sizeof(_blocks));
print_solution = true;
}
unsigned & piece_at (unsigned x, unsigned y) {
return _pieces[y][x];
}
void load (std::iostream & input) {
for (unsigned i = 0; i < 9; i += 1) {
RowT values;
unsigned row_mask = 0;
// Load one row at a time.
for (unsigned j = 0; j < 9; j += 1) {
input >> values[j];
if (/*values[j] < 0 || */values[j] > 9) {
throw std::range_error("Invalid sudoku value");
}
if (values[j] != 0)
set_bit(row_mask, values[j] - 1);
}
// Keep track of what values are fixed so we can optimise calculating forbidden positions.
_fixed[i] = row_mask;
set_row(i, values);
}
}
// If you try to set a row to an incorrect permutation, this function
// will create instability in the tree search because forbidden bits
// will be tracked incorrectly. Thus, if you don't know if your
// permutation is valid you should check it first. However, the
// permutation generation algorithm in this case will only generate
// valid configurations.
void set_row (unsigned row, unsigned values[9]) {
unsigned * block = _blocks[row / 3];
// We have to assume that we can set this row with no conflicts.
// The permutation generation algorithm should ensure that we
// don't hit a forbidden permutation. However, if we do set a
// a row to an invalid value, we will have serious problems.
for (unsigned b = 0, bc = 0; b < 9; b += 3, bc += 1) {
// We have to clear and set in blocks of 3 to avoid collisions..
// e.g. if we clear and set bits 1 in the same block.
for (unsigned column = b; column < (b + 3); column += 1) {
// Unset the forbidden bit for the old value, if appropriate.
if (_pieces[row][column] != values[column] && _pieces[row][column] > 0) {
clear_bit(_forbidden[column], _pieces[row][column] - 1);
clear_bit(block[bc], _pieces[row][column] - 1);
}
}
for (unsigned column = b; column < (b + 3); column += 1) {
// Set the forbidden bit for the new value, if appropriate.
if (_pieces[row][column] != values[column] && values[column] > 0) {
// We put this check to ensure that algorithm is working correctly.
//assert(!is_bit_set(_forbidden[column], values[column] - 1));
//assert(!is_bit_set(block[bc], values[column] - 1));
set_bit(_forbidden[column], values[column] - 1);
set_bit(block[bc], values[column] - 1);
_pieces[row][column] = values[column];
} else if (values[column] == 0) {
_pieces[row][column] = 0;
}
}
}
}
// Try to set a piece in the given row, return false if it was definitely invalid.
bool set_piece(unsigned column, unsigned row, unsigned value) {
unsigned & block = _blocks[row / 3][column / 3];
if (_pieces[row][column] != value) {
if (value > 0) {
// Column check
if (is_bit_set(_forbidden[column], value - 1)) {
return false;
}
// Block check
if (is_bit_set(block, value - 1)) {
return false;
}
}
// So far so good - we know that we don't have a conflict.
// Unset existing forbidden bits
if (_pieces[row][column] > 0) {
clear_bit(_forbidden[column], _pieces[row][column] - 1);
clear_bit(block, _pieces[row][column] - 1);
}
// Update the value
_pieces[row][column] = value;
if (value > 0) {
set_bit(_forbidden[column], value - 1);
set_bit(block, value - 1);
}
}
return true;
}
void debug();
void print(std::ostream & output);
typedef std::pair<unsigned, unsigned> CountT;
static bool sort_counts(CountT lhs, CountT rhs) {
return lhs.second < rhs.second;
}
void solve () {
// We try to organise the rows in order of least sparse
// to most sparse so culling is more effective.
std::vector<CountT> rows;
std::vector<CountT> blocks;
for (unsigned i = 0; i < 9; i += 1) {
CountT row_count(i, 0);
for (unsigned j = 0; j < 9; j += 1) {
if (_pieces[i][j] == 0) {
row_count.second += 1;
}
}
rows.push_back(row_count);
}
for (unsigned i = 0; i < 3; i += 1) {
CountT block_count(i, 0);
for (unsigned j = 0; j < 3; j += 1) {
block_count.second += rows[i * 3 + j].second;
}
blocks.push_back(block_count);
}
//std::sort(rows.begin(), rows.end(), sort_counts);
std::sort(blocks.begin(), blocks.end(), sort_counts);
std::vector<int> order;
order.reserve(9);
for (std::size_t i = 0; i < 3; i += 1) {
std::size_t start = blocks[i].first * 3;
std::sort(rows.begin() + start, rows.begin() + start + 3, sort_counts);
for (std::size_t j = start; j < start + 3; j += 1) {
//std::cerr << "Order " << i << " is row " << count[i].first << " with count " << count[i].second << std::endl;
order.push_back((rows.begin() + j)->first);
}
}
solve(0, order);
}
void solve (unsigned index, const std::vector<int> & order) {
#ifdef G_DEBUG
solution_count += 1;
#endif
unsigned row = order[index];
unsigned size = 0;
unsigned * block = _blocks[row / 3];
Permutations<9> p;
// Update forbidden columns
memcpy(p.forbidden, _forbidden, sizeof(p.forbidden));
// Update forbidden blocks
for (unsigned b = 0; b < 3; b += 1) {
p.forbidden[(b*3)] |= block[b];
p.forbidden[(b*3)+1] |= block[b];
p.forbidden[(b*3)+2] |= block[b];
}
for (unsigned i = 0; i < 9; i += 1) {
unsigned piece = piece_at(i, row);
if (piece == 0) {
size += 1;
// For an empty space, we cannot put something that is occupied by someone else in the row.
// This is primarily based on the initial set of fixed positions.
// This is the inverse of the fixed positions, e.g. if a piece is fixed in one place, it cannot
// occur anywhere else.
p.forbidden[i] |= _fixed[row];
p.forbidden[i] <<= 1;
} else {
// We set that there is only one possible option for this permutation
p.forbidden[i] = ~(1 << piece);
}
}
p.generate();
#ifdef G_DEBUG
if (p.permutations.size() > 0) {
std::cerr << "Row " << row << " : generated " << p.permutations.size() << " permutations" << std::endl;
//for (std::size_t i = 0; i < p.permutations.size(); i += 1) {
// p.permutations[i].print(std::cout);
// std::cout << std::endl;
//}
}
#endif
// We failed to generate any useful candidates.
if (p.permutations.size() == 0) {
#ifdef G_DEBUG
std::cout << "Row " << row << " : no solutions (" << permutation_count << ")" << std::endl;
#endif
return;
}
#define O_ROWCOPY
// Save a copy of the row so we can jump up again.
RowT originalRow, originalForbidden;
memcpy(originalRow, _pieces[row], sizeof(RowT));
#ifdef O_ROWCOPY
memcpy(originalForbidden, _forbidden, sizeof(_forbidden));
unsigned originalBlocks[3];
memcpy(originalBlocks, block, sizeof(originalBlocks));
#endif
for (std::size_t i = 0; i < p.permutations.size(); i += 1) {
#ifdef G_DEBUG
std::cout << "Row " << row << " : ";
p.permutations[i].print(std::cout);
std::cout << std::endl;
#endif
#ifdef G_GRAPH
std::stringstream text;
text << "(row:" << row << " perm:" << i << ")" << std::endl;
p.permutations[i].print(text);
g_grapher.enter(text.str());
#endif
set_row(row, p.permutations[i]);
if (index < 8) {
solve(index+1, order);
} else {
if (print_solution) {
print(std::cout);
}
#ifdef G_GRAPH
std::stringstream solution;
print(solution);
g_grapher.enter(solution.str());
g_grapher.exit();
#endif
solution_time = clock();
}
#ifdef G_GRAPH
g_grapher.exit();
#endif
}
// If we are going back up one step we need to restore original row.
#ifdef O_ROWCOPY
memcpy(_pieces[row], originalRow, sizeof(RowT));
memcpy(_forbidden, originalForbidden, sizeof(_forbidden));
memcpy(block, originalBlocks, sizeof(originalBlocks));
#else
set_row(row, originalRow);
#endif
}
};
void SudokuBoard::debug() {
print(std::cerr);
}
void SudokuBoard::print(std::ostream & output) {
for (std::size_t y = 0; y < 9; y += 1) {
for (std::size_t x = 0; x < 9; x += 1) {
if (x != 0)
output << " ";
output << piece_at(x, y);
}
output << std::endl;
}
}
void profile_sudoku (std::string path) {
std::fstream fp(path.c_str());
unsigned unsigned count;
fp >> count;
for (unsigned i = 0; i < count; i += 1) {
std::cout << "Case #" << i << std::endl;
SudokuBoard b;
b.load(fp);
//clock_t start_time = clock();
b.solve();
//clock_t end_time = clock();
//double total_time = double(end_time - start_time) / CLOCKS_PER_SEC;
//std::cout << "Time: " << total_time << "s" << std::endl;
}
std::cout << std::flush;
}
void solve_sudoku (std::string path) {
std::fstream fp(path.c_str());
SudokuBoard b;
try {
b.load(fp);
} catch (std::range_error error) {
std::cerr << "Invalids sudoku file: " << path << std::endl;
return;
}
std::cout << "Initial Configuration" << std::endl;
b.print(std::cout);
clock_t start_time = clock();
b.solve();
clock_t end_time = clock();
double solution_time = double(b.solution_time - start_time) / CLOCKS_PER_SEC;
double total_time = double(end_time - start_time) / CLOCKS_PER_SEC;
std::cout << "Solution Time: " << solution_time << "s" << std::endl;
std::cout << "Total Time: " << total_time << "s" << std::endl;
}
int main (int argc, char ** argv) {
#ifdef G_GRAPH
g_grapher.start();
g_grapher.enter("Sudoku");
#endif
if (argc > 1) {
using namespace std;
std::vector<char*> args(argv+1, argv+argc);
const char * input_file = NULL;
bool timing_mode = false;
for (unsigned i = 0; i < args.size(); i += 1) {
std::string arg = args[i];
if (arg == "-t") {
timing_mode = true;
} else {
input_file = args[i];
}
}
if (timing_mode) {
clock_t start_time = clock();
const unsigned REPEATS = 2;
for (int repeat = 0; repeat < REPEATS; repeat += 1)
profile_sudoku(input_file);
clock_t end_time = clock();
double totalTime = double(end_time - start_time) / CLOCKS_PER_SEC;
std::cout << "Total Time: " << totalTime / REPEATS << "s" << std::endl;
} else {
solve_sudoku(input_file);
}
} else {
std::cerr << "Usage: " << "sudoku-solver" << " input-file [-t]" << std::endl;
std::cerr << "\t-t Timing/Benchmark mode" << std::endl;
}
std::cout << std::flush;
#ifdef G_DEBUG
std::cout << "Perm count: " << permutation_count << " Solve count: " << solution_count << std::endl;
std::cout << "Row count: " << row_count << " Row spaces: " << row_spaces << std::endl;
#endif
#ifdef G_GRAPH
g_grapher.exit();
g_grapher.end();
std::ofstream graph_file("graph.dot");
g_grapher.save(graph_file);
// Generate a PDF graph
system("/opt/local/bin/dot -Tpdf:quartz -ograph.pdf graph.dot");
system("open graph.pdf");
#endif
}