Commits

David Schneider committed 15cab37

Refactored cmp_standard order to instance methods and
added specialized class for cons and support code to use it

Comments (0)

Files changed (4)

prolog/interpreter/helper.py

 def unwrap_list(prolog_list):
     result = []
     curr = prolog_list
-    while isinstance(curr, term.Term):
-        if not curr.name()== ".":
-            error.throw_type_error("list", prolog_list)
+    while isinstance(curr, term.Callable) and curr.signature() == './2':
         result.append(curr.argument_at(0))
         curr = curr.argument_at(1)
-    if isinstance(curr, term.Atom) and curr.name()== "[]":
+    if isinstance(curr, term.Callable) and curr.name()== "[]":
         return result
     error.throw_type_error("list", prolog_list)
 

prolog/interpreter/term.py

     def eval_arithmetic(self, engine):
         error.throw_type_error("evaluable", self)
 
+    def cmp_standard_order(self, other, heap):
+        raise NotImplementedError("abstract base class")
+        
 class Var(PrologObject):
     TYPE_STANDARD_ORDER = 0
 
 
         return self.eval_arithmetic(engine)
 
+    def cmp_standard_order(self, other, heap):
+        assert isinstance(other, Var)
+        return rcmp(compute_unique_id(self), compute_unique_id(other))
+        
+
 
 class NumberedVar(PrologObject):
     _immutable_ = True
 
     def enumerate_vars(self, memo):
         return self
-
+    
 class Callable(NonVar):
     _immutable_ = True
     __slots__ = ()
         raise NotImplementedError("abstract base")
         
     def get_prolog_signature(self):
-        raise NotImplementedError("abstract base")
-        
+        return Term("/", [Callable.build(self.name()),
+                                                Number(self.argument_count())])        
     def arguments(self):
         raise NotImplementedError("abstract base")
         
                 self.argument_at(i).unify(other.argument_at(i), heap, occurs_check)
         else:
             raise UnificationFailed
+    def copy_and_basic_unify(self, other, heap, env):
+        if (isinstance(other, Callable) and 
+            self.signature() == other.signature()):
+            return self._copy_term(_term_unify_and_standardize_apart,
+                                   other, heap, env)
+        else:
+            raise UnificationFailed
+
+    def copy(self, heap, memo):
+        return self._copy_term(_term_copy, heap, memo)
+
+    def copy_standardize_apart(self, heap, env):
+        return self._copy_term(_term_copy_standardize_apart, heap, env)
+
+    def enumerate_vars(self, memo):
+        return self._copy_term(_term_enumerate_vars, memo)
+
+    def getvalue(self, heap):
+        return self._copy_term(_term_getvalue, heap)
+
+    @specialize.arg(1)
+    @jit.unroll_safe
+    def _copy_term(self, copy_individual, *extraargs):
+        args = [None] * self.argument_count()
+        newinstance = False
+        i = 0
+        while i < self.argument_count():
+            arg = self.argument_at(i)
+            cloned = copy_individual(arg, i, *extraargs)
+            newinstance = newinstance or cloned is not arg
+            args[i] = cloned
+            i += 1
+        if newinstance:
+            # XXX construct the right class directly
+            return Callable.build(self.name(), args, self.signature())
+        else:
+            return self
+    
+    def contains_var(self, var, heap):
+        for arg in self.arguments():
+            if arg.contains_var(var, heap):
+                return True
+        return False
+        
+    def cmp_standard_order(self, other, heap):
+        assert isinstance(other, Callable)
+        c = rcmp(self.argument_count(), other.argument_count())
+        if c != 0:
+            return c
+        c = rcmp(self.name(), other.name())
+        if c != 0:
+            return c
+        for i in range(self.argument_count()):
+            a1 = self.argument_at(i).dereference(heap)
+            a2 = other.argument_at(i).dereference(heap)
+            c = cmp_standard_order(a1, a2, heap)
+            if c != 0:
+                return c
+        return 0
     
     @staticmethod
     def build(term_name, args=None, signature=None):
         if len(args) == 0:
             return Atom.newatom(term_name)
         else:
+            cls = specialized_term_classes.get((term_name, len(args)), None)
+            if cls is not None:
+                return cls(args)
             return Term(term_name, args, signature)
         
 class Atom(Callable):
     def __repr__(self):
         return "Atom(%r)" % (self.name(),)
 
-
-    def copy_and_basic_unify(self, other, heap, env):
-        if isinstance(other, Atom) and (self is other or
-                                        other.name() == self.name()):
-            return self
-        else:
-            raise UnificationFailed
-
-    def get_prolog_signature(self):
-        return Term("/", [self, NUMBER_0])
-
     @staticmethod
     def newatom(name):
         result = Atom.cache.get(name, None)
     def eval_arithmetic(self, engine):
         return self
 
-NUMBER_0 = Number(0)
-
+    def cmp_standard_order(self, other, heap):
+        # XXX looks a bit terrible
+        if isinstance(other, Number):
+            return rcmp(self.num, other.num)
+        elif isinstance(other, Float):
+            return rcmp(self.num, other.floatval)
+        assert 0
+        
 class Float(NonVar):
     TYPE_STANDARD_ORDER = 2
     _immutable_ = True
     def eval_arithmetic(self, engine):
         from prolog.interpreter.arithmetic import norm_float
         return norm_float(self)
-
+        
+    def cmp_standard_order(self, other, heap):
+        # XXX looks a bit terrible
+        if isinstance(other, Number):
+            return rcmp(self.floatval, other.num)
+        elif isinstance(other, Float):
+            return rcmp(self.floatval, other.floatval)
+        assert 0
+        
 Float.e = Float(math.e)
 Float.pi = Float(math.pi)
 
 
     def __str__(self):
         return "%s(%s)" % (self.name(), ", ".join([str(a) for a in self.arguments()]))
-
-    def copy_and_basic_unify(self, other, heap, env):
-        if (isinstance(other, Term) and
-                self.signature() == other.signature()):
-            return self._copy_term(_term_unify_and_standardize_apart,
-                                   other, heap, env)
-        else:
-            raise UnificationFailed
-
-    def copy(self, heap, memo):
-        return self._copy_term(_term_copy, heap, memo)
-
-    def copy_standardize_apart(self, heap, env):
-        return self._copy_term(_term_copy_standardize_apart, heap, env)
-
-    def enumerate_vars(self, memo):
-        return self._copy_term(_term_enumerate_vars, memo)
-
-    def getvalue(self, heap):
-        return self._copy_term(_term_getvalue, heap)
-
-    @specialize.arg(1)
-    @jit.unroll_safe
-    def _copy_term(self, copy_individual, *extraargs):
-        args = [None] * self.argument_count()
-        newinstance = False
-        i = 0
-        while i < self.argument_count():
-            arg = self.argument_at(i)
-            cloned = copy_individual(arg, i, *extraargs)
-            newinstance = newinstance or cloned is not arg
-            args[i] = cloned
-            i += 1
-        if newinstance:
-            return Term(self.name(), args, self.signature())
-        else:
-            return self
-
-    def get_prolog_signature(self):
-        return Term("/", [Callable.build(self.name()),
-                                                Number(self.argument_count())])
-    
-    def contains_var(self, var, heap):
-        for arg in self.arguments():
-            if arg.contains_var(var, heap):
-                return True
-        return False
-        
+            
     def eval_arithmetic(self, engine):
         from prolog.interpreter.arithmetic import get_arithmetic_function
 
     c = rcmp(obj1.TYPE_STANDARD_ORDER, obj2.TYPE_STANDARD_ORDER)
     if c != 0:
         return c
-    if isinstance(obj1, Var):
-        assert isinstance(obj2, Var)
-        return rcmp(compute_unique_id(obj1), compute_unique_id(obj2))
-    if isinstance(obj1, Atom):
-        assert isinstance(obj2, Atom)
-        return rcmp(obj1.name(), obj2.name())
-    if isinstance(obj1, Term):
-        assert isinstance(obj2, Term)
-        c = rcmp(obj1.argument_count(), obj2.argument_count())
-        if c != 0:
-            return c
-        c = rcmp(obj1.name(), obj2.name())
-        if c != 0:
-            return c
-        for i in range(obj1.argument_count()):
-            a1 = obj1.argument_at(i).dereference(heap)
-            a2 = obj2.argument_at(i).dereference(heap)
-            c = cmp_standard_order(a1, a2, heap)
-            if c != 0:
-                return c
-        return 0
-    # XXX hum
-    if isinstance(obj1, Number):
-        if isinstance(obj2, Number):
-            return rcmp(obj1.num, obj2.num)
-        elif isinstance(obj2, Float):
-            return rcmp(obj1.num, obj2.floatval)
-    if isinstance(obj1, Float):
-        if isinstance(obj2, Number):
-            return rcmp(obj1.floatval, obj2.num)
-        elif isinstance(obj2, Float):
-            return rcmp(obj1.floatval, obj2.floatval)
-    assert 0
+    return obj1.cmp_standard_order(obj2, heap)
+
+def generate_class(cname, fname, n_args):
+    from pypy.rlib.unroll import unrolling_iterable
+    arg_iter = unrolling_iterable(range(n_args))
+    signature = fname + '/' + str(n_args)
+    class cls(Callable):
+        if n_args == 0:
+            TYPE_STANDARD_ORDER = Atom.TYPE_STANDARD_ORDER
+        else:
+            TYPE_STANDARD_ORDER = Term.TYPE_STANDARD_ORDER
+            
+        def __init__(self, args):
+            assert len(args) == n_args
+            for x in arg_iter:
+                setattr(self, 'val_%d' % x, args[x])
+                
+        def name(self):
+            return fname
+        
+        def signature(self):
+            return signature
+        
+        def arguments(self):
+            result = [None] * n_args
+            for x in arg_iter:
+                result[x] = getattr(self, 'val_%d' % x)
+            return result
+
+        def argument_at(self, i):
+            for x in arg_iter:
+                if x == i:
+                    return getattr(self, 'val_%d' % x)
+
+        def argument_count(self):
+            return n_args
+            
+    cls.__name__ = cname
+    return cls
+    
+specialized_term_classes = {}
+classes = [('Cons', '.', 2)]
+for cname, fname, numargs in classes:
+    specialized_term_classes[fname, numargs] = generate_class(cname, fname, numargs)

prolog/interpreter/test/test_callable_interface.py

 from prolog.interpreter.parsing import parse_file, TermBuilder
-from prolog.interpreter.term import Atom, Number, Term, Callable
+from prolog.interpreter.term import Atom, Number, Term, Callable, specialized_term_classes
 import py
 
 def parse(inp):
 def test_callable_factory_for_term():
     r = Callable.build('foo', [1, 2])
     assert isinstance(r, Term)
-    assert r.signature() == 'foo/2'
+    assert r.signature() == 'foo/2'
+    
+def test_callable_factory_for_cons():
+    r = Callable.build('.', [1, Callable.build('[]')])
+    assert isinstance(r, specialized_term_classes['.', 2])
+    assert r.signature() == './2'
+    assert r.name() == '.'
+    assert r.argument_count() == 2
+    assert r.arguments() == [1, Callable.build('[]')]
+    assert r.argument_at(0) == 1
+    assert r.argument_at(1) == Callable.build('[]')

prolog/interpreter/test/test_function.py

         return self.name() + '/123'
     def name(self):
         return 'C'
+    def argument_count(self):
+        return 0
     __repr__ = __str__
 def test_copy():