Commits

Indra Talip  committed 1a7fdc3

Fix race condition calling defaultdict#__missing__ for derived classes.

Moves the test for deferring to the subclasses __missing__ to the
CacheLoader#load in order to have atomic behaviour with regards to
calling __missing__.

Adds a test to ThreadSafetyTestCase for subclasses of defaultdict.

  • Participants
  • Parent commits 9c9fe8e

Comments (0)

Files changed (2)

File Lib/test/test_defaultdict_jy.py

         for t in threads:
             self.assertFalse(t.isAlive())
 
+    class Counter(object):
+        def __init__(self, initial=0):
+            self.atomic = AtomicInteger(initial)
+             # waiting is important here to ensure that
+             # defaultdict factories can step on each other
+            time.sleep(0.001)
+
+        def decrementAndGet(self):
+            return self.atomic.decrementAndGet()
+
+        def incrementAndGet(self):
+            return self.atomic.incrementAndGet()
+
+        def get(self):
+            return self.atomic.get()
+
+        def __repr__(self):
+            return "Counter<%s>" % (self.atomic.get())
+
     def test_inc_dec(self):
+        counters = defaultdict(ThreadSafetyTestCase.Counter)
+        size = 17
 
-        class Counter(object):
-            def __init__(self):
-                self.atomic = AtomicInteger()
-                 # waiting is important here to ensure that
-                 # defaultdict factories can step on each other
-                time.sleep(0.001)
-
-            def decrementAndGet(self):
-                return self.atomic.decrementAndGet()
-
-            def incrementAndGet(self):
-                return self.atomic.incrementAndGet()
-
-            def get(self):
-                return self.atomic.get()
-
-            def __repr__(self):
-                return "Counter<%s>" % (self.atomic.get())
-
-        counters = defaultdict(Counter)
-        size = 17
-        
         def tester():
             for i in xrange(1000):
                 j = (i + randint(0, size)) % size
                 counters[j].incrementAndGet()
 
         self.run_threads(tester, 20)
-        
+
         for i in xrange(size):
             self.assertEqual(counters[i].get(), 0, counters)
 
+    def test_derived_inc_dec(self):
+        class DerivedDefaultDict(defaultdict):
+            def __missing__(self, key):
+                if self.default_factory is None:
+                    raise KeyError("Invalid key '{0}' and no default factory was set")
+
+                val = self.default_factory(key)
+
+                self[key] = val
+                return val
+
+        counters = DerivedDefaultDict(lambda key: ThreadSafetyTestCase.Counter(key))
+        size = 17
+
+        def tester():
+            for i in xrange(1000):
+                j = (i + randint(0, size)) % size
+                counters[j].decrementAndGet()
+                time.sleep(0.0001)
+                counters[j].incrementAndGet()
+
+        self.run_threads(tester, 20)
+
+        for i in xrange(size):
+            self.assertEqual(counters[i].get(), i, counters)
+
 class GetVariantsTestCase(unittest.TestCase):
 
     #http://bugs.jython.org/issue2133
         self.assertEquals(d.items(), [("vivify", [])]) 
 
 
-class KeyDefaultDict(defaultdict):
-    """defaultdict to pass the requested key to factory function."""
-    def __missing__(self, key):
-        if self.default_factory is None:
-            raise KeyError("Invalid key '{0}' and no default factory was set")
-        else:
-            val = self.default_factory(key)
-
-        self[key] = val
-        return val
-
-    @classmethod
-    def double(cls, k):
-        return k + k
 
 class OverrideMissingTestCase(unittest.TestCase):
+    class KeyDefaultDict(defaultdict):
+        """defaultdict to pass the requested key to factory function."""
+        def __missing__(self, key):
+            if self.default_factory is None:
+                raise KeyError("Invalid key '{0}' and no default factory was set")
+            else:
+                val = self.default_factory(key)
+
+            self[key] = val
+            return val
+
+        @classmethod
+        def double(cls, k):
+            return k + k
+
+    def setUp(self):
+        self.kdd = OverrideMissingTestCase.KeyDefaultDict(OverrideMissingTestCase.KeyDefaultDict.double)
+
     def test_dont_call_derived_missing(self):
-        kdd = KeyDefaultDict(KeyDefaultDict.double)
-        kdd[3] = 5
-        self.assertEquals(kdd[3], 5)
+        self.kdd[3] = 5
+        self.assertEquals(self.kdd[3], 5)
 
     #http://bugs.jython.org/issue2088
     def test_override_missing(self):
-
-        kdd = KeyDefaultDict(KeyDefaultDict.double)
         # line below causes KeyError in Jython, ignoring overridden __missing__ method
-        self.assertEquals(kdd[3], 6)
-        self.assertEquals(kdd['ab'], 'abab')
+        self.assertEquals(self.kdd[3], 6)
+        self.assertEquals(self.kdd['ab'], 'abab')
 
 
 def test_main():

File src/org/python/modules/_collections/PyDefaultDict.java

         backingMap = CacheBuilder.newBuilder().build(
                 new CacheLoader<PyObject, PyObject>() {
                     public PyObject load(PyObject key) {
+                        PyType self_type = getType();
+                        if (self_type != TYPE) {
+                            // Is a subclass. If it exists call the subclasses __missing__.
+                            // Otherwise PyDefaultDic.defaultdict___missing__() will
+                            // be invoked.
+                            return PyDefaultDict.this.invoke("__missing__", key);
+                        }
+
+                        // in-lined __missing__
                         if (defaultFactory == Py.None) {
                             throw Py.KeyError(key);
                         }
     @ExposedMethod(doc = BuiltinDocs.dict___getitem___doc)
     protected final PyObject defaultdict___getitem__(PyObject key) {
         try {
-            PyType type = getType();
-            if (!getMap().containsKey(key) && type != TYPE) {
-                // is a subclass. if it exists call the subclasses __missing__
-                PyObject missing = type.lookup("__missing__");
-                if (missing != null) {
-                    return missing.__get__(this, type).__call__(key);
-                }
-            }
             return backingMap.get(key);
         } catch (Exception ex) {
             throw Py.KeyError(key);