Commits

Antoine Pitrou  committed d19d9ec

Fix some set algebra methods of WeakSet objects.

  • Participants
  • Parent commits 5a2bc0d
  • Branches 3.2

Comments (0)

Files changed (2)

File Lib/_weakrefset.py

         self.update(other)
         return self
 
-    # Helper functions for simple delegating methods.
-    def _apply(self, other, method):
-        if not isinstance(other, self.__class__):
-            other = self.__class__(other)
-        newdata = method(other.data)
-        newset = self.__class__()
-        newset.data = newdata
+    def difference(self, other):
+        newset = self.copy()
+        newset.difference_update(other)
         return newset
-
-    def difference(self, other):
-        return self._apply(other, self.data.difference)
     __sub__ = difference
 
     def difference_update(self, other):
-        if self._pending_removals:
-            self._commit_removals()
-        if self is other:
-            self.data.clear()
-        else:
-            self.data.difference_update(ref(item) for item in other)
+        self.__isub__(other)
     def __isub__(self, other):
         if self._pending_removals:
             self._commit_removals()
         return self
 
     def intersection(self, other):
-        return self._apply(other, self.data.intersection)
+        return self.__class__(item for item in other if item in self)
     __and__ = intersection
 
     def intersection_update(self, other):
-        if self._pending_removals:
-            self._commit_removals()
-        self.data.intersection_update(ref(item) for item in other)
+        self.__iand__(other)
     def __iand__(self, other):
         if self._pending_removals:
             self._commit_removals()
         return self.data == set(ref(item) for item in other)
 
     def symmetric_difference(self, other):
-        return self._apply(other, self.data.symmetric_difference)
+        newset = self.copy()
+        newset.symmetric_difference_update(other)
+        return newset
     __xor__ = symmetric_difference
 
     def symmetric_difference_update(self, other):
-        if self._pending_removals:
-            self._commit_removals()
-        if self is other:
-            self.data.clear()
-        else:
-            self.data.symmetric_difference_update(ref(item) for item in other)
+        self.__ixor__(other)
     def __ixor__(self, other):
         if self._pending_removals:
             self._commit_removals()
         if self is other:
             self.data.clear()
         else:
-            self.data.symmetric_difference_update(ref(item) for item in other)
+            self.data.symmetric_difference_update(ref(item, self._remove) for item in other)
         return self
 
     def union(self, other):
-        return self._apply(other, self.data.union)
+        return self.__class__(e for s in (self, other) for e in s)
     __or__ = union
 
     def isdisjoint(self, other):

File Lib/test/test_weakset.py

             x = WeakSet(self.items + self.items2)
             c = C(self.items2)
             self.assertEqual(self.s.union(c), x)
+            del c
+        self.assertEqual(len(u), len(self.items) + len(self.items2))
+        self.items2.pop()
+        gc.collect()
+        self.assertEqual(len(u), len(self.items) + len(self.items2))
 
     def test_or(self):
         i = self.s.union(self.items2)
         self.assertEqual(self.s | frozenset(self.items2), i)
 
     def test_intersection(self):
-        i = self.s.intersection(self.items2)
+        s = WeakSet(self.letters)
+        i = s.intersection(self.items2)
         for c in self.letters:
-            self.assertEqual(c in i, c in self.d and c in self.items2)
-        self.assertEqual(self.s, WeakSet(self.items))
+            self.assertEqual(c in i, c in self.items2 and c in self.letters)
+        self.assertEqual(s, WeakSet(self.letters))
         self.assertEqual(type(i), WeakSet)
         for C in set, frozenset, dict.fromkeys, list, tuple:
             x = WeakSet([])
-            self.assertEqual(self.s.intersection(C(self.items2)), x)
+            self.assertEqual(i.intersection(C(self.items)), x)
+        self.assertEqual(len(i), len(self.items2))
+        self.items2.pop()
+        gc.collect()
+        self.assertEqual(len(i), len(self.items2))
 
     def test_isdisjoint(self):
         self.assertTrue(self.s.isdisjoint(WeakSet(self.items2)))
         self.assertEqual(self.s, WeakSet(self.items))
         self.assertEqual(type(i), WeakSet)
         self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
+        self.assertEqual(len(i), len(self.items) + len(self.items2))
+        self.items2.pop()
+        gc.collect()
+        self.assertEqual(len(i), len(self.items) + len(self.items2))
 
     def test_xor(self):
         i = self.s.symmetric_difference(self.items2)