1. idank
  2. sqlalchemy

Commits

jason kirtland  committed d4f4399

Fixed in-place set mutation operator support [ticket:920]

  • Participants
  • Parent commits 6a48b88
  • Branches default

Comments (0)

Files changed (5)

File CHANGES

View file
 0.4.3
 -----
 - orm
-    - added very rudimentary yielding iterator behavior to Query.  Call
-      query.yield_per(<number of rows>) and evaluate the Query in an 
+    - Added very rudimentary yielding iterator behavior to Query.  Call
+      query.yield_per(<number of rows>) and evaluate the Query in an
       iterative context; every collection of N rows will be packaged up
-      and yielded.  Use this method with extreme caution since it does 
+      and yielded.  Use this method with extreme caution since it does
       not attempt to reconcile eagerly loaded collections across
       result batch boundaries, nor will it behave nicely if the same
-      instance occurs in more than one batch.  This means that an eagerly 
+      instance occurs in more than one batch.  This means that an eagerly
       loaded collection will get cleared out if it's referenced in more than
       one batch, and in all cases attributes will be overwritten on instances
       that occur in more than one batch.
 
+   - Fixed in-place set mutation operators for set collections and association
+     proxied sets. [ticket:920]
+
 - dialects
-
-    - PostgreSQL
-       - Fixed the missing call to subtype result processor for the PGArray
-         type. [ticket:913]
+    - Fixed the missing call to subtype result processor for the PGArray
+      type. [ticket:913]
 
 0.4.2
 -----

File lib/sqlalchemy/ext/associationproxy.py

View file
                 self._scalar_set(target, values)
         else:
             proxy = self.__get__(obj, None)
-            proxy.clear()
-            self._set(proxy, values)
+            if proxy is not values:
+                proxy.clear()
+                self._set(proxy, values)
 
     def __delete__(self, obj):
         delattr(obj, self.key)
         for value in other:
             self.add(value)
 
-    __ior__ = update
+    def __ior__(self, other):
+        if util.duck_type_collection(other) is not set:
+            return NotImplemented
+        for value in other:
+            self.add(value)
+        return self
 
     def _set(self):
         return util.Set(iter(self))
         for value in other:
             self.discard(value)
 
-    __isub__ = difference_update
+    def __isub__(self, other):
+        if util.duck_type_collection(other) is not set:
+            return NotImplemented
+        for value in other:
+            self.discard(value)
+        return self
 
     def intersection(self, other):
         return util.Set(self).intersection(other)
         for value in add:
             self.add(value)
 
-    __iand__ = intersection_update
+    def __iand__(self, other):
+        if util.duck_type_collection(other) is not set:
+            return NotImplemented
+        want, have = self.intersection(other), util.Set(self)
+
+        remove, add = have - want, want - have
+
+        for value in remove:
+            self.remove(value)
+        for value in add:
+            self.add(value)
+        return self
 
     def symmetric_difference(self, other):
         return util.Set(self).symmetric_difference(other)
         for value in add:
             self.add(value)
 
-    __ixor__ = symmetric_difference_update
+    def __ixor__(self, other):
+        if util.duck_type_collection(other) is not set:
+            return NotImplemented
+        want, have = self.symmetric_difference(other), util.Set(self)
+
+        remove, add = have - want, want - have
+
+        for value in remove:
+            self.remove(value)
+        for value in add:
+            self.add(value)
+        return self
 
     def issubset(self, other):
         return util.Set(self).issubset(other)

File lib/sqlalchemy/orm/collections.py

View file
                     self.add(item)
         _tidy(update)
         return update
-    __ior__ = update
+
+    def __ior__(fn):
+        def __ior__(self, value):
+            if sautil.duck_type_collection(value) is not set:
+                return NotImplemented
+            for item in value:
+                if item not in self:
+                    self.add(item)
+            return self
+        _tidy(__ior__)
+        return __ior__
 
     def difference_update(fn):
         def difference_update(self, value):
                 self.discard(item)
         _tidy(difference_update)
         return difference_update
-    __isub__ = difference_update
+
+    def __isub__(fn):
+        def __isub__(self, value):
+            if sautil.duck_type_collection(value) is not set:
+                return NotImplemented
+            for item in value:
+                self.discard(item)
+            return self
+        _tidy(__isub__)
+        return __isub__
 
     def intersection_update(fn):
         def intersection_update(self, other):
                 self.add(item)
         _tidy(intersection_update)
         return intersection_update
-    __iand__ = intersection_update
+
+    def __iand__(fn):
+        def __iand__(self, other):
+            if sautil.duck_type_collection(other) is not set:
+                return NotImplemented
+            want, have = self.intersection(other), sautil.Set(self)
+            remove, add = have - want, want - have
+
+            for item in remove:
+                self.remove(item)
+            for item in add:
+                self.add(item)
+            return self
+        _tidy(__iand__)
+        return __iand__
 
     def symmetric_difference_update(fn):
         def symmetric_difference_update(self, other):
                 self.add(item)
         _tidy(symmetric_difference_update)
         return symmetric_difference_update
-    __ixor__ = symmetric_difference_update
+
+    def __ixor__(fn):
+        def __ixor__(self, other):
+            if sautil.duck_type_collection(other) is not set:
+                return NotImplemented
+            want, have = self.symmetric_difference(other), sautil.Set(self)
+            remove, add = have - want, want - have
+
+            for item in remove:
+                self.remove(item)
+            for item in add:
+                self.add(item)
+            return self
+        _tidy(__ixor__)
+        return __ixor__
 
     l = locals().copy()
     l.pop('_tidy')

File test/ext/associationproxy.py

View file
                         print 'got', repr(p.children)
                         raise
 
+        # in-place mutations
+        for op in ('|=', '-=', '&=', '^='):
+            for base in (['a', 'b', 'c'], []):
+                for other in (set(['a','b','c']), set(['a','b','c','d']),
+                              set(['a']), set(['a','b']),
+                              set(['c','d']), set(['e', 'f', 'g']),
+                              set()):
+                    p = Parent('p')
+                    p.children = base[:]
+                    control = set(base[:])
+
+                    exec "p.children %s other" % op
+                    exec "control %s other" % op
+
+                    try:
+                        self.assert_(p.children == control)
+                    except:
+                        print 'Test %s %s %s:' % (set(base), op, other)
+                        print 'want', repr(control)
+                        print 'got', repr(p.children)
+                        raise
+
+                    p = self.roundtrip(p)
+
+                    try:
+                        self.assert_(p.children == control)
+                    except:
+                        print 'Test %s %s %s:' % (base, op, other)
+                        print 'want', repr(control)
+                        print 'got', repr(p.children)
+                        raise
+
 
 class CustomSetTest(SetTest):
     def __init__(self, *args, **kw):

File test/orm/collection.py

View file
 
         adapter.append_with_event(e1)
         assert_eq()
-        
+
         adapter.append_without_event(e2)
         assert_ne()
         canary.data.add(e2)
         assert_eq()
-        
+
         adapter.remove_without_event(e2)
         assert_ne()
         canary.data.remove(e2)
     def _test_list(self, typecallable, creator=entity_maker):
         class Foo(object):
             pass
-        
+
         canary = Canary()
         attributes.register_class(Foo)
         attributes.register_attribute(Foo, 'attr', True, extension=canary,
             self.assert_(set(direct) == canary.data)
             self.assert_(set(adapter) == canary.data)
             self.assert_(direct == control)
-        
+
         # assume append() is available for list tests
         e = creator()
         direct.append(e)
             e = creator()
             direct.append(e)
             control.append(e)
-            
+
             e = creator()
             direct[0] = e
             control[0] = e
             e = creator()
             direct.append(e)
             control.append(e)
-            
+
             direct.remove(e)
             control.remove(e)
             assert_eq()
             direct[1::2] = values
             control[1::2] = values
             assert_eq()
-            
+
         if hasattr(direct, '__delslice__'):
             for i in range(1, 4):
                 e = creator()
                 control.append(e)
 
             del direct[-1:]
-            del control[-1:] 
+            del control[-1:]
             assert_eq()
 
             del direct[1:2]
                 return self.data == other
             def __repr__(self):
                 return 'ListLike(%s)' % repr(self.data)
-            
+
         self._test_adapter(ListLike)
         self._test_list(ListLike)
         self._test_list_bulk(ListLike)
                 return self.data == other
             def __repr__(self):
                 return 'ListIsh(%s)' % repr(self.data)
-            
+
         self._test_adapter(ListIsh)
         self._test_list(ListIsh)
         self._test_list_bulk(ListIsh)
             for item in list(direct):
                 direct.remove(item)
             control.clear()
-        
+
         # assume add() is available for list tests
         addall(creator())
 
             direct.discard(e)
             self.assert_(e not in canary.removed)
             assert_eq()
-            
+
         if hasattr(direct, 'update'):
+            zap()
             e = creator()
             addall(e)
-            
+
             values = set([e, creator(), creator()])
 
             direct.update(values)
             control.update(values)
             assert_eq()
 
+        if hasattr(direct, '__ior__'):
+            zap()
+            e = creator()
+            addall(e)
+
+            values = set([e, creator(), creator()])
+
+            direct |= values
+            control |= values
+            assert_eq()
+
+            try:
+                direct |= [e, creator()]
+                assert False
+            except TypeError:
+                assert True
+
         if hasattr(direct, 'clear'):
             addall(creator(), creator())
             direct.clear()
 
         if hasattr(direct, 'difference_update'):
             zap()
+            e = creator()
             addall(creator(), creator())
             values = set([creator()])
 
             control.difference_update(values)
             assert_eq()
 
+        if hasattr(direct, '__isub__'):
+            zap()
+            e = creator()
+            addall(creator(), creator())
+            values = set([creator()])
+
+            direct -= values
+            control -= values
+            assert_eq()
+            values.update(set([e, creator()]))
+            direct -= values
+            control -= values
+            assert_eq()
+
+            try:
+                direct -= [e, creator()]
+                assert False
+            except TypeError:
+                assert True
+
         if hasattr(direct, 'intersection_update'):
             zap()
             e = creator()
             control.intersection_update(values)
             assert_eq()
 
+        if hasattr(direct, '__iand__'):
+            zap()
+            e = creator()
+            addall(e, creator(), creator())
+            values = set(control)
+
+            direct &= values
+            control &= values
+            assert_eq()
+
+            values.update(set([e, creator()]))
+            direct &= values
+            control &= values
+            assert_eq()
+
+            try:
+                direct &= [e, creator()]
+                assert False
+            except TypeError:
+                assert True
+
         if hasattr(direct, 'symmetric_difference_update'):
             zap()
             e = creator()
             control.symmetric_difference_update(values)
             assert_eq()
 
+        if hasattr(direct, '__ixor__'):
+            zap()
+            e = creator()
+            addall(e, creator(), creator())
+
+            values = set([e, creator()])
+            direct ^= values
+            control ^= values
+            assert_eq()
+
+            e = creator()
+            addall(e)
+            values = set([e])
+            direct ^= values
+            control ^= values
+            assert_eq()
+
+            values = set()
+            direct ^= values
+            control ^= values
+            assert_eq()
+
+            try:
+                direct ^= [e, creator()]
+                assert False
+            except TypeError:
+                assert True
+
     def _test_set_bulk(self, typecallable, creator=entity_maker):
         class Foo(object):
             pass
         self.assert_(obj.attr == set([e2]))
         self.assert_(e1 in canary.removed)
         self.assert_(e2 in canary.added)
- 
+
         e3 = creator()
         real_set = set([e3])
         obj.attr = real_set
         self.assert_(obj.attr == set([e3]))
         self.assert_(e2 in canary.removed)
         self.assert_(e3 in canary.added)
-       
+
         e4 = creator()
         try:
             obj.attr = [e4]
             for item in list(adapter):
                 direct.remove(item)
             control.clear()
-        
+
         # assume an 'set' method is available for tests
         addall(creator())
 
             direct.clear()
             control.clear()
             assert_eq()
-            
+
             direct.clear()
             control.clear()
             assert_eq()
             zap()
             e = creator()
             addall(e)
-            
+
             direct.popitem()
             control.popitem()
             assert_eq()
     def _test_object(self, typecallable, creator=entity_maker):
         class Foo(object):
             pass
-        
+
         canary = Canary()
         attributes.register_class(Foo)
         attributes.register_attribute(Foo, 'attr', True, extension=canary,
         direct.zark(e)
         control.remove(e)
         assert_eq()
-        
+
         e = creator()
         direct.maybe_zark(e)
         control.discard(e)
             @collection.removes_return()
             def pop(self, key):
                 return self.data.pop()
-            
+
             @collection.iterator
             def __iter__(self):
                 return iter(self.data)
         col1.append(e3)
         self.assert_(e3 not in canary.data)
         self.assert_(collections.collection_adapter(col1) is None)
-        
+
         obj.attr[0] = e3
         self.assert_(e3 in canary.data)
 
 class DictHelpersTest(ORMTest):
     def define_tables(self, metadata):
         global parents, children, Parent, Child
-        
+
         parents = Table('parents', metadata,
                         Column('id', Integer, primary_key=True),
                         Column('label', String))
             'children': relation(Child, collection_class=collection_class,
                                  cascade="all, delete-orphan")
             })
-        
+
         p = Parent()
         p.children['foo'] = Child('foo', 'value')
         p.children['bar'] = Child('bar', 'value')
 
         collections.collection_adapter(p.children).append_with_event(
             Child('foo', 'newvalue'))
-        
+
         session.flush()
         session.clear()
-        
+
         p = session.query(Parent).get(pid)
-        
+
         self.assert_(set(p.children.keys()) == set(['foo', 'bar']))
         self.assert_(p.children['foo'].id != cid)
-        
+
         self.assert_(len(list(collections.collection_adapter(p.children))) == 2)
         session.flush()
         session.clear()
 
         collections.collection_adapter(p.children).remove_with_event(
             p.children['foo'])
-        
+
         self.assert_(len(list(collections.collection_adapter(p.children))) == 1)
         session.flush()
         session.clear()
 
         p = session.query(Parent).get(pid)
         self.assert_(len(list(collections.collection_adapter(p.children))) == 0)
-        
+
 
     def _test_composite_mapped(self, collection_class):
         mapper(Child, children)
             'children': relation(Child, collection_class=collection_class,
                                  cascade="all, delete-orphan")
             })
-        
+
         p = Parent()
         p.children[('foo', '1')] = Child('foo', '1', 'value 1')
         p.children[('foo', '2')] = Child('foo', '2', 'value 2')
         session.flush()
         pid = p.id
         session.clear()
-        
+
         p = session.query(Parent).get(pid)
 
         self.assert_(set(p.children.keys()) == set([('foo', '1'), ('foo', '2')]))
 
         collections.collection_adapter(p.children).append_with_event(
             Child('foo', '1', 'newvalue'))
-        
+
         session.flush()
         session.clear()
-        
+
         p = session.query(Parent).get(pid)
-        
+
         self.assert_(set(p.children.keys()) == set([('foo', '1'), ('foo', '2')]))
         self.assert_(p.children[('foo', '1')].id != cid)
-        
+
         self.assert_(len(list(collections.collection_adapter(p.children))) == 2)
-        
+
     def test_mapped_collection(self):
         collection_class = collections.mapped_collection(lambda c: c.a)
         self._test_scalar_mapped(collection_class)