1. Pypy
  2. Untitled project
  3. pypy

Commits

Carl Friedrich Bolz  committed ba5a9e3

remove some oopspecs in rdict to make the JIT trace the hash functions in
dicts. this makes it necessary to hide some interior field manipulation in a
helper function.

  • Participants
  • Parent commits 9f98d44
  • Branches default

Comments (0)

Files changed (3)

File pypy/jit/codewriter/support.py

View file
  • Ignore whitespace
         return ll_rdict.ll_newdict(DICT)
     _ll_0_newdict.need_result_type = True
 
-    _ll_2_dict_getitem = ll_rdict.ll_dict_getitem
-    _ll_3_dict_setitem = ll_rdict.ll_dict_setitem
     _ll_2_dict_delitem = ll_rdict.ll_dict_delitem
-    _ll_3_dict_setdefault = ll_rdict.ll_setdefault
-    _ll_2_dict_contains = ll_rdict.ll_contains
-    _ll_3_dict_get = ll_rdict.ll_get
     _ll_1_dict_copy = ll_rdict.ll_copy
     _ll_1_dict_clear = ll_rdict.ll_clear
     _ll_2_dict_update = ll_rdict.ll_update

File pypy/jit/metainterp/test/test_dict.py

View file
  • Ignore whitespace
 import py
 from pypy.jit.metainterp.test.test_basic import LLJitMixin, OOJitMixin
 from pypy.rlib.jit import JitDriver
+from pypy.rlib import objectmodel
 
 class DictTests:
 
             res = self.meta_interp(f, [10], listops=True)
             assert res == expected
 
+    def test_dict_trace_hash(self):
+        myjitdriver = JitDriver(greens = [], reds = ['total', 'dct'])
+        def key(x):
+            return x % 2
+        def eq(x, y):
+            return (x % 2) == (y % 2)
+
+        def f(n):
+            dct = objectmodel.r_dict(eq, key)
+            total = n
+            while total:
+                myjitdriver.jit_merge_point(total=total, dct=dct)
+                if total not in dct:
+                    dct[total] = []
+                dct[total].append(total)
+                total -= 1
+            return len(dct[0])
+
+        res1 = f(100)
+        res2 = self.meta_interp(f, [100], listops=True)
+        assert res1 == res2
+        self.check_loops(int_mod=1) # the hash was traced
+
+    def test_dict_setdefault(self):
+        myjitdriver = JitDriver(greens = [], reds = ['total', 'dct'])
+        def f(n):
+            dct = {}
+            total = n
+            while total:
+                myjitdriver.jit_merge_point(total=total, dct=dct)
+                dct.setdefault(total % 2, []).append(total)
+                total -= 1
+            return len(dct[0])
+
+        assert f(100) == 50
+        res = self.meta_interp(f, [100], listops=True)
+        assert res == 50
+        self.check_loops(new=0, new_with_vtable=0)
+
+    def test_dict_as_counter(self):
+        myjitdriver = JitDriver(greens = [], reds = ['total', 'dct'])
+        def key(x):
+            return x % 2
+        def eq(x, y):
+            return (x % 2) == (y % 2)
+
+        def f(n):
+            dct = objectmodel.r_dict(eq, key)
+            total = n
+            while total:
+                myjitdriver.jit_merge_point(total=total, dct=dct)
+                dct[total] = dct.get(total, 0) + 1
+                total -= 1
+            return dct[0]
+
+        assert f(100) == 50
+        res = self.meta_interp(f, [100], listops=True)
+        assert res == 50
+        self.check_loops(int_mod=1)
+
 
 class TestOOtype(DictTests, OOJitMixin):
     pass

File pypy/rpython/lltypesystem/rdict.py

View file
  • Ignore whitespace
 from pypy.rlib.rarithmetic import r_uint, intmask, LONG_BIT
 from pypy.rlib.objectmodel import hlinvoke
 from pypy.rpython import robject
-from pypy.rlib import objectmodel
+from pypy.rlib import objectmodel, jit
 from pypy.rpython import rmodel
 
 HIGHEST_BIT = intmask(1 << (LONG_BIT - 1))
     ENTRIES = lltype.typeOf(entries).TO
     return ENTRIES.fasthashfn(entries[i].key)
 
+@jit.dont_look_inside
+def ll_get_value(d, i):
+    return d.entries[i].value
+
 def ll_keyhash_custom(d, key):
     DICT = lltype.typeOf(d).TO
     return hlinvoke(DICT.r_rdict_hashfn, d.fnkeyhash, key)
 def ll_dict_getitem(d, key):
     i = ll_dict_lookup(d, key, d.keyhash(key))
     if not i & HIGHEST_BIT:
-        return d.entries[i].value
+        return ll_get_value(d, i)
     else:
         raise KeyError
-ll_dict_getitem.oopspec = 'dict.getitem(d, key)'
 
 def ll_dict_setitem(d, key, value):
     hash = d.keyhash(key)
     i = ll_dict_lookup(d, key, hash)
     return _ll_dict_setitem_lookup_done(d, key, value, hash, i)
-ll_dict_setitem.oopspec = 'dict.setitem(d, key, value)'
 
+@jit.dont_look_inside
 def _ll_dict_setitem_lookup_done(d, key, value, hash, i):
     valid = (i & HIGHEST_BIT) == 0
     i = i & MASK
 
 def ll_get(dict, key, default):
     i = ll_dict_lookup(dict, key, dict.keyhash(key))
-    entries = dict.entries
     if not i & HIGHEST_BIT:
-        return entries[i].value
+        return ll_get_value(dict, i)
     else:
         return default
-ll_get.oopspec = 'dict.get(dict, key, default)'
 
 def ll_setdefault(dict, key, default):
     hash = dict.keyhash(key)
     i = ll_dict_lookup(dict, key, hash)
-    entries = dict.entries
     if not i & HIGHEST_BIT:
-        return entries[i].value
+        return ll_get_value(dict, i)
     else:
         _ll_dict_setitem_lookup_done(dict, key, default, hash, i)
         return default
-ll_setdefault.oopspec = 'dict.setdefault(dict, key, default)'
 
 def ll_copy(dict):
     DICT = lltype.typeOf(dict).TO
 def ll_contains(d, key):
     i = ll_dict_lookup(d, key, d.keyhash(key))
     return not i & HIGHEST_BIT
-ll_contains.oopspec = 'dict.contains(d, key)'
 
 POPITEMINDEX = lltype.Struct('PopItemIndex', ('nextindex', lltype.Signed))
 global_popitem_index = lltype.malloc(POPITEMINDEX, zero=True, immortal=True)