Pierre Carbonnelle avatar Pierre Carbonnelle committed 32a2311

Refactor verification of presence of arguments in aggregate functions

Comments (0)

Files changed (1)

pyDatalog/pyParser.py

     def __call__ (self, *args, **kwargs):
         """ called when compiling p(args) """
         "time to create a literal !"
-        def check(kwargs, template):
-            for arg in template:
-                if not [kw for kw in arg if kw in kwargs]:
-                    raise pyDatalog.DatalogError("Error: argument missing in aggregate", None, None)
         if self._pyD_name == 'ask':
             if 1<len(args):
                 raise RuntimeError('Too many arguments for ask !')
             return pyEngine.toAnswer(literal.lua, literal.lua.ask(fast))
         elif self._pyD_name == '__sum__':
             if isinstance(args[0], Symbol):
-                check(kwargs, (('key', 'for_each'),))
-                return Sum_aggregate(args[0], for_each=kwargs.get('for_each', kwargs.get('key')))
+                return Sum_aggregate(args[0], for_each=kwargs.get('for_each', kwargs.get('key', [])))
             else:
                 return sum(args)
         elif self._pyD_name == 'concat':
-            check(kwargs, (('key','order_by'),('sep',)))
-            return Concat_aggregate(args[0], order_by=kwargs.get('order_by',kwargs.get('key')), sep=kwargs['sep'])
+            return Concat_aggregate(args[0], order_by=kwargs.get('order_by',kwargs.get('key', [])), sep=kwargs['sep'])
         elif self._pyD_name == '__min__':
             if isinstance(args[0], Symbol):
-                check(kwargs, (('key', 'order_by'),))
-                return Min_aggregate(args[0], order_by=kwargs.get('order_by',kwargs.get('key')),)
+                return Min_aggregate(args[0], order_by=kwargs.get('order_by',kwargs.get('key', [])),)
             else:
                 return min(args)
         elif self._pyD_name == '__max__':
             if isinstance(args[0], Symbol):
-                check(kwargs, (('key', 'order_by'),))
-                return Max_aggregate(args[0], order_by=kwargs.get('order_by',kwargs.get('key')),)
+                return Max_aggregate(args[0], order_by=kwargs.get('order_by',kwargs.get('key', [])),)
             else:
                 return max(args)
         elif self._pyD_name == 'rank':
-            return Rank_aggregate(None, for_each=kwargs['for_each'], order_by=kwargs['order_by'])
+            return Rank_aggregate(None, for_each=kwargs.get('for_each', []), order_by=kwargs.get('order_by', []))
         elif self._pyD_name == 'running_sum':
-            return Running_sum(args[0], for_each=kwargs['for_each'], order_by=kwargs['order_by'])
+            return Running_sum(args[0], for_each=kwargs.get('for_each', []), order_by=kwargs.get('order_by', []))
         elif self._pyD_name == '__len__':
             if isinstance(args[0], Symbol):
                 return Len_aggregate(args[0])
     e.g. 'sum(Y,key=Z)' in '(a[X]==sum(Y,key=Z))'
     pyEngine calls sort_result(), key(), reset(), add() and fact() to compute the aggregate
     """
-
+    
     def __init__(self, Y=None, for_each=tuple(), order_by=tuple(), sep=None):
         # convert for_each=Z to for_each=(Z,)
         self.Y = Y
             raise pyDatalog.DatalogError("Separator in aggregation must be a string", None, None)
         self.sep = sep
         
+        # verify presence of keyword arguments
+        if any([kw for kw in self.required_kw if getattr(self, kw) in (None, tuple())]):
+            raise pyDatalog.DatalogError("Error: argument missing in aggregate", None, None)
+        
         # used to create literal. TODO : filter on symbols
         self.args = ((Y,) if Y is not None else tuple()) + self.for_each + self.order_by + ((sep,) if sep is not None else tuple())
         self.Y_arity = 1 if Y is not None else 0
         return k + [pyEngine.Const(self.value)]
        
 class Sum_aggregate(Aggregate):
-    """ represents sum(X, key=(Y,Z))"""
+    """ represents sum(Y, for_each=(Z,T))"""
+    required_kw = ('Y', 'for_each')
+
     def add(self, row):
         self._value += row[-self.arity].id
         
 class Len_aggregate(Aggregate):
     """ represents len(X)"""
+    required_kw = ('Y')
+
     def add(self, row):
         self._value += 1
 
 class Concat_aggregate(Aggregate):
     """ represents concat(Y, order_by=(Z1,Z2), sep=sep)"""
+    required_kw = ('Y', 'order_by', 'sep')
         
     def reset(self):
         self._value = []
         return self.sep.join(self._value)
 
 class Min_aggregate(Aggregate):
-    """ represents min(X, order_by=(Y,Z))"""
+    """ represents min(Y, order_by=(Z,T))"""
+    required_kw = ('Y', 'order_by')
+
     def reset(self):
         self._value = None
         
         self._value = row[-self.arity].id if self._value is None else self._value
 
 class Max_aggregate(Min_aggregate):
-    """ represents max(X, order_by=(Y,Z))"""
+    """ represents max(Y, order_by=(Z,T))"""
     def __init__(self, *args, **kwargs):
         Min_aggregate.__init__(self, *args, **kwargs)
         for a in self.order_by:
 
 class Rank_aggregate(Aggregate):
     """ represents rank(for_each=(Z), order_by(T))"""
+    required_kw = ('for_each', 'order_by')
+    
     def reset(self):
         self.count = 0
         self._value = None
         return self._value
 
 class Running_sum(Rank_aggregate):
-    """ represents running_sum(N, for_each=(Z), order_by(T)"""
+    """ represents running_sum(Y, for_each=(Z), order_by(T)"""
+    required_kw = ('Y', 'for_each', 'order_by')
+    
     def add(self,row):
         self.count += row[self.to_add].id # TODO
         if row[:self.to_add] == row[self.slice_for_each]:
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.