Commits

Pierre Carbonnelle committed 672c204

partial improvement of rank, running_sum

  • Participants
  • Parent commits cd0e77b

Comments (0)

Files changed (3)

File pyDatalog/examples/test.py

 
     @pyDatalog.program()
     def rank(): 
+        + score('Superman', 10)
+        + score('Tom', 3)
+        + score('Jerry', 8)
+        
+        (place[Person]==rank(order_by=Score)) <= score(Person, Score)
+        assert ask(place[Person]==Rank) == set([('Superman', 2),('Tom', 0),('Jerry', 1)])
+        #print ask(place['Jerry']==Rank) # TODO
+        assert ask(place[Person]==1) == set([('Jerry',)])
+
         (a_rank1[Z] == rank(for_each=Z, order_by=Z)) <= q(X, Y, Z)
         assert ask(a_rank1[X]==Y) == set([(1, 0), (2, 0), (4, 0)])
         assert ask(a_rank1[X]==0) == set([(1,), (2,), (4,)])
         assert ask(a_rank1[1]==0) == set([()])
         assert ask(a_rank1[1]==1) == None
 
+        (a_rank2[X,Y] == rank(order_by=Z)) <= q(X, Y, Z)
+        assert ask(a_rank2[X,Y]==Z) == set([('a','c', 0), ('a','b', 1), ('b', 'b', 2)])
+        assert ask(a_rank2[X,Y]==1) #TODO== set([(a,b)])
+        assert ask(a_rank2[a,Y]==Z) == set([('c',0),('b',1)])
+        assert ask(a_rank2[a,Y]==1) == set([('b',)])
+        assert ask(a_rank2[a,Y]==0) == set([('c',)])
+
         # rank
-        (a_rank[X,Y] == rank(for_each=(X,Y2), order_by=Z2)) <= q(X, Y, Z) & q(X,Y2,Z2)
-        assert ask(a_rank[X,Y]==Z) == set([('a', 'b', 1), ('a', 'c', 0), ('b', 'b', 0)])
-        assert ask(a_rank[a,b]==1) == set([()])
-        assert ask(a_rank[a,b]==Y) == set([(1,)])
-        assert ask(a_rank[a,X]==0) == set([('c',)])
-        assert ask(a_rank[a,X]==Y) == set([('b', 1), ('c', 0)])
-        assert ask(a_rank[X,Y]==1) == set([('a', 'b')])
-        assert ask(a_rank[a,y]==Y) == None
+        (a_rank[X,Y] == rank(for_each=(X,Y), order_by=Z)) <= q(X, Y, Z) & q(X,Y2,Z2)
+        assert ask(a_rank[X,Y]==Z) == set([('a', 'b', 0), ('a', 'c', 0), ('b', 'b', 0)])
+        assert ask(a_rank[a,b]==1) == None
+        assert ask(a_rank[a,b]==Y) == set([(0,)])
+        assert ask(a_rank[a,X]==0) == set([('b',), ('c',)])
+        assert ask(a_rank[a,X]==Y) == set([('b', 0), ('c', 0)])
+        assert ask(a_rank[X,Y]==1) == None
+
         # reversed
         (b_rank[X,Y] == rank(for_each=(X,Y2), order_by=-Z2)) <= q(X, Y, Z) & q(X,Y2,Z2)
         assert ask(b_rank[X,Y]==Z) == set([('a', 'b', 0), ('a', 'c', 1), ('b', 'b', 0)])
 
     @pyDatalog.program()
     def running_sum(): 
+        +movement('Account1', 'date1', 10)
+        +movement('Account1', 'date2', -6)
+        +movement('Account1', 'date2', -2) #TODO ?
+        +movement('Account1', 'date3', -2)
+        +movement('Account2', 'date1', 10)
+        +movement('Account2', 'date2', -5)
+        
+        (balance[Account, Date] == running_sum(Amount, for_each=Account, order_by=Date)) <= movement(Account, Date, Amount)
+        
+        assert ask(balance[Account, Date]==Amount) == set([('Account1', 'date1', 10),('Account1', 'date2', 2),('Account1', 'date3', 0),('Account2', 'date1', 10),('Account2', 'date2', 5)])
+        assert ask(balance['Account1', Date]==Amount) == set([('date1', 10), ('date2', 2), ('date3', 0)])
+        #TODO assert ask(balance[Account, 'date2']==Amount) #TODO !
+        assert ask(balance[Account, Date]==0) == set([('Account1', 'date3')])
+
+        (a_run_sum1[Z] == running_sum(Z, for_each=Z, order_by=Z)) <= q(X, Y, Z)
+        assert ask(a_run_sum1[X]==Y) == set([(1, 1), (2, 2), (4, 4)])
+        assert ask(a_run_sum1[X]==1) == set([(1,)])
+        assert ask(a_run_sum1[1]==X) == set([(1,)])
+        assert ask(a_run_sum1[1]==1) == set([()])
+        assert ask(a_run_sum1[1]==0) == None
+
         # running_sum
-        (a_run_sum[X,Y] == running_sum(Z2, for_each=(X,Y2), order_by=Z2)) <= q(X, Y, Z) & q(X,Y2,Z2)
-        assert ask(a_run_sum[X,Y]==Z) == set([('a', 'b', 3), ('a', 'c', 1), ('b', 'b', 4)])
-        assert ask(a_run_sum[a,b]==3) == set([()])
-        assert ask(a_run_sum[a,b]==Y) == set([(3,)])
-        assert ask(a_run_sum[a,X]==1) == set([('c',)])
-        assert ask(a_run_sum[a,X]==Y) == set([('b', 3), ('c', 1)])
-        assert ask(a_run_sum[X,Y]==4) == set([('b', 'b')])
+        (a_run_sum[X,Y] == running_sum(Z, for_each=(Y), order_by=Z2)) <= q(X, Y, Z) & q(X,Y,Z2)
+        assert ask(a_run_sum[X,Y]==Z) == set([('a', 'b', 2), ('a', 'c', 1), ('b', 'b', 6)])
+        #assert ask(a_run_sum[b,b]==6) == set([()])
+        #assert ask(a_run_sum[b,b]==Y) == set([(6,)])
+        assert ask(a_run_sum[X,b]==Z) == set([('a',2),('b',6)])
+        #assert ask(a_run_sum[a,X]==Y) == set([('b', 3), ('c', 1)])
+        assert ask(a_run_sum[X,Y]==6) == set([('b', 'b')])
         assert ask(a_run_sum[a,y]==Y) == None
 
-        (b_run_sum[X,Y] == running_sum(Z2, for_each=(X,Y2), order_by=-Z2)) <= q(X, Y, Z) & q(X,Y2,Z2)
-        assert ask(b_run_sum[X,Y]==Z) == set([('a', 'b', 2), ('a', 'c', 3), ('b', 'b', 4)])
+        (b_run_sum[X,Y] == running_sum(Z, for_each=(X,Y2), order_by=-Z)) <= q(X, Y, Z) & q(X,Y2,Z)
+        assert ask(b_run_sum[X,Y]==Z) == set([('a', 'b', 2), ('a', 'c', 1), ('b', 'b', 4)])
         assert ask(b_run_sum[a,b]==2) == set([()])
         assert ask(b_run_sum[a,b]==Y) == set([(2,)])
-        assert ask(b_run_sum[a,X]==3) == set([('c',)])
-        assert ask(b_run_sum[a,X]==Y) == set([('b', 2), ('c', 3)])
+        assert ask(b_run_sum[a,X]==1) == set([('c',)])
+        assert ask(b_run_sum[a,X]==Y) == set([('b', 2), ('c', 1)])
         assert ask(b_run_sum[X,Y]==4) == set([('b', 'b')])
         assert ask(b_run_sum[a,y]==Y) == None
 

File pyDatalog/pyEngine.py

         for k, v in groupby(result, aggregate.key):
             aggregate.reset()
             for r in v:
-                if aggregate.add(r):
-                    break
-            k = aggregate.fact(k)
-            fact_candidate(subgoal, class0, k)
+                row = aggregate.add(r)
+                if row is not None:
+                    fact_candidate(subgoal, class0, row)
+            row = aggregate.fact(k)
+            if row is not None:
+                fact_candidate(subgoal, class0, row)
 
 def search(subgoal):
     """ 

File pyDatalog/pyParser.py

     def arity(self):
         """returns the arity of the aggregate function, not of the full predicate """
         return len(self.args)
+
+    def get_slices(self, result):
+        """ significant indexes in the result rows"""
+        # this cannot be determined at __init__ because the predicate keys are not known
+        # it also varies depending on the number of constants in the function keys
+        self.order_by_start = len(result[0]) - len(self.order_by) - self.sep_arity
+        self.for_each_start = self.order_by_start - len(self.for_each)
+        self.to_add = self.for_each_start-1
         
+        self.slice_for_each = slice(self.for_each_start, self.order_by_start)
+        self.reversed_order_by = range(len(result[0])-1-self.sep_arity, self.order_by_start-1,  -1)
+        self.slice_group_by = slice(1, self.for_each_start-self.Y_arity)  #prefixed
+                 
     def sort_result(self, result):
         """ sort result according to the aggregate argument """
-        # significant indexes in the result rows
-        order_by_start = len(result[0]) - len(self.order_by) - self.sep_arity
-        for_each_start = order_by_start - len(self.for_each)
-        self.to_add = for_each_start-1
-        
-        self.slice_for_each = slice(for_each_start, order_by_start)
-        self.reversed_order_by = range(len(result[0])-1-self.sep_arity, order_by_start-1,  -1)
-        self.slice_group_by = slice(1, for_each_start-self.Y_arity)  #prefixed
+        self.get_slices(result)
         # first sort per order_by, allowing for _pyD_negated
         for i in self.reversed_order_by:
             result.sort(key=lambda literal, i=i: literal[i].id,
-                reverse = self.order_by[i-order_by_start]._pyD_negated)
+                reverse = self.order_by[i-self.order_by_start]._pyD_negated)
         # then sort per group_by
         result.sort(key=lambda literal, self=self: [id(term) for term in literal[self.slice_group_by]])
         pass
 
 class Rank_aggregate(Aggregate):
     """ represents rank_(for_each=Z, order_by=T)"""
-    required_kw = ('for_each', 'order_by')
+    required_kw = ('order_by',)
     
-    def reset(self):
-        self.count = 0
-        self._value = None
+    def key(self, result):
+        """ return the grouping key of a result """
+        return list(result[self.slice_for_each])
+    
+    def sort_result(self, result):
+        """ sort result according to the aggregate argument """
+        self.get_slices(result)
+        # first sort per order_by, allowing for _pyD_negated
+        for i in self.reversed_order_by:
+            result.sort(key=lambda literal, i=i: literal[i].id,
+                reverse = self.order_by[i-self.order_by_start]._pyD_negated)
+        # then sort per for_each
+        result.sort(key=lambda literal, self=self: [id(term) for term in literal[self.slice_for_each]])
 
     def add(self, row):
-        # retain the value if (X,) == (Z,)
-        if row[self.slice_group_by] == row[self.slice_for_each]:
-            self._value = [row[0],] + list(row[self.slice_group_by]) + [pyEngine.Const(self.count),] #prefixed
-            return self._value
-        self.count += 1
+        self._value += 1
+        return list(row[:len(row)-self.arity]) + [pyEngine.Const(self._value-1)] #TODO
         
     def fact(self, k):
-        return self._value
+        return None
 
 class Running_sum(Rank_aggregate):
     """ represents running_sum(Y, for_each=Z, order_by=T"""
     required_kw = ('Y', 'for_each', 'order_by')
     
-    def add(self,row):
-        self.count += row[self.to_add].id # TODO
-        if row[1:self.to_add] == row[self.slice_for_each]: #prefixed
-            self._value = list(row[:self.to_add]) + [pyEngine.Const(self.count),]
-            return self._value
-
+    def add(self, row):
+        self._value += row[self.to_add].id
+        return list(row[:len(row)-self.arity]) + [pyEngine.Const(self._value)] #TODO
+        
         
 """                             Parser methods                                                   """