Commits

Pierre Carbonnelle committed 3c0ae43

move parser of pyDatalog program to bottom of file

Comments (0)

Files changed (1)

pyDatalog/pyParser.py

 This is done in the load() and add_program() method of Parser class.
 
 Classes hierarchy contained in this file: see class diagram on http://bit.ly/YRnMPH
-* ProgramContext : class to safely differentiate between In-line queries and pyDatalog program / ask(), using ProgramMode global variable
-* _transform_ast : performs some modifications of the abstract syntax tree of the datalog program
 * LazyList : a subclassable list that is populated when it is accessed. 
     * LazyListOfList : Mixin for Query and Body
 * Literal : made of a predicate and a list of arguments.  Instantiated when a symbol is called while executing the datalog program
     * Max_aggregate
     * Rank_aggregate
     * Running_sum
+* ProgramContext : class to safely differentiate between In-line queries and pyDatalog program / ask(), using ProgramMode global variable
+* _transform_ast : performs some modifications of the abstract syntax tree of the datalog program
 """
 
 import ast
 # global variable to differentiate between in-line queries and pyDatalog program / ask
 ProgramMode = False
 
-class ProgramContext(object):
-    """class to safely use ProgramMode within the "with" statement"""
-    def __enter__(self):
-        global ProgramMode
-        ProgramMode = True
-    def __exit__(self, exc_type, exc_value, traceback):
-        global ProgramMode
-        ProgramMode = False
- 
-"""                             Parser methods                                                   """
-
-def add_symbols(names, variables):
-    """ add the names to the variables dictionary"""
-    for name in names:
-        variables[name] = Symbol(name)            
-    
-class _transform_ast(ast.NodeTransformer):
-    """ does some transformation of the Abstract Syntax Tree of the datalog program """
-    def visit_Call(self, node):
-        """rename builtins to allow customization"""
-        self.generic_visit(node)
-        if hasattr(node.func, 'id'):
-            node.func.id = 'sum_' if node.func.id == 'sum' else node.func.id
-            node.func.id = 'len_' if node.func.id == 'len' else node.func.id
-            node.func.id = 'min_' if node.func.id == 'min' else node.func.id
-            node.func.id = 'max_' if node.func.id == 'max' else node.func.id
-        return node
-    
-    def visit_Compare(self, node):
-        """ rename 'in' to allow customization of (X in (1,2))"""
-        self.generic_visit(node)
-        if 1 < len(node.comparators): 
-            raise util.DatalogError("Syntax error: please verify parenthesis around (in)equalities", node.lineno, None) 
-        if not isinstance(node.ops[0], (ast.In, ast.NotIn)): return node
-        var = node.left # X, an _ast.Name object
-        comparators = node.comparators[0] # (1,2), an _ast.Tuple object
-        newNode = ast.Call(
-                ast.Attribute(var, 'in_' if isinstance(node.ops[0], ast.In) else 'not_in_', var.ctx), # func
-                [comparators], # args
-                [], # keywords
-                None, # starargs
-                None # kwargs
-                )
-        return ast.fix_missing_locations(newNode)
-
-def load(code, newglobals=None, defined=None, function='load'):
-    """ code : a string or list of string 
-        newglobals : global variables for executing the code
-        defined : reserved symbols
-    """
-    newglobals, defined = newglobals or {}, defined or set([])
-    # remove indentation based on first non-blank line
-    lines = code.splitlines() if isinstance(code, six.string_types) else code
-    r = re.compile('^\s*')
-    for line in lines:
-        spaces = r.match(line).group()
-        if spaces and line != spaces:
-            break
-    code = '\n'.join([line.replace(spaces,'') for line in lines])
-    
-    tree = ast.parse(code, function, 'exec')
-    try:
-        tree = _transform_ast().visit(tree)
-    except util.DatalogError as e:
-        e.function = function
-        e.message = e.value
-        e.value = "%s\n%s" % (e.value, lines[e.lineno-1])
-        six.reraise(*sys.exc_info())
-    code = compile(tree, function, 'exec')
-
-    defined = defined.union(dir(builtins))
-    defined.add('None')
-    for name in set(code.co_names).difference(defined): # for names that are not defined
-        add_symbols((name,), newglobals)
-    try:
-        with ProgramContext():
-            six.exec_(code, newglobals)
-    except util.DatalogError as e:
-        e.function = function
-        traceback = sys.exc_info()[2]
-        e.lineno = 1
-        while True:
-            if traceback.tb_frame.f_code.co_name == '<module>':
-                e.lineno = traceback.tb_lineno
-                break
-            elif traceback.tb_next:
-                traceback = traceback.tb_next 
-        e.message = e.value
-        e.value = "%s\n%s" % (e.value, lines[e.lineno-1])
-        six.reraise(*sys.exc_info())
-        
-class _NoCallFunction(object):
-    """ This class prevents a call to a datalog program created using the 'program' decorator """
-    def __call__(self):
-        raise TypeError("Datalog programs are not callable")
-
-def add_program(func):
-    """ A helper for decorator implementation   """
-    source_code = inspect.getsource(func)
-    lines = source_code.splitlines()
-    # drop the first 2 lines (@pydatalog and def _() )
-    if '@' in lines[0]: del lines[0]
-    if 'def' in lines[0]: del lines[0]
-    source_code = lines
-
-    try:
-        code = func.__code__
-    except:
-        raise TypeError("function or method argument expected")
-    newglobals = func.__globals__.copy() if PY3 else func.func_globals.copy()
-    func_name = func.__name__ if PY3 else func.func_name
-    defined = set(code.co_varnames).union(set(newglobals.keys())) # local variables and global variables
-
-    load(source_code, newglobals, defined, function=func_name)
-    return _NoCallFunction()
-
-def ask(code):
-    """ runs the query in the code string """
-    with ProgramContext():
-        tree = ast.parse(code, 'ask', 'eval')
-        tree = _transform_ast().visit(tree)
-        code = compile(tree, 'ask', 'eval')
-        newglobals = {}
-        add_symbols(code.co_names, newglobals)
-        parsed_code = eval(code, newglobals)
-        return Answer.make(parsed_code.ask())
-
-class Answer(object):
-    """ object returned by ask() """
-    def __init__(self, name, arity, answers):
-        self.name = name
-        self.arity = arity
-        self.answers = answers
-
-    @classmethod
-    def make(cls, answers):
-        if answers is True:
-            answer = Answer('_pyD_query', 0, True)
-        elif answers:
-            answer = Answer('_pyD_query', len(answers), answers)
-        else:
-            answer = None
-        if pyEngine.Auto_print: 
-            print(answers)
-        return answer        
-
-    def __eq__ (self, other):
-        return other == True if self.answers is True \
-            else other == set(self.answers) if self.answers \
-            else other is None
-    def __str__(self):
-        return 'True' if self.answers is True \
-            else str(set(self.answers)) if self.answers is not True \
-            else 'True'
-
-
 """                             Parser classes                                                   """
 
 class LazyList(UserList.UserList):
         if row[:self.to_add] == row[self.slice_for_each]:
             self._value = list(row[:self.to_add]) + [pyEngine.Const(self.count),]
             return self._value
+
         
+"""                             Parser methods                                                   """
+
+class ProgramContext(object):
+    """class to safely use ProgramMode within the "with" statement"""
+    def __enter__(self):
+        global ProgramMode
+        ProgramMode = True
+    def __exit__(self, exc_type, exc_value, traceback):
+        global ProgramMode
+        ProgramMode = False
+ 
+def add_symbols(names, variables):
+    """ add the names to the variables dictionary"""
+    for name in names:
+        variables[name] = Symbol(name)            
+    
+class _transform_ast(ast.NodeTransformer):
+    """ does some transformation of the Abstract Syntax Tree of the datalog program """
+    def visit_Call(self, node):
+        """rename builtins to allow customization"""
+        self.generic_visit(node)
+        if hasattr(node.func, 'id'):
+            node.func.id = 'sum_' if node.func.id == 'sum' else node.func.id
+            node.func.id = 'len_' if node.func.id == 'len' else node.func.id
+            node.func.id = 'min_' if node.func.id == 'min' else node.func.id
+            node.func.id = 'max_' if node.func.id == 'max' else node.func.id
+        return node
+    
+    def visit_Compare(self, node):
+        """ rename 'in' to allow customization of (X in (1,2))"""
+        self.generic_visit(node)
+        if 1 < len(node.comparators): 
+            raise util.DatalogError("Syntax error: please verify parenthesis around (in)equalities", node.lineno, None) 
+        if not isinstance(node.ops[0], (ast.In, ast.NotIn)): return node
+        var = node.left # X, an _ast.Name object
+        comparators = node.comparators[0] # (1,2), an _ast.Tuple object
+        newNode = ast.Call(
+                ast.Attribute(var, 'in_' if isinstance(node.ops[0], ast.In) else 'not_in_', var.ctx), # func
+                [comparators], # args
+                [], # keywords
+                None, # starargs
+                None # kwargs
+                )
+        return ast.fix_missing_locations(newNode)
+
+def load(code, newglobals=None, defined=None, function='load'):
+    """ code : a string or list of string 
+        newglobals : global variables for executing the code
+        defined : reserved symbols
+    """
+    newglobals, defined = newglobals or {}, defined or set([])
+    # remove indentation based on first non-blank line
+    lines = code.splitlines() if isinstance(code, six.string_types) else code
+    r = re.compile('^\s*')
+    for line in lines:
+        spaces = r.match(line).group()
+        if spaces and line != spaces:
+            break
+    code = '\n'.join([line.replace(spaces,'') for line in lines])
+    
+    tree = ast.parse(code, function, 'exec')
+    try:
+        tree = _transform_ast().visit(tree)
+    except util.DatalogError as e:
+        e.function = function
+        e.message = e.value
+        e.value = "%s\n%s" % (e.value, lines[e.lineno-1])
+        six.reraise(*sys.exc_info())
+    code = compile(tree, function, 'exec')
+
+    defined = defined.union(dir(builtins))
+    defined.add('None')
+    for name in set(code.co_names).difference(defined): # for names that are not defined
+        add_symbols((name,), newglobals)
+    try:
+        with ProgramContext():
+            six.exec_(code, newglobals)
+    except util.DatalogError as e:
+        e.function = function
+        traceback = sys.exc_info()[2]
+        e.lineno = 1
+        while True:
+            if traceback.tb_frame.f_code.co_name == '<module>':
+                e.lineno = traceback.tb_lineno
+                break
+            elif traceback.tb_next:
+                traceback = traceback.tb_next 
+        e.message = e.value
+        e.value = "%s\n%s" % (e.value, lines[e.lineno-1])
+        six.reraise(*sys.exc_info())
+        
+class _NoCallFunction(object):
+    """ This class prevents a call to a datalog program created using the 'program' decorator """
+    def __call__(self):
+        raise TypeError("Datalog programs are not callable")
+
+def add_program(func):
+    """ A helper for decorator implementation   """
+    source_code = inspect.getsource(func)
+    lines = source_code.splitlines()
+    # drop the first 2 lines (@pydatalog and def _() )
+    if '@' in lines[0]: del lines[0]
+    if 'def' in lines[0]: del lines[0]
+    source_code = lines
+
+    try:
+        code = func.__code__
+    except:
+        raise TypeError("function or method argument expected")
+    newglobals = func.__globals__.copy() if PY3 else func.func_globals.copy()
+    func_name = func.__name__ if PY3 else func.func_name
+    defined = set(code.co_varnames).union(set(newglobals.keys())) # local variables and global variables
+
+    load(source_code, newglobals, defined, function=func_name)
+    return _NoCallFunction()
+
+def ask(code):
+    """ runs the query in the code string """
+    with ProgramContext():
+        tree = ast.parse(code, 'ask', 'eval')
+        tree = _transform_ast().visit(tree)
+        code = compile(tree, 'ask', 'eval')
+        newglobals = {}
+        add_symbols(code.co_names, newglobals)
+        parsed_code = eval(code, newglobals)
+        return Answer.make(parsed_code.ask())
+
+class Answer(object):
+    """ object returned by ask() """
+    def __init__(self, name, arity, answers):
+        self.name = name
+        self.arity = arity
+        self.answers = answers
+
+    @classmethod
+    def make(cls, answers):
+        if answers is True:
+            answer = Answer('_pyD_query', 0, True)
+        elif answers:
+            answer = Answer('_pyD_query', len(answers), answers)
+        else:
+            answer = None
+        if pyEngine.Auto_print: 
+            print(answers)
+        return answer        
+
+    def __eq__ (self, other):
+        return other == True if self.answers is True \
+            else other == set(self.answers) if self.answers \
+            else other is None
+    def __str__(self):
+        return 'True' if self.answers is True \
+            else str(set(self.answers)) if self.answers is not True \
+            else 'True'
+
+