Commits

Alex Gaynor  committed 7dd24d6 Merge

Merged virtualref-virtualizable

This improves the generated code in cases where a virtualref points to the standard virtualizable and we force it.

  • Participants
  • Parent commits 8463b9c, f254155

Comments (0)

Files changed (4)

File rpython/jit/codewriter/jtransform.py

         return SpaceOperation('libffi_save_result_%s' % kind, op.args[1:], None)
 
     def rewrite_op_jit_force_virtual(self, op):
-        return self._do_builtin_call(op)
+        op0 = SpaceOperation('-live-', [], None)
+        op1 = self._do_builtin_call(op)
+        if isinstance(op1, list):
+            return [op0] + op1
+        else:
+            return [op0, op1]
 
     def rewrite_op_jit_is_virtual(self, op):
-        raise Exception, (
-            "'vref.virtual' should not be used from jit-visible code")
+        raise Exception("'vref.virtual' should not be used from jit-visible code")
 
     def rewrite_op_jit_force_virtualizable(self, op):
         # this one is for virtualizables

File rpython/jit/codewriter/test/test_flatten.py

         self.encoding_test(f, [], """
             new_with_vtable <Descr> -> %r0
             virtual_ref %r0 -> %r1
+            -live-
             residual_call_r_r $<* fn jit_force_virtual>, R[%r1], <Descr> -> %r2
             ref_return %r2
         """, transform=True, cc=FakeCallControlWithVRefInfo())

File rpython/jit/metainterp/pyjitpl.py

-import py, sys
+import sys
+
+import py
+
+from rpython.jit.codewriter import heaptracker
+from rpython.jit.codewriter.effectinfo import EffectInfo
+from rpython.jit.codewriter.jitcode import JitCode, SwitchDictDescr
+from rpython.jit.metainterp import history, compile, resume, executor
+from rpython.jit.metainterp.heapcache import HeapCache
+from rpython.jit.metainterp.history import (Const, ConstInt, ConstPtr,
+    ConstFloat, Box, TargetToken)
+from rpython.jit.metainterp.jitexc import JitException, get_llexception
+from rpython.jit.metainterp.jitprof import EmptyProfiler
+from rpython.jit.metainterp.logger import Logger
+from rpython.jit.metainterp.optimizeopt.util import args_dict_box
+from rpython.jit.metainterp.resoperation import rop
+from rpython.rlib import nonconst, rstack
+from rpython.rlib.debug import debug_start, debug_stop, debug_print, make_sure_not_resized
+from rpython.rlib.jit import Counters
+from rpython.rlib.objectmodel import we_are_translated, specialize
+from rpython.rlib.unroll import unrolling_iterable
 from rpython.rtyper.lltypesystem import lltype, rclass
-from rpython.rlib.objectmodel import we_are_translated
-from rpython.rlib.unroll import unrolling_iterable
-from rpython.rlib.debug import debug_start, debug_stop, debug_print
-from rpython.rlib.debug import make_sure_not_resized
-from rpython.rlib import nonconst, rstack
 
-from rpython.jit.metainterp import history, compile, resume
-from rpython.jit.metainterp.history import Const, ConstInt, ConstPtr, ConstFloat
-from rpython.jit.metainterp.history import Box, TargetToken
-from rpython.jit.metainterp.resoperation import rop
-from rpython.jit.metainterp import executor
-from rpython.jit.metainterp.logger import Logger
-from rpython.jit.metainterp.jitprof import EmptyProfiler
-from rpython.rlib.jit import Counters
-from rpython.jit.metainterp.jitexc import JitException, get_llexception
-from rpython.jit.metainterp.heapcache import HeapCache
-from rpython.rlib.objectmodel import specialize
-from rpython.jit.codewriter.jitcode import JitCode, SwitchDictDescr
-from rpython.jit.codewriter import heaptracker
-from rpython.jit.metainterp.optimizeopt.util import args_dict_box
+
 
 # ____________________________________________________________
 
     opimpl_inline_call_irf_f = _opimpl_inline_call3
     opimpl_inline_call_irf_v = _opimpl_inline_call3
 
-    @arguments("box", "boxes", "descr")
-    def _opimpl_residual_call1(self, funcbox, argboxes, calldescr):
-        return self.do_residual_or_indirect_call(funcbox, argboxes, calldescr)
-    @arguments("box", "boxes2", "descr")
-    def _opimpl_residual_call2(self, funcbox, argboxes, calldescr):
-        return self.do_residual_or_indirect_call(funcbox, argboxes, calldescr)
-    @arguments("box", "boxes3", "descr")
-    def _opimpl_residual_call3(self, funcbox, argboxes, calldescr):
-        return self.do_residual_or_indirect_call(funcbox, argboxes, calldescr)
+    @arguments("box", "boxes", "descr", "orgpc")
+    def _opimpl_residual_call1(self, funcbox, argboxes, calldescr, pc):
+        return self.do_residual_or_indirect_call(funcbox, argboxes, calldescr, pc)
+
+    @arguments("box", "boxes2", "descr", "orgpc")
+    def _opimpl_residual_call2(self, funcbox, argboxes, calldescr, pc):
+        return self.do_residual_or_indirect_call(funcbox, argboxes, calldescr, pc)
+
+    @arguments("box", "boxes3", "descr", "orgpc")
+    def _opimpl_residual_call3(self, funcbox, argboxes, calldescr, pc):
+        return self.do_residual_or_indirect_call(funcbox, argboxes, calldescr, pc)
 
     opimpl_residual_call_r_i = _opimpl_residual_call1
     opimpl_residual_call_r_r = _opimpl_residual_call1
     opimpl_residual_call_irf_f = _opimpl_residual_call3
     opimpl_residual_call_irf_v = _opimpl_residual_call3
 
-    @arguments("int", "boxes3", "boxes3")
-    def _opimpl_recursive_call(self, jdindex, greenboxes, redboxes):
+    @arguments("int", "boxes3", "boxes3", "orgpc")
+    def _opimpl_recursive_call(self, jdindex, greenboxes, redboxes, pc):
         targetjitdriver_sd = self.metainterp.staticdata.jitdrivers_sd[jdindex]
         allboxes = greenboxes + redboxes
         warmrunnerstate = targetjitdriver_sd.warmstate
             # that assembler that we call is still correct
             self.verify_green_args(targetjitdriver_sd, greenboxes)
         #
-        return self.do_recursive_call(targetjitdriver_sd, allboxes,
+        return self.do_recursive_call(targetjitdriver_sd, allboxes, pc,
                                       assembler_call)
 
-    def do_recursive_call(self, targetjitdriver_sd, allboxes,
+    def do_recursive_call(self, targetjitdriver_sd, allboxes, pc,
                           assembler_call=False):
         portal_code = targetjitdriver_sd.mainjitcode
         k = targetjitdriver_sd.portal_runner_adr
         funcbox = ConstInt(heaptracker.adr2int(k))
-        return self.do_residual_call(funcbox, allboxes, portal_code.calldescr,
+        return self.do_residual_call(funcbox, allboxes, portal_code.calldescr, pc,
                                      assembler_call=assembler_call,
                                      assembler_call_jd=targetjitdriver_sd)
 
             return box     # no promotion needed, already a Const
         else:
             constbox = box.constbox()
-            resbox = self.do_residual_call(funcbox, [box, constbox], descr)
+            resbox = self.do_residual_call(funcbox, [box, constbox], descr, orgpc)
             promoted_box = resbox.constbox()
             # This is GUARD_VALUE because GUARD_TRUE assumes the existance
             # of a label when computing resumepc
             except ChangeFrame:
                 pass
             frame = self.metainterp.framestack[-1]
-            frame.do_recursive_call(jitdriver_sd, greenboxes + redboxes,
+            frame.do_recursive_call(jitdriver_sd, greenboxes + redboxes, orgpc,
                                     assembler_call=True)
             raise ChangeFrame
 
             self.metainterp.assert_no_exception()
         return resbox
 
-    def do_residual_call(self, funcbox, argboxes, descr,
+    def do_residual_call(self, funcbox, argboxes, descr, pc,
                          assembler_call=False,
                          assembler_call_jd=None):
         # First build allboxes: it may need some reordering from the
                 effectinfo.check_forces_virtual_or_virtualizable()):
             # residual calls require attention to keep virtualizables in-sync
             self.metainterp.clear_exception()
+            if effectinfo.oopspecindex == EffectInfo.OS_JIT_FORCE_VIRTUAL:
+                resbox = self._do_jit_force_virtual(allboxes, descr, pc)
+                if resbox is not None:
+                    return resbox
             self.metainterp.vable_and_vrefs_before_residual_call()
             resbox = self.metainterp.execute_and_record_varargs(
                 rop.CALL_MAY_FORCE, allboxes, descr=descr)
             pure = effectinfo.check_is_elidable()
             return self.execute_varargs(rop.CALL, allboxes, descr, exc, pure)
 
-    def do_residual_or_indirect_call(self, funcbox, argboxes, calldescr):
+    def _do_jit_force_virtual(self, allboxes, descr, pc):
+        assert len(allboxes) == 2
+        if (self.metainterp.jitdriver_sd.virtualizable_info is None and
+            self.metainterp.jitdriver_sd.greenfield_info is None):
+            # can occur in case of multiple JITs
+            return None
+        vref_box = allboxes[1]
+        standard_box = self.metainterp.virtualizable_boxes[-1]
+        if standard_box is vref_box:
+            return vref_box
+        if self.metainterp.heapcache.is_nonstandard_virtualizable(vref_box):
+            return None
+        eqbox = self.metainterp.execute_and_record(rop.PTR_EQ, None, vref_box, standard_box)
+        eqbox = self.implement_guard_value(eqbox, pc)
+        isstandard = eqbox.getint()
+        if isstandard:
+            return standard_box
+        else:
+            return None
+
+    def do_residual_or_indirect_call(self, funcbox, argboxes, calldescr, pc):
         """The 'residual_call' operation is emitted in two cases:
         when we have to generate a residual CALL operation, but also
         to handle an indirect_call that may need to be inlined."""
                 # we should follow calls to this graph
                 return self.metainterp.perform_call(jitcode, argboxes)
         # but we should not follow calls to that graph
-        return self.do_residual_call(funcbox, argboxes, calldescr)
+        return self.do_residual_call(funcbox, argboxes, calldescr, pc)
 
 # ____________________________________________________________
 

File rpython/jit/metainterp/test/test_virtualizable.py

 import py
+
+from rpython.jit.codewriter import heaptracker
+from rpython.jit.codewriter.policy import StopAtXPolicy
+from rpython.jit.metainterp.optimizeopt.test.test_util import LLtypeMixin
+from rpython.jit.metainterp.test.support import LLJitMixin, OOJitMixin
+from rpython.jit.metainterp.warmspot import get_translator
+from rpython.rlib.jit import JitDriver, hint, dont_look_inside, promote, virtual_ref
+from rpython.rlib.rarithmetic import intmask
+from rpython.rtyper.annlowlevel import hlstr
 from rpython.rtyper.extregistry import ExtRegistryEntry
 from rpython.rtyper.lltypesystem import lltype, lloperation, rclass, llmemory
-from rpython.rtyper.annlowlevel import llhelper
-from rpython.rtyper.rclass import IR_IMMUTABLE, IR_IMMUTABLE_ARRAY
-from rpython.jit.codewriter.policy import StopAtXPolicy
-from rpython.jit.codewriter import heaptracker
-from rpython.rlib.jit import JitDriver, hint, dont_look_inside, promote
-from rpython.rlib.rarithmetic import intmask
-from rpython.jit.metainterp.test.support import LLJitMixin, OOJitMixin
-from rpython.rtyper.rclass import FieldListAccessor
-from rpython.jit.metainterp.warmspot import get_stats, get_translator
-from rpython.jit.metainterp import history
-from rpython.jit.metainterp.optimizeopt.test.test_util import LLtypeMixin
+from rpython.rtyper.rclass import IR_IMMUTABLE, IR_IMMUTABLE_ARRAY, FieldListAccessor
+
 
 def promote_virtualizable(*args):
     pass
+
+
 class Entry(ExtRegistryEntry):
     "Annotation and rtyping of LLOp instances, which are callable."
     _about_ = promote_virtualizable
         ('inst_node', lltype.Ptr(LLtypeMixin.NODE)),
         hints = {'virtualizable2_accessor': FieldListAccessor()})
     XY._hints['virtualizable2_accessor'].initialize(
-        XY, {'inst_x' : IR_IMMUTABLE, 'inst_node' : IR_IMMUTABLE})
+        XY, {'inst_x': IR_IMMUTABLE, 'inst_node': IR_IMMUTABLE})
 
     xy_vtable = lltype.malloc(rclass.OBJECT_VTABLE, immortal=True)
     heaptracker.set_testing_vtable_for_gcstruct(XY, xy_vtable, 'XY')
                 x = xy.inst_x
                 xy.inst_x = x + 1
                 n -= 1
-            promote_virtualizable(xy, 'inst_x')                
+            promote_virtualizable(xy, 'inst_x')
             return xy.inst_x
         res = self.meta_interp(f, [20])
         assert res == 30
                     x = xy.inst_x
                     xy.inst_x = x + 10
                 n -= 1
-            promote_virtualizable(xy, 'inst_x')                
+            promote_virtualizable(xy, 'inst_x')
             return xy.inst_x
         assert f(5) == 185
         res = self.meta_interp(f, [5])
                 x = xy.inst_x
                 if n <= 10:
                     x += 1000
-                promote_virtualizable(xy, 'inst_x')                    
+                promote_virtualizable(xy, 'inst_x')
                 xy.inst_x = x + 1
                 n -= 1
-            promote_virtualizable(xy, 'inst_x')                
+            promote_virtualizable(xy, 'inst_x')
             return xy.inst_x
         res = self.meta_interp(f, [18])
         assert res == 10118
                 xy.inst_x = x + 1
                 m = (m+1) & 3     # the loop gets unrolled 4 times
                 n -= 1
-            promote_virtualizable(xy, 'inst_x')                
+            promote_virtualizable(xy, 'inst_x')
             return xy.inst_x
         def f(n):
             res = 0
                 promote_virtualizable(xy, 'inst_x')
                 xy.inst_x = value + 100      # virtualized away
                 n -= 1
-            promote_virtualizable(xy, 'inst_x')                
+            promote_virtualizable(xy, 'inst_x')
             return xy.inst_x
         res = self.meta_interp(f, [20])
         assert res == 134
         ('inst_l2', lltype.Ptr(lltype.GcArray(lltype.Signed))),
         hints = {'virtualizable2_accessor': FieldListAccessor()})
     XY2._hints['virtualizable2_accessor'].initialize(
-        XY2, {'inst_x' : IR_IMMUTABLE,
-              'inst_l1' : IR_IMMUTABLE_ARRAY, 'inst_l2' : IR_IMMUTABLE_ARRAY})
+        XY2, {'inst_x': IR_IMMUTABLE,
+              'inst_l1': IR_IMMUTABLE_ARRAY, 'inst_l2': IR_IMMUTABLE_ARRAY})
 
     xy2_vtable = lltype.malloc(rclass.OBJECT_VTABLE, immortal=True)
     heaptracker.set_testing_vtable_for_gcstruct(XY2, xy2_vtable, 'XY2')
             while n > 0:
                 myjitdriver.can_enter_jit(xy2=xy2, n=n)
                 myjitdriver.jit_merge_point(xy2=xy2, n=n)
-                promote_virtualizable(xy2, 'inst_l1')                
+                promote_virtualizable(xy2, 'inst_l1')
                 promote_virtualizable(xy2, 'inst_l2')
                 xy2.inst_l1[2] += xy2.inst_l2[0]
                 n -= 1
-            promote_virtualizable(xy2, 'inst_l1')                
+            promote_virtualizable(xy2, 'inst_l1')
             return xy2.inst_l1[2]
         res = self.meta_interp(f, [16])
         assert res == 3001 + 16 * 80
                 myjitdriver.can_enter_jit(xy2=xy2, n=n)
                 myjitdriver.jit_merge_point(xy2=xy2, n=n)
                 promote_virtualizable(xy2, 'inst_l1')
-                promote_virtualizable(xy2, 'inst_l2')                
+                promote_virtualizable(xy2, 'inst_l2')
                 xy2.inst_l1[1] += len(xy2.inst_l2)
                 n -= 1
         def f(n):
                 promote_virtualizable(xy2, 'inst_l2')
                 xy2.inst_l2[0] = value + 100      # virtualized away
                 n -= 1
-            promote_virtualizable(xy2, 'inst_l2')                
+            promote_virtualizable(xy2, 'inst_l2')
             return xy2.inst_l2[0]
         expected = f(20)
         res = self.meta_interp(f, [20], enable_opts='')
                 myjitdriver.can_enter_jit(xy2=xy2, n=n)
                 myjitdriver.jit_merge_point(xy2=xy2, n=n)
                 parent = xy2.parent
-                promote_virtualizable(parent, 'inst_x')                
-                promote_virtualizable(parent, 'inst_l2')                
+                promote_virtualizable(parent, 'inst_x')
+                promote_virtualizable(parent, 'inst_l2')
                 parent.inst_l2[0] += parent.inst_x
                 n -= 1
         def f(n):
     # ------------------------------
 
 
-class ImplicitVirtualizableTests:
-
+class ImplicitVirtualizableTests(object):
     def test_simple_implicit(self):
         myjitdriver = JitDriver(greens = [], reds = ['frame'],
                                 virtualizables = ['frame'])
             def __init__(self, l, s):
                 self.l = l
                 self.s = s
-        
+
         def f(n, a):
-            frame = Frame([a,a+1,a+2,a+3], 0)
+            frame = Frame([a, a+1, a+2, a+3], 0)
             x = 0
             while n > 0:
                 myjitdriver.can_enter_jit(frame=frame, n=n, x=x)
 
         def f(n):
             BaseFrame([])     # hack to force 'x' to be in BaseFrame
-            frame = Frame([1,2,3])
+            frame = Frame([1, 2, 3])
             z = 0
             while n > 0:
                 jitdriver.can_enter_jit(frame=frame, n=n, z=z)
     def test_external_read(self):
         jitdriver = JitDriver(greens = [], reds = ['frame'],
                               virtualizables = ['frame'])
-        
+
         class Frame(object):
             _virtualizable2_ = ['x', 'y']
+
         class SomewhereElse:
             pass
         somewhere_else = SomewhereElse()
     def test_external_read_with_exception(self):
         jitdriver = JitDriver(greens = [], reds = ['frame'],
                               virtualizables = ['frame'])
-        
+
         class Frame(object):
             _virtualizable2_ = ['x', 'y']
+
         class SomewhereElse:
             pass
         somewhere_else = SomewhereElse()
 
         class Frame(object):
             _virtualizable2_ = ['x', 'y']
+
         class SomewhereElse:
             pass
         somewhere_else = SomewhereElse()
     def test_external_read_sometimes(self):
         jitdriver = JitDriver(greens = [], reds = ['frame'],
                               virtualizables = ['frame'])
-        
+
         class Frame(object):
             _virtualizable2_ = ['x', 'y']
+
         class SomewhereElse:
             pass
         somewhere_else = SomewhereElse()
     def test_external_read_sometimes_with_virtuals(self):
         jitdriver = JitDriver(greens = [], reds = ['frame'],
                               virtualizables = ['frame'])
-        
+
         class Frame(object):
             _virtualizable2_ = ['x', 'y']
+
         class Y:
             pass
+
         class SomewhereElse:
             pass
         somewhere_else = SomewhereElse()
     def test_external_read_sometimes_changing_virtuals(self):
         jitdriver = JitDriver(greens = [], reds = ['frame'],
                               virtualizables = ['frame'])
-        
+
         class Frame(object):
             _virtualizable2_ = ['x', 'y']
+
         class Y:
             pass
+
         class SomewhereElse:
             pass
         somewhere_else = SomewhereElse()
     def test_external_read_sometimes_with_exception(self):
         jitdriver = JitDriver(greens = [], reds = ['frame'],
                               virtualizables = ['frame'])
-        
+
         class Frame(object):
             _virtualizable2_ = ['x', 'y']
+
         class FooBarError(Exception):
             pass
+
         class SomewhereElse:
             pass
         somewhere_else = SomewhereElse()
     def test_external_read_sometimes_dont_compile_guard(self):
         jitdriver = JitDriver(greens = [], reds = ['frame'],
                               virtualizables = ['frame'])
-        
+
         class Frame(object):
             _virtualizable2_ = ['x', 'y']
+
         class SomewhereElse:
             pass
         somewhere_else = SomewhereElse()
     def test_external_read_sometimes_recursive(self):
         jitdriver = JitDriver(greens = [], reds = ['rec', 'frame'],
                               virtualizables = ['frame'])
-        
+
         class Frame(object):
             _virtualizable2_ = ['x', 'y']
+
         class SomewhereElse:
             pass
         somewhere_else = SomewhereElse()
     def test_external_write_sometimes(self):
         jitdriver = JitDriver(greens = [], reds = ['frame'],
                               virtualizables = ['frame'])
-        
+
         class Frame(object):
             _virtualizable2_ = ['x', 'y']
+
         class SomewhereElse:
             pass
         somewhere_else = SomewhereElse()
     def test_bridge_forces(self):
         jitdriver = JitDriver(greens = [], reds = ['frame'],
                               virtualizables = ['frame'])
-        
+
         class Frame(object):
             _virtualizable2_ = ['x', 'y']
+
         class SomewhereElse:
             pass
         somewhere_else = SomewhereElse()
     def test_promote_index_in_virtualizable_list(self):
         jitdriver = JitDriver(greens = [], reds = ['n', 'frame'],
                               virtualizables = ['frame'])
+
         class Frame(object):
             _virtualizable2_ = ['stackpos', 'stack[*]']
 
         assert direct_calls(f_graph) == ['__init__',
                                          'force_virtualizable_if_necessary',
                                          'll_portal_runner']
-        assert direct_calls(portal_graph)==['force_virtualizable_if_necessary',
-                                            'maybe_enter_jit']
+        assert direct_calls(portal_graph) == ['force_virtualizable_if_necessary',
+                                              'maybe_enter_jit']
         assert direct_calls(init_graph) == []
 
     def test_virtual_child_frame(self):
         somewhere_else = SomewhereElse()
 
         def jump_back(frame, fail):
-            myjitdriver.can_enter_jit(frame=frame, fail=fail)            
+            myjitdriver.can_enter_jit(frame=frame, fail=fail)
 
         def f(n, fail):
             frame = Frame(n, 0)
                 f.x -= 1
                 result += indirection(f)
             return result
+
         def indirection(arg):
             return interp(arg)
+
         def run_interp(n):
             f = hint(Frame(n), access_directly=True)
             return interp(f)
         assert res == run_interp(4)
 
     def test_guard_failure_in_inlined_function(self):
-        from rpython.rtyper.annlowlevel import hlstr
+
         class Frame(object):
             _virtualizable2_ = ['n', 'next']
 
                     assert 0
                 pc += 1
             return frame.n
+
         def main(n):
             frame = Frame(n)
             return f("c-l", frame)
-        print main(100)
         res = self.meta_interp(main, [100], inline=True, enable_opts='')
+        assert res == main(100)
 
     def test_stuff_from_backend_test(self):
         class Thing(object):
         driver = JitDriver(greens = ['codeno'], reds = ['i', 'frame'],
                            virtualizables = ['frame'],
                            get_printable_location = lambda codeno: str(codeno))
+
         class SomewhereElse(object):
             pass
 
         print hex(res)
         assert res == main(0)
 
+    def test_force_virtualref_to_virtualizable(self):
+        jitdriver = JitDriver(
+            greens=[],
+            reds=['i', 'n', 'f', 'f_ref'],
+            virtualizables=['f']
+        )
+
+        class Frame(object):
+            _virtualizable2_ = ['x']
+
+        def main(n):
+            f = Frame()
+            f.x = 1
+            f_ref = virtual_ref(f)
+            i = 0
+            while i < n:
+                jitdriver.jit_merge_point(f=f, i=i, f_ref=f_ref, n=n)
+                i += f_ref().x
+            return i
+
+        res = self.meta_interp(main, [10])
+        assert res == main(10)
+        self.check_resops({
+            "getfield_gc": 1, "int_lt": 2, "ptr_eq": 1, "guard_true": 3,
+            "int_add": 2, "jump": 1
+        })
+
 
 class TestOOtype(#ExplicitVirtualizableTests,
                  ImplicitVirtualizableTests,
                  OOJitMixin):
     pass
 
-        
+
 class TestLLtype(ExplicitVirtualizableTests,
                  ImplicitVirtualizableTests,
                  LLJitMixin):