# python-ai / src / sudoku / csp.py

 ```#!/usr/bin/env python from itertools import count class SolutionNotFound(Exception): def __str__(self): return "Solution not found." class MaxDepthExceeded(Exception): def __init__(self, depth): self.depth = depth def __str__(self): return "Max solution depth exceeded %d." % (self.depth,) class AllDiffConstraint(object): def __init__(self, variables): self.variables = variables def satisfied(self, m): """Given a full mapping `m`, tests whether or not the constraint is fully satisfied. In this case, tests whether the values assigned to every specified variable are different.""" vs = set() for v in self.variables: if not v in m: return False vs.add(m[v]) return len(vs) == len(self.variables) def possible(self, m): """Given a partial mapping `m`, tests whether or not there exists an extension of `m` that satisfied the constraint. In this case, if two values are ever the same, the constraint cannot be satisfied.""" vs = [m[v] for v in m if v in self.variables] return len(vs) == len(set(vs)) class CSPSolver(object): MAX_CHECKS = 1e6 def __init__(self, variables, domains, constraints): self.variables = variables self.domains = domains self.constraints = constraints def solved(self, m): for v in self.variables: if not v in m: return False for c in self.constraints: if not c.satisfied(m): return False return True def constraintsUnsatisfied(self, m): num = 0 for c in self.constraints: if not c.satisfied(m): num += 1 return num def solve(self): """Solves the CSP problem, returning the solution as a dict with variables as keys and domain members as values. Raises SolutionNotFound if no solution found or MaxDepthExceeded if solution not found after MAX_DEPTH expansions.""" (mapping, num) = self.solveWithCount() return mapping def solveWithCount(self): """Solves the CSP problem, returning the 2-tuple (mapping, exps). `mapping` is the solution as a dict with variables as keys and domain members as values. `exps` is the number of expansions required to find a solution. Raises SolutionNotFound if no solution found or MaxDepthExceeded if solution not found after MAX_DEPTH expansions.""" if len(self.constraints) == 0 and self.solved(dict()): return dict(), 0 for v in self.variables: if v not in self.domains or len(self.domains[v]) == 0: raise SolutionNotFound return self.searchSolution() def searchSolution(self): raise SolutionNotFound class BacktrackingSolver(CSPSolver): def __init__(self, *args, **kwargs): super(BacktrackingSolver, self).__init__(*args, **kwargs) def searchSolution(self): last_var = self.variables[-1] var = self.variables[0] domains = self.prunedDomains() mapping = {var: self.domains[var][0]} exps = 0 while True: if exps > self.MAX_CHECKS: raise MaxDepthExceeded(self.MAX_CHECKS) elif var == last_var or self.mustBacktrack(domains, mapping, var): if self.solved(mapping): return mapping, exps exps += 1 while mapping[var] == domains[var][-1]: del mapping[var] idx = self.variables.index(var) if idx == 0: raise SolutionNotFound var = self.variables[idx-1] last_val = mapping[var] val = domains[var][domains[var].index(last_val)+1] mapping[var] = val else: var = self.variables[self.variables.index(var)+1] if len(domains[var]) == 0: raise SolutionNotFound mapping[var] = domains[var][0] def prunedDomains(self): return self.domains def mustBacktrack(self, domains, mapping, var): return False class ForwardCheckingSolver(BacktrackingSolver): def mustBacktrack(self, domains, mapping, var): forwards = [v for v in self.variables if v not in mapping] for var in forwards: legal_vals = set(domains[var]) for val in domains[var]: hypo_mapping = dict(mapping) hypo_mapping[var] = val for c in self.constraints: if not c.possible(hypo_mapping): legal_vals.discard(val) break if len(legal_vals) == 0: return True return False class ArcConsistencySolver(BacktrackingSolver): def prunedDomains(self): domains = {v: list() for v in self.domains} for varA in self.domains: for valA in self.domains[varA]: if self.compatibleWithAllOthers(varA, valA): domains[varA].append(valA) return domains def compatibleWithAllOthers(self, varA, valA): for varB in self.domains: if varA == varB: continue if not self.existsPossibleValue(varA, valA, varB): return False return True def existsPossibleValue(self, varA, valA, varB): for valB in self.domains[varB]: if self.compatibleWithConstraints(varA, valA, varB, valB): return True return False def compatibleWithConstraints(self, varA, valA, varB, valB): for c in self.constraints: if not c.possible({varA: valA, varB: valB}): return False return True ```