Commits

Simon Cross committed ade9aa8

(fijal, hodgestar) Unroll only one iteration of the loop.

  • Participants
  • Parent commits 6376527
  • Branches inline-simple-generators

Comments (0)

Files changed (2)

pypy/jit/metainterp/pyjitpl.py

         # for resume.py operation
         self.parent_resumedata_snapshot = None
         self.parent_resumedata_frame_info_list = None
+        # counter for unrolling inlined loops
+        self.unroll_iterations = 1
 
     @specialize.arg(3)
     def copy_constants(self, registers, constants, ConstClass):
             # close the loop.  We have to put the possibly-modified list
             # 'redboxes' back into the registers where it comes from.
             put_back_list_of_boxes3(self, jcposition, redboxes)
-        elif jitdriver_sd.warmstate.should_unroll_one_iteration(greenboxes):
-            return
         else:
+            if jitdriver_sd.warmstate.should_unroll_one_iteration(greenboxes):
+                if self.unroll_iterations > 0:
+                    self.unroll_iterations -= 1
+                    return
             # warning! careful here.  We have to return from the current
             # frame containing the jit_merge_point, and then use
             # do_recursive_call() to follow the recursive call.  This is

pypy/jit/metainterp/test/test_ajit.py

     def test_unroll_one_loop_iteration(self):
         def unroll(x):
             return x == 0
-        myjitdriver = JitDriver(greens = ['x'], reds = ['y'], should_unroll_one_iteration=unroll)
+        myjitdriver = JitDriver(greens = ['code'],
+                                reds = ['loops', 'inner_loops', 's'],
+                                should_unroll_one_iteration=unroll)
 
-        def f(x, y):
-            while y > 0:
-                myjitdriver.jit_merge_point(x=x, y=y)
-                if x == 0:
-                    return y
-                f(0, 4)
-                y -= 1
-            return 0
+        def f(code, loops, inner_loops):
+            s = 0
+            while loops > 0:
+                myjitdriver.jit_merge_point(code=code, loops=loops,
+                                            inner_loops=inner_loops, s=s)
+                if code == 1:
+                    s += f(0, inner_loops, 0)
+                loops -= 1
+                s += 1
+            return s
 
-        res = self.meta_interp(f, [1, 4], enable_opts="", inline=True)
+        res = self.meta_interp(f, [1, 4, 1], enable_opts="", inline=True)
+        assert res == f(1, 4, 1)
         self.check_history(call_assembler=0)
 
+        res = self.meta_interp(f, [1, 4, 2], enable_opts="", inline=True)
+        assert res == f(1, 4, 2)
+        self.check_history(call_assembler=1)
+
     def test_format(self):
         def f(n):
             return len("<%d>" % n)