Commits

Geoff Hill committed c228b45

moved CSP stuff out of sudoku and into own file

  • Participants
  • Parent commits 12ea18a

Comments (0)

Files changed (4)

File src/sudoku/csp.py

 #!/usr/bin/env python
 
+from itertools import count
+
 class SolutionNotFound(Exception):
     def __str__(self):
         return "Solution not found."
 
-class MaxSolutionDepthExceeded(Exception):
+class MaxDepthExceeded(Exception):
     def __init__(self, depth):
         self.depth = depth
     def __str__(self):
     def __init__(self, variable):
         self.variable = variable
     def satisfied(self, m):
-        return m[self.variable] != 0
+        v = self.variable
+        return v in m and m[v] != 0
 
 class AllDiffConstraint(object):
     def __init__(self, variables):
     def satisfied(self, m):
         vs = set()
         for v in self.variables:
+            if not v in m: return False
             vs.add(m[v])
         return len(vs) == len(self.variables)
 
 class CSPSolver(object):
     MAX_CHECKS = 1e7
-
-    def __init__(self, numVariables, domains, constraints):
-        self.numVariables = numVariables
+    def __init__(self, variables, domains, constraints):
+        self.variables = variables
         self.domains = domains
         self.constraints = constraints
-
     def solved(self, m):
+        for v in self.variables:
+            if not v in m: return False
         for c in self.constraints:
             if not c.satisfied(m): return False
         return True
+    def constraintsUnsatisfied(self, m):
+        num = 0
+        for c in self.constraints:
+            if not c.satisfied(m): num += 1
+        return num
+    def solve(self):
+        """Solves the CSP problem, returning the solution as a dict with
+        variables as keys and domain members as values.
+        
+        Raises SolutionNotFound if no solution found or MaxDepthExceeded if
+        solution not found after MAX_DEPTH expansions."""
+        (mapping, num) = self.solveWithCount()
+        return mapping
+    def solveWithCount(self):
+        """Solves the CSP problem, returning the 2-tuple (mapping, exps).
+        `mapping` is the solution as a dict with variables as keys and
+        domain members as values. `exps` is the number of expansions
+        required to find a solution.
+        
+        Raises SolutionNotFound if no solution found or MaxDepthExceeded if
+        solution not found after MAX_DEPTH expansions."""
+        if len(self.constraints) == 0 and self.solved(dict()):
+            return dict(), 0
+        for v in self.variables:
+            if v not in self.domains or len(self.domains[v]) == 0:
+                raise SolutionNotFound
+        return self.searchSolution()
+    def searchSolution(self):
+        raise SolutionNotFound
+
+
+class BacktrackingSolver(CSPSolver):
+    def __init__(self, *args, **kwargs):
+        super(BacktrackingSolver, self).__init__(*args, **kwargs)
+    def searchSolution(self):
+        last_var = self.variables[-1]
+        var = self.variables[0]
+        mapping = {var: self.domains[var][0]}
+        for exps in count():
+            if self.solved(mapping):
+                return mapping, exps
+            elif var == last_var:
+                while mapping[var] == self.domains[var][-1]:
+                    del mapping[var]
+                    idx = self.variables.index(var)
+                    if idx == 0:
+                        raise SolutionNotFound
+                    var = self.variables[idx-1]
+                last_val = mapping[var]
+                val = self.domains[var][self.domains[var].index(last_val)+1]
+                mapping[var] = val
+            else:
+                var = self.variables[self.variables.index(var)+1]
+                if len(self.domains[var]) == 0:
+                    raise SolutionNotFound
+                mapping[var] = self.domains[var][0]
+    def getNextVar(self, remaining):
+        for r in remaining: return r
+    def getNextVal(self, remaining):
+        for r in remaining: return r
+
+

File src/sudoku/csp_test.py

+#!/usr/bin/env python
+
+import unittest
+from csp import *
+
+class TestNotZeroConstraint(unittest.TestCase):
+    def testNotZero(self):
+        c = NotZeroConstraint(4)
+        self.assertTrue(c.satisfied({4: 3}))
+    def testZero(self):
+        c = NotZeroConstraint(2)
+        self.assertFalse(c.satisfied({2: 0}))
+    def testNotInMapping(self):
+        c = NotZeroConstraint(4)
+        self.assertFalse(c.satisfied({2: 3}))
+
+class TestAllDiffConstraint(unittest.TestCase):
+    def testEmpty(self):
+        c = AllDiffConstraint(list())
+        self.assertTrue(c.satisfied({0: 0}))
+    def testUnivariateAllDiff(self):
+        c = AllDiffConstraint([0])
+        self.assertTrue(c.satisfied({0: 6}))
+    def testAllDiff(self):
+        c = AllDiffConstraint([0, 1, 2, 3])
+        self.assertTrue(c.satisfied({0: 4, 1: 7, 2: 5, 3: 6}))
+    def testAllSame(self):
+        c = AllDiffConstraint([0, 1, 2, 3])
+        self.assertFalse(c.satisfied({0: 6, 1: 6, 2: 6, 3: 6}))
+    def testSomeSame(self):
+        c = AllDiffConstraint([0, 1, 2, 3])
+        self.assertFalse(c.satisfied({0: 4, 1: 7, 2: 3, 3: 7}))
+    def testNotInMapping(self):
+        c = AllDiffConstraint([0, 1, 2, 3])
+        self.assertFalse(c.satisfied({0: 4, 1: 7}))
+
+
+class TestCSPSolver(unittest.TestCase):
+    def testTriviallySolveNoConstraintsNoVariables(self):
+        vs = []
+        ds = {}
+        cs = []
+        s = CSPSolver(vs, ds, cs)
+        self.assertEqual(s.solve(), {})
+    def testTriviallyUnsolvableUnspecifiedDomain(self):
+        vs = [0,1,2]
+        ds = {0: [0], 1: [1]}
+        cs = []
+        s = CSPSolver(vs, ds, cs)
+        with self.assertRaises(SolutionNotFound):
+            s.solve()
+    def testTriviallyUnsolvableEmptyDomain(self):
+        vs = [0,1,2]
+        ds = {0: [0], 1: [1], 2: []}
+        cs = []
+        s = CSPSolver(vs, ds, cs)
+        with self.assertRaises(SolutionNotFound):
+            s.solve()
+
+
+class TestBacktrackingSolver(unittest.TestCase):
+    def testNoConstraintsLimitedDomain(self):
+        vs = [0,1,2]
+        ds = {0: [0], 1: [1], 2: [2]}
+        cs = []
+        s = BacktrackingSolver(vs, ds, cs)
+        self.assertEqual(s.solve(), {0:0, 1:1, 2:2})
+    def testNoConstraintsWideDomain(self):
+        vs = [0,1,2]
+        ds = {0: [0,1,2], 1: [0,1,2], 2: [0,1,2]}
+        cs = []
+        s = BacktrackingSolver(vs, ds, cs)
+        solution = s.solve()
+        self.assertIn(solution[0], [0,1,2])
+        self.assertIn(solution[1], [0,1,2])
+        self.assertIn(solution[2], [0,1,2])
+    def testAllDiffConstraintsLimitedDomain(self):
+        vs = [0,1,2]
+        ds = {0: [0], 1: [1], 2: [2]}
+        cs = [AllDiffConstraint([0,1,2])]
+        s = BacktrackingSolver(vs, ds, cs)
+        self.assertEqual(s.solve(), {0:0, 1:1, 2:2})
+    def testImpossibleAllDiffConstraint(self):
+        vs = [0,1,2]
+        ds = {0: [0], 1: [0,2], 2: [0]}
+        cs = [AllDiffConstraint([0,1,2])]
+        s = BacktrackingSolver(vs, ds, cs)
+        with self.assertRaises(SolutionNotFound):
+            s.solve()
+    def testOneBacktrack(self):
+        vs = [0,1,2]
+        ds = {0: [0], 1: [1], 2: [1,2]}
+        cs = [AllDiffConstraint([0,1,2])]
+        s = BacktrackingSolver(vs, ds, cs)
+        self.assertEqual(s.solve(), {0:0, 1:1, 2:2})
+    def testTwoBacktrack(self):
+        vs = [0,1,2]
+        ds = {0: [0], 1: [1,2], 2: [1]}
+        cs = [AllDiffConstraint([0,1,2])]
+        s = BacktrackingSolver(vs, ds, cs)
+        self.assertEqual(s.solve(), {0:0, 1:2, 2:1})
+    def testThreeBacktrack(self):
+        vs = [0,1,2]
+        ds = {0: [0,1], 1: [1,2], 2: [0]}
+        cs = [AllDiffConstraint([0,1,2])]
+        s = BacktrackingSolver(vs, ds, cs)
+        self.assertEqual(s.solve(), {0:1, 1:2, 2:0})
+
+
+if __name__ == "__main__":
+    unittest.main()
+
+

File src/sudoku/sudoku.py

 #!/usr/bin/env python
 
-import math, array, copy
+import math, array, copy, csp
 
 def unique(seq):
     """Determines whether or not every element in a sequence is unique."""
         """Constructor. See class doc string."""
         if filename:
             f = open(filename, 'r')
-            self._initBoard(int(f.readline()))
+            self.initBoard(int(f.readline()))
             vals = int(f.readline())
             for i in range(vals):
-                row = int(f.read(2)) - 1
-                col = int(f.read(2)) - 1
-                val = int(f.read(2))
+                line = f.readline()
+                chars = line.split()
+                row = int(chars[0]) - 1
+                col = int(chars[1]) - 1
+                val = int(chars[2])
                 self.set(row, col, val)
             f.close()
         else:
             if not size:
                 size = 9
-            self._initBoard(size)
+            self.initBoard(size)
     
-    def _initBoard(self, size):
-        self._size = size
-        self._width = int(math.sqrt(size))
-        self._board = array.array('H', [0 for i in range(size*size)])
+    def initBoard(self, size):
+        self.size = size
+        self.width = int(math.sqrt(size))
+        self.board = array.array('H', [0 for i in range(size*size)])
+    
+    def idx(self, row, col):
+        return row*self.size + col
+        
+    def boxIdx(self, box):
+        w = self.width
+        rs = w*(box//w)
+        cs = w*(box%w)
+        return [(i, j) for i in range(rs, rs+w) for j in range(cs, cs+w)]
     
     def get(self, row, col):
-        return self._board[row*self._size + col]
+        return self.board[self.idx(row, col)]
     
     def set(self, row, col, val):
-        self._board[row*self._size + col] = val
+        self.board[self.idx(row, col)] = val
+    
+    def toCSP(self, constructor):
+        variables = list(range(len(self.board)))
+        domains = dict()
+        for v in variables:
+            if self.board[v]:
+                domains[v] = [v]
+            else:
+                domains[v] = list(range(1,self.size+1))
+        constraints = list()
+        for row in self.size:
+            c = AllDiffConstraint([self.idx(row, col) for col in self.size])
+            constraints.add(c)
+        for col in self.size:
+            c = AllDiffConstraint([self.idx(row, col) for row in self.size])
+            constraints.add(c)
+        for box in self.size:
+            c = AllDiffConstraint([self.idx(row, col) for (row, col) in self.boxIdx(box)])
+            constraints.add(c)
+        return constructor(variables, domains, constraints)
     
     # binary-decision constraint tests
     
         return self.filled() and self.valid()
     
     def filled(self):
-        return not 0 in self._board
+        return not 0 in self.board
 
     def valid(self):
-        for i in range(self._size):
+        for i in range(self.size):
             if not self.rowValid(i): return False
             if not self.colValid(i): return False
             if not self.boxValid(i): return False
         return True
 
     def rowValid(self, row):
-        s = self._size
+        s = self.size
         for i in range(s):
             for j in range(s):
                 if i == j:
         return True
 
     def colValid(self, col):
-        s = self._size
+        s = self.size
         for i in range(s):
             for j in range(s):
                 if i == j:
         return True
 
     def boxValid(self, box):
-        s = self._size
-        w = self._width
+        s = self.size
+        w = self.width
         rs = w*(box//w)
         cs = w*(box%w)
         for i in range(rs, rs+w):
         return self.unfilled() + self.diff()
 
     def numUnfilled(self):
-        return self._board.count(0)
+        return self.board.count(0)
     
     def numDiff(self):
         invalid = 0
-        for i in range(self._size):
+        for i in range(self.size):
             if not self.numDiffRow(i): invalid += 1
             if not self.numDiffColl(i): invalid += 1
             if not self.numDiffBox(i): invalid += 1
         return invalid
     
     def numDiffRow(self, row):
-        s = self._size
+        s = self.size
         num = 0
         for i in s:
             for j in s:
         return num
     
     def numDiffColl(self, col):
-        return unique([self.get(i, col) for i in range(self._size)])
+        return unique([self.get(i, col) for i in range(self.size)])
     
     def numDiffBox(self, box):
         return unique([self.get(i, j) for (i, j) in self.boxIndices(box)])
     
     def boxIndices(self, box):
-        w = self._width
+        w = self.width
         rs = w*(box//w)
         cs = w*(box%w)
         return [(i, j) for i in range(rs, rs+w) for j in range(cs, cs+w)]
     
     def __str__(self):
-        s = self._size
+        s = self.size
         t = '+--'*s + '+\n'
         for i in range(s):
             for j in range(s):

File src/sudoku/sudoku_test.py

     
     def test_board4(self):
         b = sudoku.Board(size=4)
-        self.assertEqual(b._size, 4)
-        self.assertEqual(b._width, 2)
+        self.assertEqual(b.size, 4)
+        self.assertEqual(b.width, 2)
         self.assertFalse(b.complete())
         self.assertFalse(b.filled())
         self.assertFalse(b.valid())
-
+    
     def test_board25(self):
         b = sudoku.Board(size=25)
-        self.assertEqual(b._size, 25)
-        self.assertEqual(b._width, 5)
+        self.assertEqual(b.size, 25)
+        self.assertEqual(b.width, 5)
         self.assertFalse(b.complete())
         self.assertFalse(b.filled())
         self.assertFalse(b.valid())
         f.close()
         
         b = sudoku.Board(filename=self.path)
-        self.assertEqual(b._size, 4)
-        self.assertEqual(b._width, 2)
+        self.assertEqual(b.size, 4)
+        self.assertEqual(b.width, 2)
         self.assertFalse(b.complete())
         self.assertFalse(b.filled())
         self.assertFalse(b.valid())
         f.close()
         
         b = sudoku.Board(filename=self.path)
-        self.assertEqual(b._size, 4)
-        self.assertEqual(b._width, 2)
+        self.assertEqual(b.size, 4)
+        self.assertEqual(b.width, 2)
         self.assertFalse(b.complete())
         self.assertTrue(b.filled())
         self.assertFalse(b.valid())
         f.close()
         
         b = sudoku.Board(filename=self.path)
-        self.assertEqual(b._size, 4)
-        self.assertEqual(b._width, 2)
+        self.assertEqual(b.size, 4)
+        self.assertEqual(b.width, 2)
         self.assertTrue(b.complete())
         self.assertTrue(b.filled())
         self.assertTrue(b.valid())
-        
-        
+    
     def test_board9_unfilled(self):
         f = open(self.path, 'w')
         f.write("9\n")
         f.write("4 4 3\n")
         f.write("5 9 6\n")
         f.close()
-
+        
         b = sudoku.Board(filename=self.path)
-        self.assertEqual(b._size, 9)
-        self.assertEqual(b._width, 3)
+        self.assertEqual(b.size, 9)
+        self.assertEqual(b.width, 3)
         self.assertFalse(b.complete())
         self.assertFalse(b.filled())
         self.assertFalse(b.valid())
 
 
-class TestBoardMore(unittest.TestCase):
-    
-    def setUp(self):
-        self.b = sudoku.Board(size=4)
-        self.b.set(0, 0, 3)
-        self.b.set(0, 1, 4)
-        self.b.set(0, 2, 1)
-        self.b.set(0, 3, 2)
-        self.b.set(1, 0, 2)
-        self.b.set(1, 1, 1)
-        self.b.set(1, 2, 4)
-        self.b.set(1, 3, 3)
-        self.b.set(2, 0, 4)
-        self.b.set(2, 1, 3)
-        self.b.set(2, 2, 2)
-        self.b.set(2, 3, 1)
-        self.b.set(3, 0, 1)
-        self.b.set(3, 1, 2)
-        self.b.set(3, 2, 3)
-        self.b.set(3, 3, 4)
-    
-    def test_board4_complete(self):
-        self.assertTrue(self.b.complete())
-        self.assertTrue(self.b.filled())
-        self.assertTrue(self.b.valid())
-    
-    def test_board4_invalid(self):    
-        self.b.set(2, 2, 3)
-        self.assertFalse(self.b.complete())
-        self.assertTrue(self.b.filled())
-        self.assertFalse(self.b.valid())
-    
-    def test_board4_unfilled(self):    
-        self.b.set(2, 2, 0)
-        self.assertFalse(self.b.complete())
-        self.assertFalse(self.b.filled())
-        self.assertTrue(self.b.valid())
-    
-    def test_board4_str(self):
-        expected = """
-+--+--+--+--+
-| 3| 4| 1| 2|
-+--+--+--+--+
-| 2| 1| 4| 3|
-+--+--+--+--+
-| 4| 3| 2| 1|
-+--+--+--+--+
-| 1| 2| 3| 4|
-+--+--+--+--+
-"""
-        self.assertEqual(str(self.b), expected[1:])
-        self.assertEqual(repr(self.b), expected[1:])
-
-
-class TestSolve(unittest.TestCase):
-    
-    def setUp(self):
-        self.b = sudoku.Board(size=4)
-        self.b.set(0, 0, 3)
-        self.b.set(0, 1, 4)
-        self.b.set(0, 2, 1)
-        self.b.set(0, 3, 2)
-        self.b.set(1, 0, 2)
-        self.b.set(1, 1, 1)
-        self.b.set(1, 2, 4)
-        self.b.set(1, 3, 3)
-        self.b.set(2, 0, 4)
-        self.b.set(2, 1, 3)
-        self.b.set(2, 2, 2)
-        self.b.set(2, 3, 1)
-        self.b.set(3, 0, 1)
-        self.b.set(3, 1, 2)
-        self.b.set(3, 2, 3)
-        self.b.set(3, 3, 4)
-
-    
-    def testInitialDomains(self):
-        v = sudoku.BacktrackingSolver()
-        ds = v.initialDomains(self.b)
-        for i in range(4*4):
-            self.assertEqual(len(ds[i]), 1)
-
-    def testPartialInitialDomains(self):
-        self.b.set(0, 0, 0)
-        v = sudoku.BacktrackingSolver()
-        ds = v.initialDomains(self.b)
-        for i in range(4*4):
-            if i == 0:
-                self.assertEqual(len(ds[i]), 4)
-            else:
-                self.assertEqual(len(ds[i]), 1)
-    
-    def testBacktrackOneMissing(self):
-        self.b.set(3, 3, 0)
-        v = sudoku.BacktrackingSolver()
-        self.assertEqual(self.b.get(3, 3), 0)
-        self.assertTrue(v.solve(self.b))
-        self.assertEqual(self.b.get(3, 3), 4)
-    
-    def testBacktrackMultipleMissing(self):
-        self.b.set(0, 3, 0)
-        self.b.set(2, 1, 0)
-        v = sudoku.BacktrackingSolver()
-        self.assertEqual(self.b.get(0, 3), 0)
-        self.assertEqual(self.b.get(2, 1), 0)
-        self.assertTrue(v.solve(self.b))
-        self.assertEqual(self.b.get(0, 3), 2)
-        self.assertEqual(self.b.get(2, 1), 3)
-    
-    def testForwardCheckingOneMissing(self):
-        self.b.set(3, 3, 0)
-        v = sudoku.ForwardCheckingSolver()
-        self.assertEqual(self.b.get(3, 3), 0)
-        self.assertTrue(v.solve(self.b))
-        self.assertEqual(self.b.get(3, 3), 4)
-
-    def testForwardCheckingMultipleMissing(self):
-        self.b.set(0, 3, 0)
-        self.b.set(2, 1, 0)
-        v = sudoku.ForwardCheckingSolver()
-        self.assertEqual(self.b.get(0, 3), 0)
-        self.assertEqual(self.b.get(2, 1), 0)
-        self.assertTrue(v.solve(self.b))
-        self.assertEqual(self.b.get(0, 3), 2)
-        self.assertEqual(self.b.get(2, 1), 3)
-    
-    def testForwardCheckingManyMissing(self):
-        self.b.set(2, 1, 0)
-        self.b.set(2, 0, 0)
-        self.b.set(2, 3, 0)
-        v = sudoku.ForwardCheckingSolver()
-        self.assertEqual(self.b.get(2, 1), 0)
-        self.assertEqual(self.b.get(2, 0), 0)
-        self.assertEqual(self.b.get(2, 3), 0)
-        self.assertTrue(v.solve(self.b))
-        self.assertEqual(self.b.get(2, 1), 3)
-        self.assertEqual(self.b.get(2, 0), 4)
-        self.assertEqual(self.b.get(2, 3), 1)
-    
-
 if __name__ == "__main__":
     unittest.main()