Source

python-ai / src / sudoku / csp.py

Full commit
#!/usr/bin/env python

from itertools import count

class SolutionNotFound(Exception):
    def __str__(self):
        return "Solution not found."

class MaxDepthExceeded(Exception):
    def __init__(self, depth):
        self.depth = depth
    def __str__(self):
        return "Max solution depth exceeded %d." % (self.depth,)

class NotZeroConstraint(object):
    def __init__(self, variable):
        self.variable = variable
    def satisfied(self, m):
        v = self.variable
        return v in m and m[v] != 0

class AllDiffConstraint(object):
    def __init__(self, variables):
        self.variables = variables
    def satisfied(self, m):
        vs = set()
        for v in self.variables:
            if not v in m: return False
            vs.add(m[v])
        return len(vs) == len(self.variables)

class CSPSolver(object):
    MAX_CHECKS = 1e7
    def __init__(self, variables, domains, constraints):
        self.variables = variables
        self.domains = domains
        self.constraints = constraints
    def solved(self, m):
        for v in self.variables:
            if not v in m: return False
        for c in self.constraints:
            if not c.satisfied(m): return False
        return True
    def constraintsUnsatisfied(self, m):
        num = 0
        for c in self.constraints:
            if not c.satisfied(m): num += 1
        return num
    def solve(self):
        """Solves the CSP problem, returning the solution as a dict with
        variables as keys and domain members as values.
        
        Raises SolutionNotFound if no solution found or MaxDepthExceeded if
        solution not found after MAX_DEPTH expansions."""
        (mapping, num) = self.solveWithCount()
        return mapping
    def solveWithCount(self):
        """Solves the CSP problem, returning the 2-tuple (mapping, exps).
        `mapping` is the solution as a dict with variables as keys and
        domain members as values. `exps` is the number of expansions
        required to find a solution.
        
        Raises SolutionNotFound if no solution found or MaxDepthExceeded if
        solution not found after MAX_DEPTH expansions."""
        if len(self.constraints) == 0 and self.solved(dict()):
            return dict(), 0
        for v in self.variables:
            if v not in self.domains or len(self.domains[v]) == 0:
                raise SolutionNotFound
        return self.searchSolution()
    def searchSolution(self):
        raise SolutionNotFound


class BacktrackingSolver(CSPSolver):
    def __init__(self, *args, **kwargs):
        super(BacktrackingSolver, self).__init__(*args, **kwargs)
    def searchSolution(self):
        last_var = self.variables[-1]
        var = self.variables[0]
        mapping = {var: self.domains[var][0]}
        for exps in count():
            if self.solved(mapping):
                return mapping, exps
            elif var == last_var:
                while mapping[var] == self.domains[var][-1]:
                    del mapping[var]
                    idx = self.variables.index(var)
                    if idx == 0:
                        raise SolutionNotFound
                    var = self.variables[idx-1]
                last_val = mapping[var]
                val = self.domains[var][self.domains[var].index(last_val)+1]
                mapping[var] = val
            else:
                var = self.variables[self.variables.index(var)+1]
                if len(self.domains[var]) == 0:
                    raise SolutionNotFound
                mapping[var] = self.domains[var][0]
    def getNextVar(self, remaining):
        for r in remaining: return r
    def getNextVal(self, remaining):
        for r in remaining: return r