Commits

Pierre Carbonnelle  committed 0155cb8

support unprefixed aggregate inline queries

  • Participants
  • Parent commits 5975d2b

Comments (0)

Files changed (3)

File pyDatalog/examples/datalog.py

     2. define business rules
     3. Query the datalog engine
 """
-from pyDatalog.pyDatalog import load, ask
+from pyDatalog import pyDatalog
 
 """ 1. create facts for 3 employees in the datalog engine """
+pyDatalog.create_atoms('salary', 'manager')
+
 # John is the manager of Mary, who is the manager of Sam
-load("+ (salary['John'] == 6800)")
++ (salary['John'] == 6800)
 
-load("+ (manager['Mary'] == 'John')")
-load("+ (salary['Mary'] == 6300)")
++ (manager['Mary'] == 'John')
++ (salary['Mary'] == 6300)
 
-load("+ (manager['Sam'] == 'Mary')")
-load("+ (salary['Sam'] == 5900)")
++ (manager['Sam'] == 'Mary')
++ (salary['Sam'] == 5900)
 
 """ 2. define business rules """
+pyDatalog.create_atoms('salary_class', 'indirect_manager', 'report_count', 'budget', 'lowest',
+                       'X', 'Y', 'Z', 'N')
 # the salary class of employee X is computed as a function of his/her salary
-load("salary_class[X] = salary[X]//1000")
+salary_class[X] = salary[X]//1000
     
 # all the indirect managers of employee X are derived from his manager, recursively
-load("indirect_manager(X,Y) <= (manager[X] == Y) & (Y != None)")
-load("indirect_manager(X,Y) <= (manager[X] == Z) & indirect_manager(Z,Y) & (Y != None)")
+indirect_manager(X,Y) <= (manager[X] == Y) & (Y != None)
+indirect_manager(X,Y) <= (manager[X] == Z) & indirect_manager(Z,Y) & (Y != None)
 
 # count the number of reports of X
-load("(report_count[X] == len(Y)) <= indirect_manager(Y,X)")
+(report_count[X] == _len(Y)) <= indirect_manager(Y,X)
 
 """ 3. Query the datalog engine """
 
 # what is the salary class of John ?
-print(ask("salary_class['John'] == Y")) # prints set([('John', 6)])
+print(salary_class['John'] == Y) # prints [6]
 
 # who has a salary of 6300 ?
-print(ask("salary[X] == 6300")) # prints set([('Mary', 6300)])
+print(salary[X] == 6300) # prints Mary
 
 # who are the indirect managers of Mary ?
-print(ask("indirect_manager('Mary', X)")) # prints set([('Mary', 'John')])
+print(indirect_manager('Mary', X)) # prints [('John',)]
 
 # Who are the employees of John with a salary below 6000 ?
-print(ask("(salary[X] < 6000) & indirect_manager(X, 'John')")) # prints set([('Sam', )])
+print((salary[X] < 6000) & indirect_manager(X, 'John')) # prints [('Sam',)]
 
 # who is his own indirect manager ?
-print(ask("indirect_manager('X', X)")) # prints None
+print(indirect_manager('X', X)) # prints []
 
 # who has 2 reports ?
-print(ask("report_count[X] == 2")) # prints set([('John', 2)])
+print(report_count[X] == 2) # prints [('John',)]
 
 # what is the total salary of the employees of John ? 
-load("(Budget[X] == sum(N, for_each=Y)) <= (indirect_manager(Y, X)) & (salary[Y]==N)")
-print(ask("Budget['John']==N")) # prints set([('John', 12200)])
+(budget[X] == _sum(N, for_each=Y)) <= (indirect_manager(Y, X)) & (salary[Y]==N)
+print(budget['John']==N) # prints [(12200,)]
 
 # who has the lowest salary ?
-load("(Lowest[1] == min(X, order_by=N)) <= (salary[X]==N)")
-print(ask("Lowest[1]==N")) # prints set([(1, 'Sam')])
+(lowest[1] == _min(X, order_by=N)) <= (salary[X]==N)
+print(lowest[1]==N) # prints [('Sam',)]
 
 # start the datalog console, for interactive querying 
-from pyDatalog import pyDatalog
 from pyDatalog.examples import console
 console = console.datalogConsole(locals=locals())
 console.interact('Type exit() when done.')

File pyDatalog/pyDatalog.py

     stack = inspect.stack()
     try:
         locals_ = stack[1][0].f_locals
-        for arg in set(args + ('__sum__','__min__', '__max__')):
+        for arg in set(args + ('_sum','_min', '_max', '_len')):
             if arg in locals_ and not isinstance(locals_[arg], (pyParser.Symbol, pyDatalog.Variable)):
                 raise BaseException("Name conflict.  Can't redefine %s as atom" % arg)
             if arg[0] not in string.ascii_uppercase:

File pyDatalog/pyParser.py

         """rename builtins to allow customization"""
         self.generic_visit(node)
         if hasattr(node.func, 'id'):
-            node.func.id = '__sum__' if node.func.id == 'sum' else node.func.id
-            node.func.id = '__len__' if node.func.id == 'len' else node.func.id
-            node.func.id = '__min__' if node.func.id == 'min' else node.func.id
-            node.func.id = '__max__' if node.func.id == 'max' else node.func.id
+            node.func.id = '_sum' if node.func.id == 'sum' else node.func.id
+            node.func.id = '_len' if node.func.id == 'len' else node.func.id
+            node.func.id = '_min' if node.func.id == 'min' else node.func.id
+            node.func.id = '_max' if node.func.id == 'max' else node.func.id
         return node
     
     def visit_Compare(self, node):
             fast = kwargs['_fast'] if '_fast' in list(kwargs.keys()) else False
             literal = args[0] if not isinstance(args[0], Body) else args[0].literal()
             return pyEngine.toAnswer(literal.lua, literal.lua.ask(fast))
-        elif self._pyD_name == '__sum__':
+        elif self._pyD_name == '_sum':
             if isinstance(args[0], (Symbol, pyDatalog.Variable)):
                 return Sum_aggregate(args[0], for_each=kwargs.get('for_each', kwargs.get('key', [])))
             else:
                 return sum(args)
         elif self._pyD_name == 'concat':
             return Concat_aggregate(args[0], order_by=kwargs.get('order_by',kwargs.get('key', [])), sep=kwargs['sep'])
-        elif self._pyD_name == '__min__':
-            if isinstance(args[0], Symbol):
+        elif self._pyD_name == '_min':
+            if isinstance(args[0], (Symbol, pyDatalog.Variable)):
                 return Min_aggregate(args[0], order_by=kwargs.get('order_by',kwargs.get('key', [])),)
             else:
                 return min(args)
-        elif self._pyD_name == '__max__':
-            if isinstance(args[0], Symbol):
+        elif self._pyD_name == '_max':
+            if isinstance(args[0], (Symbol, pyDatalog.Variable)):
                 return Max_aggregate(args[0], order_by=kwargs.get('order_by',kwargs.get('key', [])),)
             else:
                 return max(args)
             return Rank_aggregate(None, for_each=kwargs.get('for_each', []), order_by=kwargs.get('order_by', []))
         elif self._pyD_name == 'running_sum':
             return Running_sum(args[0], for_each=kwargs.get('for_each', []), order_by=kwargs.get('order_by', []))
-        elif self._pyD_name == '__len__':
-            if isinstance(args[0], Symbol):
+        elif self._pyD_name == '_len':
+            if isinstance(args[0], (Symbol, pyDatalog.Variable)):
                 return Len_aggregate(args[0])
             else: 
                 return len(args[0]) 
             self.has_symbols = self.has_symbols or isinstance(t, Symbol)
             self.is_fact = self.is_fact and not(isinstance(t, pyDatalog.Variable) and not(isinstance(t, Symbol) and t._pyD_type == 'variable'))
         
-        self.args = terms
+        self.args = terms # TODO simplify
         self.todo = self
         cls_name = predicate_name.split('.')[0].replace('~','') if 1< len(predicate_name.split('.')) else ''
         terms, env = [], {}
                 raise pyDatalog.DatalogError("Syntax error: Literals cannot have a literal as argument : %s%s" % (predicate_name, terms), None, None)
             elif i==0 and cls_name and arg.__class__.__name__ != cls_name: # TODO use __mro__ !
                 raise TypeError("Object is incompatible with the class that is queried.")
+            elif isinstance(arg, Aggregate):
+                raise pyDatalog.DatalogError("Syntax error: Incorrect use of aggregation.", None, None)
             else:
                 terms.append(arg)
         self.terms = terms
         for a in terms:
             if isinstance(a, Symbol):
                 tbl.append(a._pyD_lua)
-            elif isinstance(a, Aggregate):
-                raise pyDatalog.DatalogError("Syntax error: Incorrect use of aggregation.", None, None)
             else:
                 tbl.append(pyEngine.Const(a))
         # now create the literal for the head of a clause
 
     def __le__(self, body):
         " head <= body"
-        global ProgramMode
-        #TODO assert ProgramMode # '<=' cannot be used with literal containing pyDatalog.Variable instances
         if isinstance(body, Literal):
             newBody = body.pre_calculations & body
             if isinstance(body, Literal) and body.predicate_name[-1]=='!':
     def __init__(self, Y=None, for_each=tuple(), order_by=tuple(), sep=None):
         # convert for_each=Z to for_each=(Z,)
         self.Y = Y
-        self.for_each = (for_each,) if isinstance(for_each, Symbol) else tuple(for_each)
-        self.order_by = (order_by,) if isinstance(order_by, Symbol) else tuple(order_by)
+        self.for_each = (for_each,) if isinstance(for_each, (Symbol, pyDatalog.Variable)) else tuple(for_each)
+        self.order_by = (order_by,) if isinstance(order_by, (Symbol, pyDatalog.Variable)) else tuple(order_by)
         if sep and not isinstance(sep, six.string_types):
             raise pyDatalog.DatalogError("Separator in aggregation must be a string", None, None)
         self.sep = sep