Commits

Charlie Arnold committed 470bb79

first shot at general solve algo .. Doesn't work.

Comments (0)

Files changed (1)

 import os
 import os.path as p
 import logging
+import copy
 from itertools import chain, count
 
 __all__ = ['SudokuPuzzle', 'RowCellGroup', 'ColCellGroup', 'SqrCellGroup', 'Cell']
         grpi = GROUP_SIZE*(rowi/GROUP_SIZE)+(coli/GROUP_SIZE)
         return self.sqrGroups[grpi]
     
+    def _rowStr(self, colWidths, row):
+        row = [c.center(colWidths[i]) for i, c in enumerate(row)]
+        return '| %s | %s | %s |'% (
+            ' '.join(row[:3]), ' '.join(row[3:6]), ' '.join(row[6:])
+        )
+    
+    def __str__(self):
+        rows = [[c._shortStr() for c in row]
+                for row in self.rows]
+        colWidths = [max([len(c) for c in col])
+                     for col in zip(*rows)]
+        
+        border = '_'*(sum(colWidths)+16)
+        rowstrs = [border]
+        for rowi, row in enumerate(rows):
+            rowstrs.append(self._rowStr(colWidths, row))
+            if (rowi + 1) % 3 == 0:
+                rowstrs.append(border)
+        
+        return '\n'.join(rowstrs)
+    
+    @property
+    def solved(self):
+        for cell in self.cells:
+            if not cell.solved:
+                return False
+        return True
+    
+    @property
+    def cells(self):
+        for row in self.rows:
+            for cell in row:
+                yield cell
+    
     def initFromBinFile(self, filename):
         with file(filename, 'rb') as fobj:
             vals = [ord(b) for b in fobj.read()]
         assert len(vals) == 81, 'expected 81 rows (got %d)'% len(vals)
         assert all(map(lambda c: type(c) is int, vals)), 'all rows need to be int vals'
         
-        possibles = set(range(1,10))
         for i, v in enumerate(vals):
             if v == 0:
                 continue
             rowi, coli = i/9, i%9
-            self.rows[rowi][coli].excludeVals(possibles.difference((v,)))
+            self.rows[rowi][coli].setValue(v)
     
     def outputBinFile(self, filename, error=None):
         with open(filename, 'wb') as fobj:
             cells = chain(*(rowgrp.cells for rowgrp in self.rowGroups))
             fobj.write(''.join([chrval(cell) for cell in cells]))
     
-    def _rowStr(self, rowi):
-        row = [c._shortStr() for c in self.rows[rowi]]
-        join = lambda l: ' '.join([str(i) for i in l])
-        return '|%s | %s | %s |'% (join(row[:3]), join(row[3:6]), join(row[6:]))
-    
-    def __str__(self):
-        border = '-' * 24
-        rows = [border]
-        for rowi in range(9):
-            rows.append(self._rowStr(rowi))
-            if (rowi + 1) % 3 == 0:
-                rows.append(border)
-        return '\n'.join(rows)
-    
     def solveCertain(self):
         ''' Solve as much of the puzzle as we can with certainty
         '''
         
         while True:
-            foundClosedSet = False
+            modifiedPuzzle = False
             for grpList in (self.rowGroups, self.colGroups, self.sqrGroups):
                 for grp in grpList:
                     for cellset in grp.findClosedSets():
-                        foundClosedSet = True
                         closedSetVals = iter(cellset).next().possibles
                         uncertainCells = set(c for c in grp.cells if c.certainty < 1)
                         for cell in uncertainCells.difference(cellset):
-                            if cell.certainty < 1: # could be changed from previous exclusions
-                                cell.excludeVals(closedSetVals)
-            if not foundClosedSet:
+                            if not cell.solved: # could be changed from previous exclusions
+                                if len(closedSetVals - cell.possibles) < len(closedSetVals):
+                                    modifiedPuzzle = True
+                                    cell.excludeVals(closedSetVals)
+            if not modifiedPuzzle:
                 break
     
+    def clone(self):
+        return copy.deepcopy(self) # Not sure if this works ...
+#        clone = SudokuPuzzle()
+#        for cell in self.cells:
+#            if cell.solved:
+#                clonecell = clone.rows[cell.rowGrp.idx][cell.colGrp.idx]
+#                clonecell.setValue(cell.val)
+#        return clone
+    
     def solve(self):
-        raise NotImplementedError
-
+        '''
+        Solve the puzzle, even if it involves guessing ...
+        '''
+        
+        def recTestAllPossibles(pzl, cell, unsolvedCellIter):
+            ''' recursively try all possible values for all unsolved cells
+            '''
+            
+            if cell.solved:
+                return pzl
+            
+            for val in cell.possibles:
+                clone = pzl.clone()
+                clonecell = clone.rows[cell.rowGrp.idx][cell.colGrp.idx]
+                try:
+                    clonecell.setValue(val)
+                    clone.solveCertain()
+                except ProcessingError:
+                    # incorrect value .. try a different one
+                    continue
+                else:
+                    if clone.solved:
+                        break
+                    # FIXME: must be in try/catch, just like above ..
+                    nextcell = unsolvedCellIter.next() # Can't be empty if not solved ..
+                    clone = recTestAllPossibles(clone, nextcell, unsolvedCellIter)
+            else:
+                raise ProcessingError("no val in possibles worked")
+            
+            return clone # Solved Test ..
+        
+        self.solveCertain()
+        if self.solved:
+            return
+        
+        # Create an iterator of unsolved cells, putting the ones with the
+        #   highest probability of being solved correctly and their being
+        #   guessed correctly first ...
+        groupCertainty = {}
+        def cellGroupCertainty(cell):
+            if not groupCertainty.has_key(cell):
+                groups = (cell.rowGrp, cell.colGrp, cell.sqrGrp)
+                groupCertainty[cell] = sum([grp.certainty for grp in groups])
+            return groupCertainty[cell]
+        
+        sortkey = lambda cell: (cell.certainty, cellGroupCertainty(cell))
+        unsolvedCells = [cell for cell in self.cells if not cell.solved]
+        sortedUnsolvedCells = sorted(unsolvedCells, key=sortkey, reverse=True)
+        unsolvedCellIter = iter(sortedUnsolvedCells)
+        cell = unsolvedCellIter.next()
+        solved = recTestAllPossibles(self, cell, unsolvedCellIter)
+        
+        # Copy solved vals back into self ..
+        for unsolved, solved in zip(self.cells, solved.cells):
+            unsolved.setValue(solved)
 
 
 class CellGroup(object):
             cells.difference_update(cs)
         
         return closedsets
+    
+    def _cellsByUncertainty(self):
+        ''' Get unsolved cells sorted by least uncertainty
+        '''
+        unsolvedCells = filter(lambda c: c.certainty < 1, self.cells)
+        return sorted(unsolvedCells, key=lambda c: c.certainty, reverse=True)
+    
+    @property
+    def certainty(self):
+        return sum(c.certainty for c in self.cells)/len(self.cells)
 
 
 class RowCellGroup(CellGroup):
     
     @property
     def certainty(self):
-        return 1 if self.val else (9-len(self.possibles))/9
+        return 1 if self.val else (9.0-len(self.possibles))/9.0
+    
+    @property
+    def solved(self):
+        return self.val is not None
     
     def _checkKnown(self):
         if len(self.possibles) == 1:
                 if connectedCell in self.connected:
                     connectedCell.disconnect(self)
     
+    def setValue(self, val):
+        if self.val is not None:
+            if self.val != val:
+                raise ProcessingError('Setting val in already set cell')
+            return
+        self.excludeVals(self.possibles.difference((val,)))
+    
     def excludeVals(self, vals):
         self.possibles.difference_update(vals)
         if len(self.possibles) == 0:
             return '%d'% self.val
         if len(self.possibles) <= 3:
             return str(tuple(self.possibles))
-        return '%s, ...'% ', '.join(map(str, list(self.possibles)[:3]))
+        return '(%s, ...)'% ', '.join(map(str, list(self.possibles)[:3]))
     
     def __str__(self):
         return 'Cell @ (%d, %d) (%s)'% (
         )
 
 
-def main():
-    import sys
-    
-    if len(sys.argv) < 3:
-        print 'usage: %s [input_file] [output_file]'
-    
-    logging.basicConfig(level=logging.INFO)
-    inpth, outpth = sys.argv[1:3]
-    
+def main(inpth, outpth):
     pzl = SudokuPuzzle()
     pzl.initFromBinFile(inpth)
     logging.info(' Result: \n%s'% pzl)
-    pzl.solveCertain()
+    pzl.solve()
     logging.info(' Result: \n%s'% pzl)
+    raw_input()
     pzl.outputBinFile(outpth)
 
 if __name__ == '__main__':
-    try:
-        main()
-    except Exception, e:
-        logging.exception(e)
-        raw_input()
+    
+    logging.basicConfig(level=logging.INFO)
+    
+    testd = r'D:\Users\charlie\Documents\eclipse_workspace\sudoku'
+    infn, outfn = p.join(testd, 'snail1.in'), p.join(testd, 'snail1.out')
+    main(infn, outfn)
+    
+#    import sys
+#    
+#    if len(sys.argv) < 3:
+#        print 'usage: %s [input_file] [output_file]'
+#    
+#    inpth, outpth = sys.argv[1:3]
+#    
+#    try:
+#        main(inpth, outpth)
+#    except Exception, e:
+#        logging.exception(e)
+#        raw_input()
 
-#    testd = r'\\ZD\Users\charlie\Documents\eclipse_workspace\sudoku'
-#    
-#    logging.basicConfig(level=logging.INFO)
-#    pzl = SudokuPuzzle()
-#    pzl.initFromBinFile(p.join(testd, 'snail2.in'))
-#    logging.info(' Result: \n%s'% pzl)
-#    pzl.solveCertain()
-#    logging.info(' Result: \n%s'% pzl)
-    
-    
-#    for fn in os.listdir(testd):
-#        if fn.endswith('.in'):
-#            pth = p.join(testd, fn)
-#            outfobj = file(p.splitext(pth)[0] + '.out', 'wb')
-#            pzl = SudokuPuzzle()
-#            pzl.initFromBinFile(pth)
-#            logging.info(' Result: \n%s'% pzl)
-            
-#            try:
-#                pzl.bruteSolve()
-#            except Exhausted, e:
-#                for c in pzl.rows:
-#                    outfobj.write(chr(c))
-#            except Exception, e:
-#                logging.exception(e)
-#                outfobj.write(chr(32))
-#                outfobj.write(str(e))
-#            else:
-#                for c in pzl.rows:
-#                    outfobj.write(chr(c))