# python-ai / src / sudoku / csp.py

 ``` 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157``` ```#!/usr/bin/env python from itertools import count class SolutionNotFound(Exception): def __str__(self): return "Solution not found." class MaxDepthExceeded(Exception): def __init__(self, depth): self.depth = depth def __str__(self): return "Max solution depth exceeded %d." % (self.depth,) class AllDiffConstraint(object): def __init__(self, variables): self.variables = variables def satisfied(self, m): """Given a full mapping `m`, tests whether or not the constraint is fully satisfied. In this case, tests whether the values assigned to every specified variable are different.""" vs = set() for v in self.variables: if not v in m: return False vs.add(m[v]) return len(vs) == len(self.variables) def possible(self, m): """Given a partial mapping `m`, tests whether or not there exists an extension of `m` that satisfied the constraint. In this case, if two values are ever the same, the constraint cannot be satisfied.""" vs = [m[v] for v in m if v in self.variables] return len(vs) == len(set(vs)) class CSPSolver(object): MAX_CHECKS = 1e6 def __init__(self, variables, domains, constraints): self.variables = variables self.domains = domains self.constraints = constraints def solved(self, m): for v in self.variables: if not v in m: return False for c in self.constraints: if not c.satisfied(m): return False return True def constraintsUnsatisfied(self, m): num = 0 for c in self.constraints: if not c.satisfied(m): num += 1 return num def solve(self): """Solves the CSP problem, returning the solution as a dict with variables as keys and domain members as values. Raises SolutionNotFound if no solution found or MaxDepthExceeded if solution not found after MAX_DEPTH expansions.""" (mapping, num) = self.solveWithCount() return mapping def solveWithCount(self): """Solves the CSP problem, returning the 2-tuple (mapping, exps). `mapping` is the solution as a dict with variables as keys and domain members as values. `exps` is the number of expansions required to find a solution. Raises SolutionNotFound if no solution found or MaxDepthExceeded if solution not found after MAX_DEPTH expansions.""" if len(self.constraints) == 0 and self.solved(dict()): return dict(), 0 for v in self.variables: if v not in self.domains or len(self.domains[v]) == 0: raise SolutionNotFound return self.searchSolution() def searchSolution(self): raise SolutionNotFound class BacktrackingSolver(CSPSolver): def __init__(self, *args, **kwargs): super(BacktrackingSolver, self).__init__(*args, **kwargs) def searchSolution(self): last_var = self.variables[-1] var = self.variables[0] domains = self.prunedDomains() mapping = {var: self.domains[var][0]} exps = 0 while True: if exps > self.MAX_CHECKS: raise MaxDepthExceeded(self.MAX_CHECKS) elif var == last_var or self.mustBacktrack(domains, mapping, var): if self.solved(mapping): return mapping, exps exps += 1 while mapping[var] == domains[var][-1]: del mapping[var] idx = self.variables.index(var) if idx == 0: raise SolutionNotFound var = self.variables[idx-1] last_val = mapping[var] val = domains[var][domains[var].index(last_val)+1] mapping[var] = val else: var = self.variables[self.variables.index(var)+1] if len(domains[var]) == 0: raise SolutionNotFound mapping[var] = domains[var][0] def prunedDomains(self): return self.domains def mustBacktrack(self, domains, mapping, var): return False class ForwardCheckingSolver(BacktrackingSolver): def mustBacktrack(self, domains, mapping, var): forwards = [v for v in self.variables if v not in mapping] for var in forwards: legal_vals = set(domains[var]) for val in domains[var]: hypo_mapping = dict(mapping) hypo_mapping[var] = val for c in self.constraints: if not c.possible(hypo_mapping): legal_vals.discard(val) break if len(legal_vals) == 0: return True return False class ArcConsistencySolver(BacktrackingSolver): def prunedDomains(self): domains = {v: list() for v in self.domains} for varA in self.domains: for valA in self.domains[varA]: if self.compatibleWithAllOthers(varA, valA): domains[varA].append(valA) return domains def compatibleWithAllOthers(self, varA, valA): for varB in self.domains: if varA == varB: continue if not self.existsPossibleValue(varA, valA, varB): return False return True def existsPossibleValue(self, varA, valA, varB): for valB in self.domains[varB]: if self.compatibleWithConstraints(varA, valA, varB, valB): return True return False def compatibleWithConstraints(self, varA, valA, varB, valB): for c in self.constraints: if not c.possible({varA: valA, varB: valB}): return False return True ```