Commits

Carl Friedrich Bolz committed 6d60b4e

add function tests and make them work, carefully introduce the distinction between mutable and immutable shaped callables

Comments (0)

Files changed (5)

prolog/builtin/atomconstruction.py

 from prolog.interpreter import helper, term, error
 from prolog.interpreter import continuation
 from prolog.builtin.register import expose_builtin
-from prolog.interpreter.term import specialized_term_classes
 from prolog.interpreter.term import Callable
 import re
 

prolog/builtin/numberchars.py

 from prolog.interpreter import term, error
 from prolog.builtin.register import expose_builtin
 from prolog.interpreter.term import Callable
-from prolog.interpreter.term import specialized_term_classes
 from pypy.rlib.rarithmetic import ovfcheck
 from pypy.rlib.rbigint import rbigint
 from prolog.interpreter.signature import Signature

prolog/interpreter/shape.py

 
     def resolve(self, storage, index):
         storage = storage[index:index + self.num_storage_vars()]
+        # XXX
         return ShapedCallable(self, storage)
 
     def resolve_at(self, i, storage):
 
 # _____________________________________________________________________
 
-class ShapedCallable(term.MutableCallable):
+class ShapedCallableBase(term.Callable):
+    def get_shape(self):
+        raise NotImplementedError("abstract base class")
+
+    def get_storage(self, i):
+        raise NotImplementedError("abstract base class")
+
+    def size_storage(self):
+        raise NotImplementedError("abstract base class")
+
+    def new(self, shape, storage):
+        raise NotImplementedError("abstract base class")
+
+class ShapedCallableMixin:
     TYPE_STANDARD_ORDER = term.Term.TYPE_STANDARD_ORDER
+    _mixin_ = True
 
     def __init__(self, shape, storage):
         assert isinstance(shape, SharingShape)
         self.storage = storage
         assert shape.num_storage_vars() == len(storage)
 
+    def get_shape(self):
+        return self.shape
+
+    def get_storage(self, i):
+        return self.storage[i]
+
+    def size_storage(self):
+        return len(self.storage)
+
     # _____________________________________________________________________
     # callable interface
 
 
     @objectmodel.specialize.arg(3)
     def basic_unify(self, other, heap, occurs_check=False):
-        if (isinstance(other, ShapedCallable) and
-                self.shape is other.shape):
-            for i in range(len(self.storage)):
-                self.storage[i].unify(other.storage[i], heap, occurs_check)
+        if (isinstance(other, ShapedCallableBase) and
+                self.shape is other.get_shape()):
+            for i in range(self.size_storage()):
+                self.get_storage(i).unify(other.get_storage(i), heap, occurs_check)
             return
         return term.Callable.basic_unify(self, other, heap, occurs_check)
 
     @jit.unroll_safe
     def copy_and_basic_unify(self, other, heap, env):
-        if (isinstance(other, ShapedCallable) and
-                self.shape is other.shape):
-            for i in range(len(self.storage)):
-                self.storage[i].unify_and_standardize_apart(other.storage[i], heap, env)
+        if (isinstance(other, ShapedCallableBase) and
+                self.shape is other.get_shape()):
+            for i in range(self.size_storage()):
+                self.get_storage(i).unify_and_standardize_apart(
+                        other.get_storage(i), heap, env)
             return
         return term.Callable.copy_and_basic_unify(self, other, heap, env)
 
 
     def copy_standardize_apart(self, heap, env):
         storage = [None] * len(self.storage)
-        result = ShapedCallable(self.shape, storage)
+        result = ShapedCallableMutable(self.shape, storage)
         newinstance = False
         needmutable = False
         i = 0
             arg = self.storage[i]
             cloned = arg.copy_standardize_apart_as_child_of(heap, env, result, i)
             newinstance = newinstance | (isinstance(arg, term.NumberedVar) or cloned is not arg)
+            needmutable = needmutable | isinstance(cloned, term.VarInTerm)
             storage[i] = cloned
         if newinstance:
+            if not needmutable:
+                return result._make_immutable()
             return result
         else:
             return self
             i += 1
         if newinstance:
             # XXX what about the variable shunting in Callable.build?
-            return ShapedCallable(self.shape, args)
+            return self.new(self.shape, args)
         else:
             return self
 
     # shape-specific interface
 
     def _replace_child(self, i, obj, new_shape):
-        assert isinstance(obj, ShapedCallable)
-        self.storage = self.storage[:i] + obj.storage + self.storage[i + 1:]
-        assert len(self.storage) == new_shape.num_storage_vars()
+        assert isinstance(obj, ShapedCallableBase)
+        assert obj.size_storage() + self.size_storage() - 1 == new_shape.num_storage_vars()
+        objstorage = [obj.get_storage(j) for j in range(obj.size_storage())]
+        self.storage = self.storage[:i] + objstorage + self.storage[i + 1:]
         self.shape = new_shape
 
     def replace_child(self, i, obj):
-        if isinstance(obj, ShapedCallable):
-            new_shape = self.shape.get_transition(i, obj.shape)
+        if isinstance(obj, ShapedCallableBase):
+            new_shape = self.shape.get_transition(i, obj.get_shape())
             if new_shape is not None:
+                old_length = len(self.storage)
                 self._replace_child(i, obj, new_shape)
-                return True
-        return False
+                # XXX whew, subtle logic here
+                for i in range(obj.size_storage()):
+                    old_child = obj.get_storage(i)
+                    if isinstance(old_child, term.VarInTerm):
+                        newi = i + old_length - 1
+                        heap = old_child.created_after_choice_point
+                        new_child = heap.newvar_in_term(self, newi)
+                        self.storage[newi] = new_child
+                        old_child.parent_or_binding = new_child
+                        old_child.bound = True
+                        obj.storage[i] = new_child
+                        self = self._make_mutable()
+                return self
+        return None
 
     @staticmethod
-    def build(shape, children):
-        result = ShapedCallable(shape, children)
+    def build(shape, storage):
+        if isinstance(shape, WrapShape):
+            assert not storage
+            return shape.w_obj
+        result = ShapedCallable(shape, storage)
         i = 0
         while i < len(result.storage):
             child = result.storage[i]
-            if not result.replace_child(i, child):
+            newresult = result.replace_child(i, child)
+            if not newresult:
                 i += 1
+            else:
+                result = newresult
         assert result.shape.num_storage_vars() == len(result.storage)
         return result
 
+class ShapedCallableMutable(ShapedCallableMixin, ShapedCallableBase):
+    def _make_immutable(self):
+        return ShapedCallable(self.shape, self.storage)
+
+    def _make_mutable(self):
+        return self
+
+    def new(self, shape, storage):
+        return ShapedCallableMutable(shape, storage)
+
+
+class ShapedCallable(ShapedCallableMixin, ShapedCallableBase):
+    _immutable_fields_ = ["shape", "storage[*]"]
+
+    def _make_mutable(self):
+        return ShapedCallableMutable(self.shape, self.storage)
+
+    def new(self, shape, storage):
+        return ShapedCallable(shape, storage)
 
 
 # _____________________________________________________________________
                 if obj is None:
                     obj = env[index] = heap.newvar()
             storage[i] = obj
-        return self.shape.resolve(storage, 0)
+        return ShapedCallable.build(self.shape, storage)
 
 # _____________________________________________________________________
 

prolog/interpreter/term.py

         raise NotImplementedError("abstract base class")
 
     def init(self, parent):
-        from prolog.interpreter.shape import ShapedCallable
-        assert isinstance(parent, ShapedCallable)
+        from prolog.interpreter.shape import ShapedCallableMutable
+        assert isinstance(parent, ShapedCallableMutable)
         self.parent_or_binding = parent
         self.bound = False
 
             var = self.created_after_choice_point.newvar()
             var.setvalue(value, heap)
             value = var
-        self._setvalue_in_parent(value)
+        self._setvalue_in_parent(value, heap)
         self.bound = True
         self.parent_or_binding = value
 
-    def _setvalue_in_parent(self, value):
+    def _setvalue_in_parent(self, value, heap):
         raise NotImplementedError("abstract base class")
 
     def __repr__(self):
         def __init__(self, parent):
             self.init(parent)
 
-        def _setvalue_in_parent(self, value):
-            from prolog.interpreter.shape import ShapedCallable
+        def _setvalue_in_parent(self, value, heap):
+            from prolog.interpreter.shape import ShapedCallableMutable
             obj = self.parent_or_binding
-            assert isinstance(obj, ShapedCallable)
+            assert isinstance(obj, ShapedCallableMutable)
             if not obj.replace_child(index, value):
                 obj.storage[index] = value
     VarInTermN.__name__ = "VarInTerm%s" % index
                 assert signature.numargs == len(args)
             assert isinstance(signature, Signature)
 
-            cls = Callable._find_specialized_class(term_name, len(args))
-            if cls is not None:
-                return cls(term_name, args, signature)
-            cls = Callable._find_specialized_class('Term', len(args))
-            if cls is not None:
-                return cls(term_name, args, signature)
             return Term(term_name, args, signature)
 
-    @staticmethod
-    @jit.elidable
-    def _find_specialized_class(term_name, numargs):
-        return specialized_term_classes.get((term_name, numargs), None)
-
     def __repr__(self):
         return "%s(%s, %r)" % (self.__class__.__name__, self.name(),
                                self.arguments())
                 return False
         return True
 
-class MutableCallable(Callable):
-    def set_argument_at(self, i, arg):
-        raise NotImplementedError
-
-
 class Atom(Callable):
     TYPE_STANDARD_ORDER = 1
     __slots__ = ('_name', '_signature')
     if c != 0:
         return c
     return obj1.cmp_standard_order(obj2, heap)
-
-def generate_class(cname, fname, n_args, immutable=True):
-    from pypy.rlib.unroll import unrolling_iterable
-    arg_iter = unrolling_iterable(range(n_args))
-    parent = callables['Abstract', n_args]
-    if not immutable:
-        parent = parent.mutable_version
-    assert parent is not None
-    signature = Signature.getsignature(fname, n_args)
-
-    class specific_class(parent):
-        if n_args == 0:
-            TYPE_STANDARD_ORDER = Atom.TYPE_STANDARD_ORDER
-        else:
-            TYPE_STANDARD_ORDER = Term.TYPE_STANDARD_ORDER
-        
-        def __init__(self, term_name, args, signature):
-            parent._init_values(self, args)
-            assert self.name() == term_name
-            assert args is None or len(args) == n_args
-                
-        def name(self):
-            return fname
-        
-        def signature(self):
-            return signature
-
-        def _make_new(self, name, signature):
-            cls = specific_class
-            return cls(name, None, signature)
-
-        if immutable:
-            def _make_new_mutable(self, name, signature):
-                cls = mutable_version
-                return cls(name, None, signature)
-        else:
-            _make_new_mutable = _make_new
-    if immutable:
-        mutable_version = specific_class.mutable_version = generate_class(
-                cname, fname, n_args, False)
-    specific_class.__name__ = cname + "Mutable" * (not immutable)
-    return specific_class
-
-def generate_abstract_class(n_args, immutable=True):
-    from pypy.rlib.unroll import unrolling_iterable
-    arg_iter = unrolling_iterable(range(n_args))
-    if immutable:
-        base = Callable
-    else:
-        base = MutableCallable
-    class abstract_callable(base):
-
-        if immutable:
-            _immutable_fields_ = ["val_%d" % x for x in arg_iter]
-
-        def __init__(self, term_name, args, signature):
-            raise NotImplementedError
-
-        def _init_values(self, args):
-            if args is None:
-                return
-            for x in arg_iter:
-                setattr(self, 'val_%d' % x, args[x])
-
-        def _make_new(self, name, signature, mutable=False):
-            raise NotImplementedError("abstract base class")
-        _make_new_mutable = _make_new
-
-        def arguments(self):
-            result = [None] * n_args
-            for x in arg_iter:
-                result[x] = getattr(self, 'val_%d' % x)
-            return result
-        
-        def argument_at(self, i):
-            for x in arg_iter:
-                if x == i:
-                    return getattr(self, 'val_%d' % x)
-            raise IndexError
-
-        if not immutable:
-            def set_argument_at(self, i, arg):
-                for x in arg_iter:
-                    if x == i:
-                        setattr(self, 'val_%d' % x, arg)
-                        return
-                raise IndexError
-
-        def argument_count(self):
-            return n_args
-
-        def quick_unify_check(self, other):
-            other = other.dereference(None)
-            if isinstance(other, Var):
-                return True
-            if not isinstance(other, Callable):
-                return False
-            if not self.signature().eq(other.signature()):
-                return False
-            if not isinstance(other, abstract_callable):
-                return Callable.quick_unify_check(self, other)
-            for x in arg_iter:
-                a = getattr(self, 'val_%d' % x)
-                b = getattr(other, 'val_%d' % x)
-                if not a.quick_unify_check(b):
-                    return False
-            return True
-
-        def copy_and_basic_unify(self, other, heap, env):
-            if not isinstance(other, abstract_callable):
-                return Callable.copy_and_basic_unify(self, other, heap, env)
-            if self.signature().eq(other.signature()):
-                for x in arg_iter:
-                    a = getattr(self, 'val_%d' % x)
-                    b = getattr(other, 'val_%d' % x)
-                    a.unify_and_standardize_apart(b, heap, env)
-            else:
-                raise UnificationFailed
-
-        def copy_standardize_apart(self, heap, env):
-            result = self._make_new_mutable(self.name(), self.signature())
-            newinstance = False
-            needmutable = False
-            i = 0
-            for i in arg_iter:
-                arg = getattr(self, 'val_%d' % i)
-                cloned = arg.copy_standardize_apart_as_child_of(heap, env, result, i)
-                newinstance = newinstance | (cloned is not arg)
-                needmutable = needmutable | isinstance(arg, VarInTerm)
-                setattr(result, 'val_%d' % i, cloned)
-                i += 1
-            if newinstance:
-                # XXX what about the variable shunting in Callable.build
-                return result
-            else:
-                return self
-
-        @specialize.arg(3)
-        @jit.dont_look_inside
-        def basic_unify(self, other, heap, occurs_check=False):
-            if not isinstance(other, abstract_callable):
-                return Callable.basic_unify(self, other, heap, occurs_check)
-            if self.signature().eq(other.signature()):
-                for x in arg_iter:
-                    a = getattr(self, 'val_%d' % x)
-                    b = getattr(other, 'val_%d' % x)
-                    a.unify(b, heap, occurs_check)
-            else:
-                raise UnificationFailed
-
-        @specialize.arg(1)
-        def _copy_term(self, copy_individual, heap, *extraargs):
-            result = self._make_new(self.name(), self.signature())
-            newinstance = False
-            i = 0
-            for i in arg_iter:
-                arg = getattr(self, 'val_%d' % i)
-                cloned = copy_individual(arg, i, heap, *extraargs)
-                newinstance = newinstance | (cloned is not arg)
-                setattr(result, 'val_%d' % i, cloned)
-                i += 1
-            if newinstance:
-                # XXX what about the variable shunting in Callable.build
-                return result
-            else:
-                return self
-    if immutable:
-        abstract_callable.mutable_version = generate_abstract_class(n_args, immutable=False)
-    else:
-        abstract_callable.mutable_version = abstract_callable
-
-    abstract_callable.__name__ = 'Abstract'+str(n_args) + "Mutable" * (not immutable)
-    return abstract_callable
-
-def generate_generic_class(n_args, immutable=True):
-    parent = callables['Abstract', n_args]
-    assert parent is not None
-    if not immutable:
-        parent = parent.mutable_version
-
-    class generic_callable(parent):
-        _immutable_fields_ = ["_signature"]
-        TYPE_STANDARD_ORDER = Term.TYPE_STANDARD_ORDER
-        
-        def __init__(self, term_name, args, signature):
-            parent._init_values(self, args)
-            self._signature = signature
-            assert args is None or len(args) == n_args
-            assert self.name() == term_name
-
-        def _make_new(self, name, signature, mutable=False):
-            cls = generic_callable
-            return cls(name, None, signature)
-
-        if immutable:
-            def _make_new_mutable(self, name, signature, mutable=False):
-                cls = mutable_version
-                return cls(name, None, signature)
-        else:
-            _make_new_mutable = _make_new
-
-        def signature(self):
-            return self._signature
-    if immutable:
-        mutable_version = generic_callable.mutable_version = generate_generic_class(n_args, False)
-    generic_callable.__name__ = 'Generic'+str(n_args) + "Mutable" * (not immutable)
-    return generic_callable
-
-
-specialized_term_classes = {}
-callables = {}
-
-for numargs in range(1, OPTIMIZED_TERM_SIZE_MAX):
-    callables['Abstract', numargs] = generate_abstract_class(numargs)
-
-classes = [('Cons', '.', 2), ('Or', ';', 2), ('And', ',', 2)]
-for cname, fname, numargs in classes:
-    specialized_term_classes[fname, numargs] = generate_class(
-                                                        cname, fname, numargs)
-
-for numargs in range(1, 10):
-    assert ('Term', numargs) not in specialized_term_classes
-    specialized_term_classes['Term', numargs] = generate_generic_class(numargs)

prolog/interpreter/test/test_shape.py

 
 def test_shaped_callable_replace_child():
     sig = signature.Signature.getsignature(".", 2)
-    build = shape.SharingShape.build
+    build = shape.SharingShape
     X = shape.InStorageShape.build()
     s1 = build(sig, [X, X])
     a = term.Callable.build("a")
     c1._replace_child(0, c2, newshape)
     assert c1.storage == [b, nil, a]
 
+def test_replace_child_fixup_varinterm():
+    from prolog.interpreter.heap import Heap
+    h = Heap()
+    sig = signature.Signature.getsignature(".", 2)
+    build = shape.SharingShape
+    X = shape.InStorageShape.build()
+    s1 = build(sig, [X, X])
+    a = term.Callable.build("a")
+    b = term.Callable.build("b")
+    nil = term.Callable.build("[]")
+    c1 = shape.ShapedCallableMutable(s1, [a, None])
+
+    c2 = shape.ShapedCallableMutable(s1, [b, None])
+    var2 = h.newvar_in_term(c2, 1)
+    c2.storage[1] = var2
+
+    s1.get_transition(1, s1)
+    res = c1.replace_child(1, c2)
+    assert res
+    assert c1.storage[2].parent_or_binding is c1
+
+
 def test_depth():
     sig = signature.Signature.getsignature(".", 2)
     b = shape.SharingShape.build
     c1.unify(c2, h)
     assert X.binding is a
 
+def test_functional_test():
+    from prolog.interpreter.continuation import Engine
+    from prolog.interpreter.test.tool import assert_true, get_engine
+    e = get_engine("""
+        append([], L, L).
+        append([H|T], L, [H|R]) :- append(T, L, R).
+        reverse([], L, L).
+        reverse([H|T], L, O) :-
+            reverse(T, [H | L], O).
+    """)
 
+    for i in range(10):
+        env = assert_true("append([1, 2, 3, 4, 5], [2, 3, 4, 5, 6], X).", e)
+    res = env['X']
+    l = []
+    while res.name() == ".":
+        l.append(res.argument_at(0).num)
+        res = res.argument_at(1)
+    assert l == [1, 2, 3, 4, 5, 2, 3, 4, 5, 6]
+    res = env['X']
+    assert len(res.storage) > 5
+
+    for i in range(10):
+        env = assert_true("reverse([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [], X).", e)
+    res = env['X']
+    l = []
+    while res.name() == ".":
+        l.append(res.argument_at(0).num)
+        res = res.argument_at(1)
+    assert l == [10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
+    res = env['X']
+    assert len(res.storage) > 5