Commits

Pierre Carbonnelle committed 45cea26

support unary operators (+, -)

Comments (0)

Files changed (2)

pyDatalog/examples/test.py

     pyDatalog.clear()
     @pyDatalog.program()
     def factorial(): 
-        (factorial[N] == F) <= (N < 2) & (F==1)
+        (factorial[N] == F) <= (N < 1) & (F== -factorial[-N])
+        + (factorial[1]==1)
         (factorial[N] == F) <= (N > 1) & (F == N*factorial[N-1])
         assert ask(factorial[1] == F) == set([(1, 1)])
         assert ask(factorial[4] == F) == set([(4, 24)])
+        assert ask(factorial[-4] == F) == set([(-4, -24)])
     
     # Fibonacci
     pyDatalog.clear()
     assert_error('ask(z(a),True)', 'Too many arguments for ask \!')
     assert_error('ask(z(a))', 'Predicate without definition \(or error in resolver\): z/1')
     assert_error("+ farmer(farmer(moshe))", "Syntax error: Literals cannot have a literal as argument : farmer\[\]")
-    assert_error("+ manager[Mary]==John", "bad operand type for unary \+: 'Function'. Please consider adding parenthesis")
+    assert_error("+ manager[Mary]==John", "Left-hand side of equality must be a symbol or function, not an expression.")
     assert_error("manager[X]==Y <= (X==Y)", "Syntax error: please verify parenthesis around \(in\)equalities")
     assert_error("p(X) <= (Y==2)", "Can't create clause")
     assert_error("p(X) <= X==1 & X==2", "Syntax error: please verify parenthesis around \(in\)equalities")
     assert_error("p(X) <= (manager[X]== max(X, order_by=X))", "Aggregation cannot appear in the body of a clause")
     assert_error("q(min(X, order_by=X)) <= p(X)", "Syntax error: Incorrect use of aggregation\.")
     assert_error("manager[X]== min(X, order_by=X) <= manager(X)", "Syntax error: please verify parenthesis around \(in\)equalities")
+    assert_error("(manager[X]== min(X, order_by=X+2)) <= manager(X)", "order_by argument of aggregate must be variable\(s\), not expression\(s\).")
     assert_error("ask(X<1)", 'Error: left hand side of comparison must be bound: =X<1/1')
     assert_error("ask(X<Y)", 'Error: left hand side of comparison must be bound: =X<Y/2')
     assert_error("ask(1<Y)", 'Error: left hand side of comparison must be bound: =Y>1/1')
     assert_error('(a_sum[X] == sum(Y, key=Y)) <= p(X, Z, Y)', "Error: Duplicate definition of aggregate function.")
     assert_error('(two(X)==Z) <= (Z==X+(lambda X: X))', 'Syntax error near equality: consider using brackets. two\(X\)')
     assert_error('p(X) <= sum(X, key=X)', 'Invalid body for clause')
-    assert_error('ask(- manager[X]==1)', "bad operand type for unary -: 'Function'. Please consider adding parenthesis")
+    assert_error('ask(- manager[X]==1)', "Left-hand side of equality must be a symbol or function, not an expression.")
     assert_error("p(X) <= (X=={})", "Syntax error: Symbol or Expression expected")
 
     """ SQL Alchemy                    """

pyDatalog/pyParser.py

         return Body() # by default, there is no precalculation needed to evaluate an expression
     
     def __eq__(self, other):
+        assert isinstance(self, (VarSymbol, Function)), "Left-hand side of equality must be a symbol or function, not an expression."
         if self._pyD_type == 'variable' and not isinstance(other, VarSymbol):
             return Literal.make_for_comparison(self, '==', other)
         else:
         """ called when compiling (X not in (1,2)) """
         return Literal.make_for_comparison(self, 'not in', values)
     
+    def __pos__(self):
+        """ called when compiling -X """
+        return 0 + self
+    def __neg__(self):
+        """ called when compiling -X """
+        return 0 - self
+
     def __add__(self, other):
         return Operation(self, '+', other)
     def __sub__(self, other):
         """ called when compiling -X """
         neg = Symbol(self._pyD_name)
         neg._pyD_negated = True
-        return neg
+
+        expr = 0 - self
+        expr.variable = neg
+        return expr
     
     def lua_expr(self, variables):
         if self._pyD_type == 'variable':
     def __eq__(self, other):
         return Literal.make_for_comparison(self, '==', other)
     
-    def __pos__(self):
-        raise pyDatalog.DatalogError("bad operand type for unary +: 'Function'. Please consider adding parenthesis", None, None)
-    
-    def __neg__(self):
-        raise pyDatalog.DatalogError("bad operand type for unary -: 'Function'. Please consider adding parenthesis", None, None)
-    
     # following methods are used when the function is used in an expression
     def _variables(self):
         return {self.dummy_variable_name : self.symbol}
     def __init__(self, Y=None, for_each=tuple(), order_by=tuple(), sep=None):
         # convert for_each=Z to for_each=(Z,)
         self.Y = Y
-        self.for_each = (for_each,) if isinstance(for_each, (Symbol, pyDatalog.Variable)) else tuple(for_each)
-        self.order_by = (order_by,) if isinstance(order_by, (Symbol, pyDatalog.Variable)) else tuple(order_by)
+        self.for_each = (for_each,) if isinstance(for_each, Expression) else tuple(for_each)
+        self.order_by = (order_by,) if isinstance(order_by, Expression) else tuple(order_by)
+        
+        # try to recast expressions to variables
+        self.for_each = tuple([e.__dict__.get('variable', e) for e in self.for_each]) 
+        self.order_by = tuple([e.__dict__.get('variable', e) for e in self.order_by])
+        
+        assert all([isinstance(e, VarSymbol) for e in self.for_each]), "for_each argument of aggregate must be variable(s), not expression(s)."
+        assert all([isinstance(e, VarSymbol) for e in self.order_by]), "order_by argument of aggregate must be variable(s), not expression(s)."
+        
         if sep and not isinstance(sep, six.string_types):
             raise pyDatalog.DatalogError("Separator in aggregation must be a string", None, None)
         self.sep = sep