Commits

Charlie Arnold committed 5ede2a3

second run; have same easy points done and enough structure to start with the next level

Comments (0)

Files changed (2)

 import logging
 from itertools import chain, count
 
+__all__ = ['SudokuPuzzle', 'RowCellGroup', 'ColCellGroup', 'SqrCellGroup', 'Cell']
 
+class ExhaustedError(Exception):
+    "Thrown when the solver gives up"
 
-class BohError(Exception):
-    "Couldn't solve the puzzle :("
-
-
+class ProcessingError(Exception):
+    "Raised when a logic error occurs"
 
 
 class SudokuPuzzle(object):
     
-    _logger = logging.getLogger('Puzzle Logger')
+    indices = range(9)
     
-    def __init__(self, cells):
-        assert len(cells) == 81, 'expected 81 cells (got %d)'% len(cells)
-        assert all(map(lambda c: type(c) is int, cells)), 'all cells need to be int vals'
+    def __init__(self):
+        self.rowGroups = [RowCellGroup(i) for i in self.indices]
+        self.colGroups = [ColCellGroup(i) for i in self.indices]
+        self.sqrGroups = [SqrCellGroup(i) for i in self.indices]
+        self.rows = []
         
-#        self.rows = [[set() if c == 0 else set(c) for c in row]
-#                     for row in cells]
-        self.cells = [set() if c == 0 else set(c)
-                      for c in cells]
+        for rg in self.rowGroups:
+            row = []
+            self.rows.append(row)
+            for cg in self.colGroups:
+                sg = self._getSquareGrp(rg.idx, cg.idx)
+                row.append(Cell(rg, cg, sg))
+        
+        for row in self.rows:
+            for cell in row:
+                cell._connectCells()
     
-    @classmethod
-    def fromRawFile(cls, filename):
-        'Factory method to construct from bytes file ..'
-        bl = file(filename, 'rb').read()
-        return cls([ord(b) for b in bl])
+    def _getSquareGrp(self, rowi, coli):
+        'create index 0 - 9 from rowi, coli'
+        GROUP_SIZE = 3
+        grpi = GROUP_SIZE*(rowi/GROUP_SIZE)+(coli/GROUP_SIZE)
+        return self.sqrGroups[grpi]
     
-    def _getRow(self, rowi):
-        assert rowi < 9
-        return self.cells[9*rowi:(9*rowi)+9]
-    
-    def _getCol(self, coli):
-        assert coli < 9
-        rows = [self._getRow(i) for i in range(9)]
-        return list(zip(*rows)[coli])
-    
-    def _getGroup(self, rowi, coli):
-        assert rowi < 9 and coli < 9
-        if rowi < 3:
-            rows = [self._getRow(ridx) for ridx in range(3)]
-        elif rowi < 6:
-            rows = [self._getRow(ridx) for ridx in range(3, 6)]
-        else:
-            rows = [self._getRow(ridx) for ridx in range(6, 9)]
-        colStart = 3 * (coli / 3)
-        return list(chain(*[r[colStart:colStart+3] for r in rows]))
-    
-    
+    def initFromBinFile(self, filename):
+        with file(filename, 'rb') as fobj:
+            vals = [ord(b) for b in fobj.read()]
+        assert len(vals) == 81, 'expected 81 rows (got %d)'% len(vals)
+        assert all(map(lambda c: type(c) is int, vals)), 'all rows need to be int vals'
+        
+        possibles = set(range(1,10))
+        for i, v in enumerate(vals):
+            if v == 0:
+                continue
+            rowi, coli = i/9, i%9
+            self.rows[rowi][coli].excludeVals(possibles.difference((v,)))
+
     def _rowStr(self, rowi):
-        row = self._getRow(rowi)
-        itos = lambda l: ' '.join([str(i) for i in l])
-        return '|%s | %s | %s |'% (itos(row[:3]), itos(row[3:6]), itos(row[6:]))
-    
-    possibles = range(1,10)
-    def _getRowPossibles(self, rowi):
-        r = self._getRow(rowi)
-#        assert r[cellidx] == 0, 'getting possibles for solved cell'
-        return [i for i in self.possibles if i not in r]
-    
-    def _getColPossibles(self, coli):
-        c = self._getCol(coli)
-        return [i for i in self.possibles if i not in c]
-    
-    def _getGroupPossibles(self, rowi, coli):
-        g = self._getGroup(rowi, coli)
-        return [i for i in self.possibles if i not in g]
+        row = [c._shortStr() for c in self.rows[rowi]]
+        join = lambda l: ' '.join([str(i) for i in l])
+        return '|%s | %s | %s |'% (join(row[:3]), join(row[3:6]), join(row[6:]))
     
     def __str__(self):
         border = '-' * 24
             if (rowi + 1) % 3 == 0:
                 rows.append(border)
         return '\n'.join(rows)
+
+
+class CellGroup(object):
     
-    def _unsolved(self):
-        return self.cells.count(0) > 0
+    def __init__(self, idx):
+        self.idx = idx
+        self.cells = []
     
-    def _easyEliminate(self, rowi, coli):
-        pl = set(self.possibles)
-        pl.difference_update(self._getRow(rowi))
-        pl.difference_update(self._getCol(coli))
-        group = self._getGroup(rowi, coli)
-        pl.difference_update(group)
-        return pl
+    def _addCell(self, cell):
+        self.cells.append(cell)
     
-    def _eliminatePairs(self, rowi, coli, pl):
-        pass
+    def findExclusives(self):
+        'Find cells that contain all of a certain set of values'
+        cells = list(self.cells)
+        for i, c in enumerate(cells):
+            cl = [test for test in cells[i:]
+                  if test.possibles == c.possibles]
+
+class RowCellGroup(CellGroup):
+    pass
+
+class ColCellGroup(CellGroup):
+    pass
+
+class SqrCellGroup(CellGroup):
+    pass
+
+
+class Cell(object):
     
-    def _setupPossibles(self):
-        pass
+    def __init__(self, rowGrp, colGrp, sqrGrp):
+        self.rowGrp = rowGrp
+        self.rowGrp._addCell(self)
+        self.colGrp = colGrp
+        self.colGrp._addCell(self)
+        self.sqrGrp = sqrGrp
+        self.sqrGrp._addCell(self)
+        
+        self.possibles = set(range(1,10))
+        self.val = None
+        self.connected = None
     
-    def bruteSolve2(self):
+    @property
+    def certainty(self):
+        return 1 if self.val else (9-len(self.possibles))/9
+    
+    def _checkKnown(self):
+        if len(self.possibles) == 1:
+            self.val = self.possibles.pop()
+            assert len(self.possibles) == 0
+            for connectedCell in list(self.connected):
+                # recursion may have already disconnected:
+                if connectedCell in self.connected:
+                    connectedCell.disconnect(self)
+    
+    def excludeVals(self, vals):
+        self.possibles.difference_update(vals)
+        if len(self.possibles) == 0:
+            raise ProcessingError('Removing all possible values')
+        self._checkKnown()
+    
+    def excludeVal(self, val):
+        if self.val:
+            if val == self.val:
+                raise ProcessingError('Excluding only possible')
+            return
         
-        self._logger.info(' Solving:\n%s'% self)
-        self._setupPossibles()
+        if val in self.possibles:
+            self.possibles.remove(val)
+        if len(self.possibles) == 0:
+            raise ProcessingError('Removing all possible values')
+        self._checkKnown()
     
-    def bruteSolve(self):
-        
-        self._logger.info(' Solving:\n%s'% self)
-        
-        for i in count():
-            if not self._unsolved():
-                self._logger.info(' Solved :) .. \n%s'% self)
-                return
-            
-            gotOne = False
-            self._logger.debug('Solve Pass %d'% i)
-            for rowi in range(9):
-                for coli in range(9):
-                    if self.cells[rowi*9+coli] != 0:
-                        continue
-                    pl = self._easyEliminate(rowi, coli)
-                    pl = self._eliminatePairs(rowi, coli, pl)
-                    assert len(pl) > 0, 'not 0 but no possibles??'
-                    if len(pl) == 1:
-                        gotOne = True
-                        self.cells[rowi*9+coli] = pl.pop()
-                    else:
-                        self._logger.debug(
-                            'Possibles for (%d, %d): %s'% (rowi, coli, pl)
-                        )
-            
-            if not gotOne:
-                self._logger.info(" Failed :( \n%s"% self)
-                raise BohError("Naive solve didn't work")
+    def disconnect(self, cell):
+        "called by a connected cell when it's value becomes known"
+        self.connected.remove(cell)
+        self.excludeVal(cell.val)
+    
+    def _connectCells(self):
+        "called after construction to create cell connections"
+        assert self.connected == None, 'already connected'
+        self.connected = set()
+        for grp in self.rowGrp, self.colGrp, self.sqrGrp:
+            for cell in grp.cells:
+                if cell is self:
+                    continue
+                self.connected.add(cell)
+    
+    def _shortStr(self):
+        if self.val:
+            return '%d'% self.val
+        if len(self.possibles) <= 3:
+            return str(tuple(self.possibles))
+        return '%s, ...'% ', '.join(map(str, list(self.possibles)[:3]))
+    
+    def __str__(self):
+        return 'Cell @ (%d, %d) (%s)'% (
+            self.rowGrp.idx, self.colGrp.idx, self._shortStr()
+        )
+
+
+
 
 if __name__ == '__main__':
-#    logging.basicConfig(level=logging.INFO)
-#    pzl = SudokuPuzzle.fromRawFile(r'C:\sudoku\11-1.in')
-#    pzl.bruteSolve()
     
     testd = r'\\ZD\Users\charlie\Documents\eclipse_workspace\sudoku'
     
-    for fn in os.listdir(testd):
-        if fn.endswith('.in'):
-            pth = p.join(testd, fn)
-            outfobj = file(p.splitext(pth)[0] + '.out', 'wb')
-            pzl = SudokuPuzzle.fromRawFile(pth)
-            try:
-                pzl.bruteSolve()
-            except BohError, e:
-                for c in pzl.cells:
-                    outfobj.write(chr(c))
-            except Exception, e:
-                logging.exception(e)
-                outfobj.write(chr(32))
-                outfobj.write(str(e))
-            else:
-                for c in pzl.cells:
-                    outfobj.write(chr(c))
+    logging.basicConfig(level=logging.INFO)
+    pzl = SudokuPuzzle()
+    pzl.initFromBinFile(p.join(testd, '11-1.in'))
+    logging.info(' Result: \n%s'% pzl)
+    
+    
+#    for fn in os.listdir(testd):
+#        if fn.endswith('.in'):
+#            pth = p.join(testd, fn)
+#            outfobj = file(p.splitext(pth)[0] + '.out', 'wb')
+#            pzl = SudokuPuzzle()
+#            pzl.initFromBinFile(pth)
+#            logging.info(' Result: \n%s'% pzl)
+            
+#            try:
+#                pzl.bruteSolve()
+#            except Exhausted, e:
+#                for c in pzl.rows:
+#                    outfobj.write(chr(c))
+#            except Exception, e:
+#                logging.exception(e)
+#                outfobj.write(chr(32))
+#                outfobj.write(str(e))
+#            else:
+#                for c in pzl.rows:
+#                    outfobj.write(chr(c))
             
 
 
 import logging
 
 import unittest
-from entry import SudokuPuzzle
+from entry import *
 
 
 
 class Test(unittest.TestCase):
 
-    testd = r'C:\sudoku'
+    testd = r'\\ZD\Users\charlie\Documents\eclipse_workspace\sudoku'
 
-    def testName(self):
-        pass
-
-    def testRead(self):
-        for fn in os.listdir(self.testd):
-            if fn.endswith('.in'):
-                pth = p.join(self.testd, fn)
-                print '%s: \n%s'% (fn, SudokuPuzzle.fromRawFile(pth))
+#    def testRead(self):
+#        for fn in os.listdir(self.testd):
+#            if fn.endswith('.in'):
+#                pth = p.join(self.testd, fn)
+#                print '%s: \n%s'% (fn, SudokuPuzzle().initFromBinFile(pth))
     
-    def testGetRow(self):
-        fn = '11-1.in'
-        pzl = SudokuPuzzle.fromRawFile(p.join(self.testd, fn))
-        expected = [4, 6, 2, 0, 0, 5, 0, 0, 0]
-        found = pzl._getRow(0)
-        self.assertTrue(found == expected)
-        expected = [7, 0, 8, 0, 0, 9, 0, 0, 0]
-        found = pzl._getRow(2)
-        self.assertTrue(found == expected)
+    def setup2cells(self):
+        rgrp, cgrp, sgrp = RowCellGroup(0), ColCellGroup(0), SqrCellGroup(0)
+        c1 = Cell(rgrp, cgrp, sgrp)
+        c2 = Cell(rgrp, cgrp, sgrp)
+        c1._connectCells()
+        c2._connectCells()
+        self.assert_(c1.connected == set([c2]))
+        self.assert_(c2.connected == set([c1]))
+        return c1, c2
     
-    def testGetCol(self):
-        fn = '11-1.in'
-        pzl = SudokuPuzzle.fromRawFile(p.join(self.testd, fn))
-        expected = [0, 0, 0, 2, 7, 0, 1, 0, 9]
-        found = pzl._getCol(8)
-        self.assertTrue(found == expected)
-
-    def testGetGroup(self):
-        fn = '11-1.in'
-        pzl = SudokuPuzzle.fromRawFile(p.join(self.testd, fn))
-        expected = [0, 5, 0, 0, 0, 0, 0, 2, 0]
-        found = pzl._getGroup(5, 5)
-        self.assertTrue(found == expected)
+    def testCellUpdate(self):
+        c1, c2 = self.setup2cells()
+        c1.excludeVals(range(1,9))
+        self.assert_(c1.val == 9)
+        self.assert_(9 not in c2.possibles)
     
-    def testBruteSolve(self):
-        bruteSolve
+    def testCircularUpdate(self):
+        c1, c2 = self.setup2cells()
+        c1.excludeVals(range(1,8))
+        c2.excludeVals(range(1,8))
+        self.assert_(c1.possibles == set([8, 9]))
+        self.assert_(c2.possibles == c1.possibles)
+        c1.excludeVal(8)
+        self.assert_(c1.val == 9)
+        self.assert_(c2.val == 8)
+    
+#    def testCellUpdate(self):
+#        fn = '11-1.in'
+#        pzl = SudokuPuzzle()
+#        pzl.initFromBinFile(p.join(self.testd, fn))
+        
 
 if __name__ == "__main__":
     #import sys;sys.argv = ['', 'Test.testName']