Source

python-ai / src / sudoku / csp.py

#!/usr/bin/env python

from itertools import count

class SolutionNotFound(Exception):
    def __str__(self):
        return "Solution not found."

class MaxDepthExceeded(Exception):
    def __init__(self, depth):
        self.depth = depth
    def __str__(self):
        return "Max solution depth exceeded %d." % (self.depth,)

class AllDiffConstraint(object):
    def __init__(self, variables):
        self.variables = variables
    def satisfied(self, m):
        """Given a full mapping `m`, tests whether or not the constraint is
        fully satisfied. In this case, tests whether the values assigned
        to every specified variable are different."""
        vs = set()
        for v in self.variables:
            if not v in m: return False
            vs.add(m[v])
        return len(vs) == len(self.variables)
    def possible(self, m):
        """Given a partial mapping `m`, tests whether or not there exists an
        extension of `m` that satisfied the constraint. In this case, if two
        values are ever the same, the constraint cannot be satisfied."""
        vs = [m[v] for v in m if v in self.variables]
        return len(vs) == len(set(vs))

class CSPSolver(object):
    MAX_CHECKS = 1e6
    def __init__(self, variables, domains, constraints):
        self.variables = variables
        self.domains = domains
        self.constraints = constraints
    def solved(self, m):
        for v in self.variables:
            if not v in m: return False
        for c in self.constraints:
            if not c.satisfied(m): return False
        return True
    def constraintsUnsatisfied(self, m):
        num = 0
        for c in self.constraints:
            if not c.satisfied(m): num += 1
        return num
    def solve(self):
        """Solves the CSP problem, returning the solution as a dict with
        variables as keys and domain members as values.
        
        Raises SolutionNotFound if no solution found or MaxDepthExceeded if
        solution not found after MAX_DEPTH expansions."""
        (mapping, num) = self.solveWithCount()
        return mapping
    def solveWithCount(self):
        """Solves the CSP problem, returning the 2-tuple (mapping, exps).
        `mapping` is the solution as a dict with variables as keys and
        domain members as values. `exps` is the number of expansions
        required to find a solution.
        
        Raises SolutionNotFound if no solution found or MaxDepthExceeded if
        solution not found after MAX_DEPTH expansions."""
        if len(self.constraints) == 0 and self.solved(dict()):
            return dict(), 0
        for v in self.variables:
            if v not in self.domains or len(self.domains[v]) == 0:
                raise SolutionNotFound
        return self.searchSolution()
    def searchSolution(self):
        raise SolutionNotFound


class BacktrackingSolver(CSPSolver):
    def __init__(self, *args, **kwargs):
        super(BacktrackingSolver, self).__init__(*args, **kwargs)
    def searchSolution(self):
        last_var = self.variables[-1]
        var = self.variables[0]
        domains = self.prunedDomains()
        mapping = {var: self.domains[var][0]}
        exps = 0
        while True:
            if exps > self.MAX_CHECKS:
                raise MaxDepthExceeded(self.MAX_CHECKS)
            elif var == last_var or self.mustBacktrack(domains, mapping, var):
                if self.solved(mapping):
                    return mapping, exps
                exps += 1
                while mapping[var] == domains[var][-1]:
                    del mapping[var]
                    idx = self.variables.index(var)
                    if idx == 0:
                        raise SolutionNotFound
                    var = self.variables[idx-1]
                last_val = mapping[var]
                val = domains[var][domains[var].index(last_val)+1]
                mapping[var] = val
            else:
                var = self.variables[self.variables.index(var)+1]
                if len(domains[var]) == 0:
                    raise SolutionNotFound
                mapping[var] = domains[var][0]
    def prunedDomains(self):
        return self.domains
    def mustBacktrack(self, domains, mapping, var):
        return False


class ForwardCheckingSolver(BacktrackingSolver):
    def mustBacktrack(self, domains, mapping, var):
        forwards = [v for v in self.variables if v not in mapping]
        for var in forwards:
            legal_vals = set(domains[var])
            for val in domains[var]:
                hypo_mapping = dict(mapping)
                hypo_mapping[var] = val
                for c in self.constraints:
                    if not c.possible(hypo_mapping):
                        legal_vals.discard(val)
                        break
            if len(legal_vals) == 0:
                return True
        return False


class ArcConsistencySolver(BacktrackingSolver):
    def prunedDomains(self):
        domains = {v: list() for v in self.domains}
        for varA in self.domains:
            for valA in self.domains[varA]:
                if self.compatibleWithAllOthers(varA, valA):
                    domains[varA].append(valA)
        return domains
    
    def compatibleWithAllOthers(self, varA, valA):
        for varB in self.domains:
            if varA == varB: continue
            if not self.existsPossibleValue(varA, valA, varB):
                return False
        return True
    
    def existsPossibleValue(self, varA, valA, varB):
        for valB in self.domains[varB]:
            if self.compatibleWithConstraints(varA, valA, varB, valB):
                return True
        return False
    
    def compatibleWithConstraints(self, varA, valA, varB, valB):
        for c in self.constraints:
            if not c.possible({varA: valA, varB: valB}):
                return False
        return True