Commits

Geoff Hill committed 00b1f77

more pair programming work

Comments (0)

Files changed (2)

src/sudoku/sudoku.py

 #!/usr/bin/env python
 
-import math, array
+import math, array, copy
 
 def unique(seq):
     """Determines whether or not every element in a sequence is unique."""
     return len(seq) == len(set(seq))
 
-def selectFromSet(s):
-    """Returns an arbitrary element from a set."""
-    for i in s:
-        return i
 
 
 class Board(object):
         cs = w*(box%w)
         return [(i, j) for i in range(rs, rs+w) for j in range(cs, cs+w)]
     
-    # solution algorithms
-
-    def _initialDomains(self):
+    def __str__(self):
         s = self._size
-        domains = [set(range(1, s+1)) for i in range(s*s)]
-        for row in range(s):
-            for col in range(s):
-                val = self.get(row, col)
-                if val != 0:
-                    domains[row*s + col] = set([val,])
-        return domains
+        t = '+--'*s + '+\n'
+        for i in range(s):
+            for j in range(s):
+                t += '|%2d' % (self.get(i, j),)
+            t += '|\n' + '+--'*s + '+\n'
+        return t
     
-    def backtrackSolve(self):
-        s = self._size
+    def __repr__(self):
+        return str(self)
+
+
+class SolutionNotFound(Exception):
+    def __str__(self):
+        return "Solution not found."
+
+class MaxSolutionDepthExceeded(Exception):
+    def __init__(self, depth):
+        self.depth = depth
+    def __str__(self):
+        return "Max solution depth exceeded %d." % (self.depth,)
+
+
+class BacktrackingSolver(object):
+
+    MAX_CHECKS = 1e7
+
+    def solve(self, b):
+        s = b._size
         m = s*s - 1
-        domain = self._initialDomains()
-        tried = [set() for i in range(s*s)]
-        backup = array.array('H', self._board)
+        domains = [self.initialDomains(b),]
+        backup = array.array('H', b._board)
         checks = 0
         
         pos = 0
         while True:
-            if pos == 0 and domain[0] == tried[0]:
-                self._board = backup
-                return False
-            elif domain[pos] == tried[pos]:
-                tried[pos] = set()
-                self._board[pos] = 0
+            if pos == 0 and not domains[-1][0]:
+                b._board = backup
+                raise SolutionNotFound
+            if not domains[-1][pos]:
+                domains.pop()
                 pos -= 1
                 continue
-            val = selectFromSet(domain[pos] - tried[pos])
-            tried[pos].add(val)
-            self._board[pos] = val
-            checks += 1
-            if self.complete():
+            val = self.valueSelect(domains[-1][pos])
+            b._board[pos] = val
+            if b.complete():
                 return checks
-            elif pos == m:
+            domains[-1][pos].remove(val)
+            checks += 1
+            if checks > self.MAX_CHECKS:
+                raise MaxSolutionDepthExceeded(self.MAX_CHECKS)
+            if pos == m:
                 continue
             else:
+                new_domain = self.inference(b, domains[-1], pos)
+                domains.append(new_domain)
                 pos += 1
-        
     
-    def __str__(self):
-        s = self._size
-        t = '+--'*s + '+\n'
-        for i in range(s):
-            for j in range(s):
-                t += '|%2d' % (self.get(i, j),)
-            t += '|\n' + '+--'*s + '+\n'
-        return t
+    def initialDomains(self, b):
+        s = b._size
+        domains = [set(range(1, s+1)) for i in range(s*s)]
+        for row in range(s):
+            for col in range(s):
+                val = b.get(row, col)
+                if val != 0:
+                    domains[row*s + col] = set([val,])
+        return domains
+
+    def valueSelect(self, s):
+        for i in s:
+            return i
+
+    def inference(self, b, domain, pos):
+        return copy.deepcopy(domain)
+
+
+class ForwardCheckingSolver(BacktrackingSolver):
     
-    def __repr__(self):
-        return str(self)
+    def inference(self, b, domain, pos):
+        new_domain = copy.deepcopy(domain)
+        s = b._size
+        w = b._width
+        val = b._board[pos]
+
+        row = pos // s
+        col = pos % s
+        for i in range(row+1, s):
+            idx = row*s + i
+            new_domain[idx].discard(val)
+        for i in range(col+1, s):
+            idx = i*s + col
+            new_domain[idx].discard(val)
 
+        box = w*(row//w) + col//w
+        rs = w*(box//w)
+        cs = w*(box%w)
+        for i in range(rs, rs+w):
+            for j in range(cs, cs+w):
+                if i > row and j > col:
+                    idx = i*s + j
+                    new_domain[idx].discard(val)
+        return new_domain
+    
 
 
 if __name__ == "__main__":

src/sudoku/sudoku_test.py

         self.b.set(3, 2, 3)
         self.b.set(3, 3, 4)
 
+    
     def testInitialDomains(self):
-        ds = self.b._initialDomains()
+        v = sudoku.BacktrackingSolver()
+        ds = v.initialDomains(self.b)
         for i in range(4*4):
             self.assertEqual(len(ds[i]), 1)
 
     def testPartialInitialDomains(self):
         self.b.set(0, 0, 0)
-        ds = self.b._initialDomains()
+        v = sudoku.BacktrackingSolver()
+        ds = v.initialDomains(self.b)
         for i in range(4*4):
             if i == 0:
                 self.assertEqual(len(ds[i]), 4)
             else:
                 self.assertEqual(len(ds[i]), 1)
-
+    
     def testBacktrackOneMissing(self):
         self.b.set(3, 3, 0)
+        v = sudoku.BacktrackingSolver()
         self.assertEqual(self.b.get(3, 3), 0)
-        self.assertTrue(self.b.backtrackSolve())
+        self.assertTrue(v.solve(self.b))
         self.assertEqual(self.b.get(3, 3), 4)
-
+    
     def testBacktrackMultipleMissing(self):
         self.b.set(0, 3, 0)
         self.b.set(2, 1, 0)
+        v = sudoku.BacktrackingSolver()
         self.assertEqual(self.b.get(0, 3), 0)
         self.assertEqual(self.b.get(2, 1), 0)
-        self.assertTrue(self.b.backtrackSolve())
+        self.assertTrue(v.solve(self.b))
         self.assertEqual(self.b.get(0, 3), 2)
         self.assertEqual(self.b.get(2, 1), 3)
-        
+    
+    def testForwardCheckingOneMissing(self):
+        self.b.set(3, 3, 0)
+        v = sudoku.ForwardCheckingSolver()
+        self.assertEqual(self.b.get(3, 3), 0)
+        self.assertTrue(v.solve(self.b))
+        self.assertEqual(self.b.get(3, 3), 4)
+
+    def testForwardCheckingMultipleMissing(self):
+        self.b.set(0, 3, 0)
+        self.b.set(2, 1, 0)
+        v = sudoku.ForwardCheckingSolver()
+        self.assertEqual(self.b.get(0, 3), 0)
+        self.assertEqual(self.b.get(2, 1), 0)
+        self.assertTrue(v.solve(self.b))
+        self.assertEqual(self.b.get(0, 3), 2)
+        self.assertEqual(self.b.get(2, 1), 3)
+    
+    def testForwardCheckingManyMissing(self):
+        self.b.set(2, 1, 0)
+        self.b.set(2, 0, 0)
+        self.b.set(2, 3, 0)
+        v = sudoku.ForwardCheckingSolver()
+        self.assertEqual(self.b.get(2, 1), 0)
+        self.assertEqual(self.b.get(2, 0), 0)
+        self.assertEqual(self.b.get(2, 3), 0)
+        self.assertTrue(v.solve(self.b))
+        self.assertEqual(self.b.get(2, 1), 3)
+        self.assertEqual(self.b.get(2, 0), 4)
+        self.assertEqual(self.b.get(2, 3), 1)
+    
 
 if __name__ == "__main__":
     unittest.main()