Solving Sudoku Efficiently

Author: Samuel Williams When: Thursday, 26 May 2011

Sudoku is an interesting problem because ultimately it has a mechanical solution, and yet finding this solution still presents an difficult challenge. One way to solve the problem is to systematically generate all possible permutations and discard those that violate the rules. You can do this one row at a time, and if you finally fill in all the rows, you've found a valid solution.

However, this method, brute force, can be slow. There are a variety of optimisations you can apply to the above problem to improve the efficiency:

There are many more optimisations, but here is my efficient implementation which does all of the above and more.

//
//  main.cpp
//  COSC329 Sudoku
//
//  Created by Samuel Williams on 26/05/11.
//  Copyright 2011 Samuel Williams. All rights reserved.
//

#include <vector>
#include <iostream>
#include <fstream>
#include <cassert>
#include <time.h>
#include <stdexcept>
#include <sstream>

//#define G_GRAPH
//#define G_DEBUG

template <typename ItemT, typename CostT, typename ArrayT, typename CostArrayT>
void insertion_sort(ItemT item, CostT cost, ArrayT & v, CostArrayT &c, std::size_t offset) {
	while (offset > 0 && cost < c[offset-1]) {
		c[offset] = c[offset-1];
		v[offset] = v[offset-1];
		
		offset -= 1;
	}
	
	v[offset] = item;
	c[offset] = cost;
}

template <typename AnyT>
inline void set_bit(AnyT & bits, unsigned position) {
	bits |= (AnyT(1) << position);
}

template <typename AnyT>
inline void clear_bit(AnyT & bits, unsigned position) {
	bits &= ~(AnyT(1) << position);
}

template <typename AnyT>
inline bool is_bit_set(const AnyT & bits, unsigned position) {
	return bits & (AnyT(1) << position);
}

#ifdef G_DEBUG
static long long permutation_count = 0;
static long long solution_count = 0;

// Statistics for row spaces
static long long row_spaces = 0;
static long long row_count = 0;
#endif

#ifdef G_GRAPH
class Grapher {
	private:
		std::stringstream _buffer;
		std::vector<std::string> _tree;
		unsigned _count;
	
	public:
		void start ();
		void end ();
		
		void enter(std::string text);
		void exit();
		
		void save (std::ostream &);
};

void replace(std::string& str, const std::string& oldStr, const std::string& newStr)
{
  std::size_t pos = 0;
  while((pos = str.find(oldStr, pos)) != std::string::npos)
  {
     str.replace(pos, oldStr.length(), newStr);
     pos += newStr.length();
  }
}

void Grapher::start () {
	_count = 0;
	_buffer << "digraph { rankdir=TB; size=\"8,5\";" << std::endl;
	_buffer << "\t ranksep=\"1.0 equally\";" << std::endl;
}

void Grapher::end () {
	_buffer << "}" << std::endl;
}

void Grapher::enter(std::string text) {
	std::stringstream child_name;
	child_name << "n" << _count++;

	replace(text, "\n", "\\n");

	if (_tree.size() > 0) {
		std::string parent_name = _tree.back();
		
		_buffer << "\tnode [shape=rectangle,label=\"" << text << "\"]; " << parent_name << " -> " << child_name.str() << ";" << std::endl;
	} else {
		_buffer << "\tnode [shape=rectangle,label=\"" << text << "\"]; " << child_name.str() << ";" << std::endl;
	}
	
	_tree.push_back(child_name.str());
}

void Grapher::exit() {
	_tree.pop_back();
}


void Grapher::save (std::ostream & out) {
	out << _buffer.str();
}

Grapher g_grapher;
#endif

template <unsigned N>
class Permutations {
public:
	static const unsigned MASK = ((1 << N) - 1) << 1;
	unsigned forbidden[N];

	struct Result {
		unsigned data[N];
		
		Result () {
		}
		
		Result (unsigned _data[N]) {
			memcpy(data, _data, sizeof(unsigned[N]));
		}
		
		operator unsigned * () {
			return data;
		}
		
		void print (std::ostream & out) const {
			out << data[0];
			
			for (std::size_t i = 1; i < N; i += 1) {
				out << ' ' << data[i];
			}
		}
	};
	
	std::vector<Result> permutations;
	
protected:
	unsigned out[N], order[N], count;
	unsigned available[N], fixed[N];
	unsigned top;
	
	bool okay (unsigned k, unsigned value) {
		return !(forbidden[k] & (1 << value));
	}
	
	// This permutation generation doesn't generate lexicographic order.
	// However, it might be possible if we consider a pointer into available,
	// and rather than take items from the top, take items from the front by
	// incrementing pointer.
	void recurse(unsigned k) {
#ifdef G_GRAPH
		std::stringstream current;
		current << "(" << permutation_count << ") ";
		Result intermediate_result(out);
		intermediate_result.print(current);
		g_grapher.enter(current.str());
#endif

		if (top != 0) {
			// We reduce the number of available items by 1 and recurse.
			top--;
#ifdef G_GRAPH
				permute(k);
#else
			if (top != 1) {
				permute(k);
			} else {
				bottom(k);
			}
#endif
			top++;
		} else {
			next();
		}
		
#ifdef G_GRAPH		
		g_grapher.exit();
#endif
	}
	
	void bottom(unsigned k) {
#ifdef G_DEBUG
		permutation_count += 1;
#endif
		
		unsigned a = order[k];
		unsigned b = order[k+1];
		// available[0], available[1]
		
		if (okay(a, available[0]) && okay(b, available[1])) {
			out[a] = available[0];
			out[b] = available[1];
			next();
		}
		
		if (okay(b, available[0]) && okay(a, available[1])) {
			out[b] = available[0];
			out[a] = available[1];
			next();
		}
	}
	
	void permute(unsigned k) {
#ifdef G_DEBUG
		permutation_count += 1;
#endif
		
		unsigned p = order[k];

		// We need to handle the last available item specifically,
		// So this loop handles the case where available items >= 2
		for (unsigned i = 0; i < top; i += 1) {
			if (okay(p, available[i])) {
				// We take available[i] and put it in the permutation.
				out[p] = available[i];

				// We take the top value and replace available[i].
				available[i] = available[top];

				recurse(k+1);

				// We replace the item back unsignedo available[i].
				available[i] = out[p];

#ifdef G_GRAPH
				// Improve debugging
				out[p] = 0;
#endif
			}
		}

		if (okay(p, available[top])) {
			out[p] = available[top];

			recurse(k+1);

#ifdef G_GRAPH
			// Improve debugging
			out[p] = 0;
#endif
		}
	}
	
	// Could be optimised to use BSF instruction.
	static inline unsigned bit_shift_left (unsigned value) {
		return ffs((int)(value & MASK)) - 1;
		/*
		unsigned count = 0;
		
		while (value != 0) {
			count++;
			value >>= 1;
		}
		
		return count - 1;
		*/
	}
	
	// http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetNaive
	static inline unsigned count_zero_bits (unsigned value) {
		unsigned int c, mask = MASK;
		
		// We inverse the value and count 1 bits.
		value = ~value;

		for (c = 0; mask; value >>= 1, mask >>= 1)
		{
			c += (mask & 1) & (value & 1);
		}
		
		return c;
	}
	
public:
	Permutations () {
		memset(forbidden, 0, sizeof(forbidden));
		memset(fixed, 0, sizeof(fixed));
	}
	
	void generate () {
		// Check if there are any possible permutations:
		
		// We need mask all initial possibilities, e.g for N = 9, b1111111110
		unsigned sum = MASK;
		
		for (unsigned i = 0; i < N; i += 1) {
			sum &= forbidden[i];
			
			// If it is not possible to put any value in a particular column,
			// no permutation can be generated.
			if ((forbidden[i] & MASK) == MASK)
				return;
		}
		
		// If sum is non-zero then there is at least one number
		// forbidden in all positions - e.g. no valid permutations.
		if (sum != 0) {
			return;
		}

		// Reset the permutation generation state:
		permutations.resize(0);		
		unsigned used[N+1];
		memset(used, 0, sizeof(used));
		
		// Number of empty spots.
		count = 0;
		unsigned cost[N];
				
		for (unsigned j = 0; j < N; j += 1) {
			// Go backwards through the columns
			//unsigned i = (N - 1) - j;
			// Go forwards through the columns
			unsigned i = j;
			
			// Position zero is not available.
			// forbidden[i] |= 1;
		
			// Check if only one bit is zero in the given forbidden bitset.
			// This indicates that there is only one possible value for this position (e.g. fixed)
			unsigned mask = ~forbidden[i] & MASK;
			if (mask & (mask - 1)) {
				// More than one bit is set
				out[i] = 0;
				
				unsigned c = count_zero_bits(forbidden[i]);
				
				// --- Insertion Sort ---
				// Most inefficient order:
				//c = N - c;
				insertion_sort(i, c, order, cost, count);
				count += 1;
				
				// --- Unordered ---
				//order[count] = i;
				//cost[count] = c;
				//order[count++] = i;
			} else {
				// Only one bit isn't set, so this value isn't part of
				// permutation generation.
				
				// Find power of two bit position.
				unsigned value = bit_shift_left(~forbidden[i]);
				
				// Save the value here - it is never updated.
				out[i] = value;

				// Used for direct recursion.
				fixed[i] = 1;

				// An item cannot be used twice in the same row.
				if (used[value] == 1) {
					// If we encounter this case, it means that there are no possible permutations.
					return;
				}
				
				// Item is not added to available list.
				used[value] = 1;
			}
		}
		
#ifdef G_DEBUG
		row_count += 1;
		row_spaces += count;
#endif
		
		// There is only one permutation, because all items are fixed.
		if (count == 0) {
			next();
			return;
		}
		
		//quicksort(order, cost, 0, count);
		
		// Setup list of sequential items excluding used
		top = 0;
		for (unsigned i = 0; i < N; i += 1) {
			if (!used[i+1]) {
				available[top] = i+1;
				top += 1;
			}
		}
		
		recurse(0);
	}
	
	void next () {
		//ResultT result(out, out+N);
		permutations.push_back(out);
	}
};

class SudokuBoard {
	public:
		typedef unsigned RowT[9];
		
		// Index _pieces[row][column]
		unsigned _pieces[9][9];
		
		// Index _forbidden[column]
		unsigned _forbidden[9];
		
		// Index _fixed[row]
		unsigned _fixed[9];
		
		// Index _blocks[row][column]
		unsigned _blocks[3][3];
		
		bool print_solution;
		clock_t solution_time;
		
		SudokuBoard () {
			memset(_pieces, 0, sizeof(_pieces));
			memset(_forbidden, 0, sizeof(_forbidden));
			memset(_blocks, 0, sizeof(_blocks));
			print_solution = true;
		}
		
		unsigned & piece_at (unsigned x, unsigned y) {
			return _pieces[y][x];
		}
		
		void load (std::iostream & input) {
			for (unsigned i = 0; i < 9; i += 1) {
				RowT values;
				
				unsigned row_mask = 0;
				
				// Load one row at a time.
				for (unsigned j = 0; j < 9; j += 1) {
					input >> values[j];
					
					if (/*values[j] < 0 || */values[j] > 9) {
						throw std::range_error("Invalid sudoku value");
					}
					
					if (values[j] != 0)
						set_bit(row_mask, values[j] - 1);
				}
				
				// Keep track of what values are fixed so we can optimise calculating forbidden positions.
				_fixed[i] = row_mask;
				set_row(i, values);
			}
		}
		
		// If you try to set a row to an incorrect permutation, this function
		// will create instability in the tree search because forbidden bits
		// will be tracked incorrectly. Thus, if you don't know if your
		// permutation is valid you should check it first. However, the
		// permutation generation algorithm in this case will only generate
		// valid configurations.
		void set_row (unsigned row, unsigned values[9]) {
			unsigned * block = _blocks[row / 3];
			
			// We have to assume that we can set this row with no conflicts.
			// The permutation generation algorithm should ensure that we
			// don't hit a forbidden permutation. However, if we do set a
			// a row to an invalid value, we will have serious problems.
			for (unsigned b = 0, bc = 0; b < 9; b += 3, bc += 1) {
				
				// We have to clear and set in blocks of 3 to avoid collisions..
				// e.g. if we clear and set bits 1 in the same block.
				for (unsigned column = b; column < (b + 3); column += 1) {
					// Unset the forbidden bit for the old value, if appropriate.
					if (_pieces[row][column] != values[column] && _pieces[row][column] > 0) {
						clear_bit(_forbidden[column], _pieces[row][column] - 1);
						clear_bit(block[bc], _pieces[row][column] - 1);
					}
				}
				
				for (unsigned column = b; column < (b + 3); column += 1) {
					// Set the forbidden bit for the new value, if appropriate.
					if (_pieces[row][column] != values[column] && values[column] > 0) {
						// We put this check to ensure that algorithm is working correctly.
						//assert(!is_bit_set(_forbidden[column], values[column] - 1));
						//assert(!is_bit_set(block[bc], values[column] - 1));
						
						set_bit(_forbidden[column], values[column] - 1);
						set_bit(block[bc], values[column] - 1);
						
						_pieces[row][column] = values[column];
					} else if (values[column] == 0) {
						_pieces[row][column] = 0;
					}
				}
			}
		}
		
		// Try to set a piece in the given row, return false if it was definitely invalid.
		bool set_piece(unsigned column, unsigned row, unsigned value) {
			unsigned & block = _blocks[row / 3][column / 3];
			
			if (_pieces[row][column] != value) {
				if (value > 0) {
					// Column check
					if (is_bit_set(_forbidden[column], value - 1)) {
						return false;
					}
					
					// Block check
					if (is_bit_set(block, value - 1)) {
						return false;
					}
				}
				
				// So far so good - we know that we don't have a conflict.
				
				// Unset existing forbidden bits
				if (_pieces[row][column] > 0) {
					clear_bit(_forbidden[column], _pieces[row][column] - 1);
					clear_bit(block, _pieces[row][column] - 1);
				}
				
				// Update the value
				_pieces[row][column] = value;
				
				if (value > 0) {
					set_bit(_forbidden[column], value - 1);
					set_bit(block, value - 1);
				}
			}
			
			return true;
		}
		
		void debug();
		void print(std::ostream & output);
		
		typedef std::pair<unsigned, unsigned> CountT;
		
		static bool sort_counts(CountT lhs, CountT rhs) {
			return lhs.second < rhs.second;
		}
		
		void solve () {
			// We try to organise the rows in order of least sparse
			// to most sparse so culling is more effective.
			std::vector<CountT> rows;
			std::vector<CountT> blocks;
			
			for (unsigned i = 0; i < 9; i += 1) {
				CountT row_count(i, 0);
				
				
				for (unsigned j = 0; j < 9; j += 1) {
					if (_pieces[i][j] == 0) {
						row_count.second += 1;
					}
				}
				
				rows.push_back(row_count);
			}
			
			for (unsigned i = 0; i < 3; i += 1) {
				CountT block_count(i, 0);
				
				for (unsigned j = 0; j < 3; j += 1) {
					block_count.second += rows[i * 3 + j].second;
				}
				
				blocks.push_back(block_count);
			}
			
			//std::sort(rows.begin(), rows.end(), sort_counts);
			
			std::sort(blocks.begin(), blocks.end(), sort_counts);
			
			std::vector<int> order;
			order.reserve(9);
			
			for (std::size_t i = 0; i < 3; i += 1) {
				std::size_t start = blocks[i].first * 3;
				std::sort(rows.begin() + start, rows.begin() + start + 3, sort_counts);
				
				for (std::size_t j = start; j < start + 3; j += 1) {
					//std::cerr << "Order " << i << " is row " << count[i].first << " with count " << count[i].second << std::endl;
					order.push_back((rows.begin() + j)->first);
				}
			}
						
			solve(0, order);
		}
		
		void solve (unsigned index, const std::vector<int> & order) {
#ifdef G_DEBUG
			solution_count += 1;
#endif
			unsigned row = order[index];
			unsigned size = 0;
			unsigned * block = _blocks[row / 3];
			
			Permutations<9> p;
			
			// Update forbidden columns
			memcpy(p.forbidden, _forbidden, sizeof(p.forbidden));
			
			// Update forbidden blocks
			for (unsigned b = 0; b < 3; b += 1) {
				p.forbidden[(b*3)] |= block[b];
				p.forbidden[(b*3)+1] |= block[b];
				p.forbidden[(b*3)+2] |= block[b];
			}
			
			for (unsigned i = 0; i < 9; i += 1) {
				unsigned piece = piece_at(i, row);
				if (piece == 0) {
					size += 1;
					
					// For an empty space, we cannot put something that is occupied by someone else in the row.
					// This is primarily based on the initial set of fixed positions.
					// This is the inverse of the fixed positions, e.g. if a piece is fixed in one place, it cannot
					// occur anywhere else.
					p.forbidden[i] |= _fixed[row];
					
					p.forbidden[i] <<= 1;
				} else {
					// We set that there is only one possible option for this permutation
					p.forbidden[i] = ~(1 << piece);
				}
			}
			
			p.generate();

#ifdef G_DEBUG
			if (p.permutations.size() > 0) {
				std::cerr << "Row " << row << " : generated " << p.permutations.size() << " permutations" << std::endl;
				
				//for (std::size_t i = 0; i < p.permutations.size(); i += 1) {
				//	p.permutations[i].print(std::cout);
				//	std::cout << std::endl;
				//}
			}
#endif
			
			// We failed to generate any useful candidates.
			if (p.permutations.size() == 0) {
#ifdef G_DEBUG
				std::cout << "Row " << row << " : no solutions (" << permutation_count << ")" << std::endl;
#endif
				return;
			}

#define O_ROWCOPY

			// Save a copy of the row so we can jump up again.
			RowT originalRow, originalForbidden;
			memcpy(originalRow, _pieces[row], sizeof(RowT));

#ifdef O_ROWCOPY
			memcpy(originalForbidden, _forbidden, sizeof(_forbidden));
			unsigned originalBlocks[3];
			memcpy(originalBlocks, block, sizeof(originalBlocks));
#endif
			
			for (std::size_t i = 0; i < p.permutations.size(); i += 1) {
#ifdef G_DEBUG
				std::cout << "Row " << row << " : ";
				p.permutations[i].print(std::cout);
				std::cout << std::endl;
#endif
		
#ifdef G_GRAPH
				std::stringstream text;
				text << "(row:" << row << " perm:" << i << ")" << std::endl;
				p.permutations[i].print(text);
				g_grapher.enter(text.str());
#endif
				
				set_row(row, p.permutations[i]);
				if (index < 8) {
					solve(index+1, order);
				} else {
					if (print_solution) {
						print(std::cout);
					}
					
#ifdef G_GRAPH
					std::stringstream solution;
					print(solution);
					g_grapher.enter(solution.str());
					g_grapher.exit();
#endif
					
					solution_time = clock();
				}

#ifdef G_GRAPH				
				g_grapher.exit();
#endif
			}
			
			// If we are going back up one step we need to restore original row.
#ifdef O_ROWCOPY
			memcpy(_pieces[row], originalRow, sizeof(RowT));
			memcpy(_forbidden, originalForbidden, sizeof(_forbidden));
			memcpy(block, originalBlocks, sizeof(originalBlocks));
#else
			set_row(row, originalRow);
#endif
		}
};

void SudokuBoard::debug() {
	print(std::cerr);
}

void SudokuBoard::print(std::ostream & output) {
	for (std::size_t y = 0; y < 9; y += 1) {
		for (std::size_t x = 0; x < 9; x += 1) {
			if (x != 0)
				output << " ";
				
			output << piece_at(x, y);
		}
		
		output << std::endl;
	}	
}

void profile_sudoku (std::string path) {
	std::fstream fp(path.c_str());

	unsigned unsigned count;
	fp >> count;

	for (unsigned i = 0; i < count; i += 1) {
		std::cout << "Case #" << i << std::endl;
		
		SudokuBoard b;
		b.load(fp);
		
		//clock_t start_time = clock();
		b.solve();
		//clock_t end_time = clock();
		
		//double total_time = double(end_time - start_time) / CLOCKS_PER_SEC;
		//std::cout << "Time: " << total_time << "s" << std::endl;
	}
	
	std::cout << std::flush;
}

void solve_sudoku (std::string path) {
	std::fstream fp(path.c_str());
	
	SudokuBoard b;
	try {
		b.load(fp);
	} catch (std::range_error error) {
		std::cerr << "Invalids sudoku file: " << path << std::endl;
		return;
	}

	std::cout << "Initial Configuration" << std::endl;
	b.print(std::cout);
	
	clock_t start_time = clock();
	b.solve();
	clock_t end_time = clock();
	
	double solution_time = double(b.solution_time - start_time) / CLOCKS_PER_SEC;
	double total_time = double(end_time - start_time) / CLOCKS_PER_SEC;
	
	std::cout << "Solution Time: " << solution_time << "s" << std::endl;
	std::cout << "Total Time: " << total_time << "s" << std::endl;
}

int main (int argc, char ** argv) {
#ifdef G_GRAPH
	g_grapher.start();
	g_grapher.enter("Sudoku");
#endif
	
	if (argc > 1) {
		using namespace std;
		std::vector<char*> args(argv+1, argv+argc);
		
		const char * input_file = NULL;
		bool timing_mode = false;
		
		for (unsigned i = 0; i < args.size(); i += 1) {
			std::string arg = args[i];
			
			if (arg == "-t") {
				timing_mode = true;
			} else {
				input_file = args[i];
			}
		}
		
		if (timing_mode) {
			clock_t start_time = clock();
			const unsigned REPEATS = 2;

			for (int repeat = 0; repeat < REPEATS; repeat += 1)
				profile_sudoku(input_file);
				
			clock_t end_time = clock();
			
			double totalTime = double(end_time - start_time) / CLOCKS_PER_SEC;
			std::cout << "Total Time: " << totalTime / REPEATS << "s" << std::endl;
		} else {
			solve_sudoku(input_file);
		}
	} else {
		std::cerr << "Usage: " << "sudoku-solver" << " input-file [-t]" << std::endl;
		std::cerr << "\t-t Timing/Benchmark mode" << std::endl;
	}
	
	std::cout << std::flush;
	
#ifdef G_DEBUG
	std::cout << "Perm count: " << permutation_count << " Solve count: " << solution_count << std::endl;
	std::cout << "Row count: " << row_count << " Row spaces: " << row_spaces << std::endl;
#endif

#ifdef G_GRAPH
	g_grapher.exit();
	g_grapher.end();
	
	std::ofstream graph_file("graph.dot");
	g_grapher.save(graph_file);
	
	// Generate a PDF graph
	system("/opt/local/bin/dot -Tpdf:quartz -ograph.pdf graph.dot");
	system("open graph.pdf");
#endif
}

Comments

Leave a comment

Please note, comments must be formatted using Markdown. Links can be enclosed in angle brackets, e.g. <www.codeotaku.com>.