Commits

Lukas Diekmann committed 90dde1a

Added fallback tests for Module- and StringDictStrategy and implemented corresponding methods

  • Participants
  • Parent commits 50273dd
  • Branches dict-strategies

Comments (0)

Files changed (3)

File pypy/objspace/std/celldict.py

                 cell.w_value = w_default
             return cell.w_value
         else:
-            return self._as_rdict().impl_fallback_setdefault(w_key, w_default)
+            self.switch_to_object_strategy(w_dict)
+            return w_dict.setdefault(w_key, w_default)
 
     def delitem(self, w_dict, w_key):
         space = self.space
         elif _is_sane_hash(space, w_key_type):
             raise KeyError
         else:
-            self._as_rdict().impl_fallback_delitem(w_key)
+            self.switch_to_object_strategy(w_dict)
+            w_dict.delitem(w_key)
 
     def length(self, w_dict):
         # inefficient, but do we care?

File pypy/objspace/std/dictmultiobject.py

                                    instance=False, classofinstance=None,
                                    from_strdict_shared=None, strdict=False):
         if from_strdict_shared is not None:
-            assert w_type is None
             assert not module and not instance and classofinstance is None
             strategy = space.fromcache(StringDictStrategy)
             storage = strategy.cast_to_void_star(from_strdict_shared)
-            w_type = space.w_dict #XXX is this right?
-            w_self = space.allocate_instance(W_DictMultiObject, w_type)
-            W_DictMultiObject.__init__(w_self, space, strategy, storage)
+            w_self = W_DictMultiObject(space, strategy, storage)
             return w_self
 
         if space.config.objspace.std.withcelldict and module:
         elif _is_sane_hash(space, w_key_type):
             raise KeyError
         else:
-            self._as_rdict().impl_fallback_delitem(w_key)
+            self.switch_to_object_strategy(w_dict)
+            return w_dict.delitem(w_key)
 
     def length(self, w_dict):
         return len(self.cast_from_void_star(w_dict.dstorage))
         elif _is_sane_hash(space, w_lookup_type):
             return None
         else:
-            return self._as_rdict().impl_fallback_getitem(w_key)
+            self.switch_to_object_strategy(w_dict)
+            return w_dict.getitem(w_key)
 
     def iter(self, w_dict):
         return StrIteratorImplementation(self.space, w_dict)

File pypy/objspace/std/test/test_dictmultiobject.py

         d = type(__builtins__)("abc").__dict__
         raises(KeyError, "d['def']")
 
-    def test_fallback_getitem(self):
+    def test_fallback_evil_key(self):
         class F(object):
             def __hash__(self):
                 return hash("s")
         assert d["s"] == 12
         assert d[F()] == d["s"]
 
-    #XXX tests for fallbacks setdefault, delitem
+        d = type(__builtins__)("abc").__dict__
+        x = d.setdefault("s", 12)
+        assert x == 12
+        x = d.setdefault(F(), 12)
+        assert x == 12
+
+        d = type(__builtins__)("abc").__dict__
+        x = d.setdefault(F(), 12)
+        assert x == 12
+
+        d = type(__builtins__)("abc").__dict__
+        d["s"] = 12
+        del d[F()]
+
+        assert "s" not in d
+        assert F() not in d
+
 
 class FakeString(str):
     hash_count = 0
 
     w_StopIteration = StopIteration
     w_None = None
+    w_NoneType = type(None, None)
+    w_int = int
+    w_bool = bool
+    w_float = float
     StringObjectCls = FakeString
     w_dict = W_DictMultiObject
     iter = iter
         if on_pypy:
             assert key.hash_count == 2
 
+    def test_fallback_evil_key(self):
+        class F(object):
+            def __hash__(self):
+                return hash("s")
+            def __eq__(self, other):
+                return other == "s"
+
+        d = self.get_impl()
+        d.setitem("s", 12)
+        assert d.getitem("s") == 12
+        assert d.getitem(F()) == d.getitem("s")
+
+        d = self.get_impl()
+        x = d.setdefault("s", 12)
+        assert x == 12
+        x = d.setdefault(F(), 12)
+        assert x == 12
+
+        d = self.get_impl()
+        x = d.setdefault(F(), 12)
+        assert x == 12
+
+        d = self.get_impl()
+        d.setitem("s", 12)
+        d.delitem(F())
+
+        assert "s" not in d.keys()
+        assert F() not in d.keys()
+
 class TestStrDictImplementation(BaseTestRDictImplementation):
     StrategyClass = StringDictStrategy
     #ImplementionClass = StrDictImplementation
         assert self.impl.getitem(s) == 1000
         assert s.unwrapped
 
-    #XXX add tests for fallback getitem, delitem
-
 ## class TestMeasuringDictImplementation(BaseTestRDictImplementation):
 ##     ImplementionClass = MeasuringDictImplementation
 ##     DevolvedClass = MeasuringDictImplementation