Commits

Mike Bayer committed 02b0d86

added recursion check to merge

Comments (0)

Files changed (3)

lib/sqlalchemy/orm/properties.py

                     return s
                 return getattr(obj, self.name)
         setattr(self.parent.class_, self.key, SynonymProp())
-    def merge(self, session, source, dest):
+    def merge(self, session, source, dest, _recursive):
         pass
         
 class ColumnProperty(StrategizedProperty):
         setattr(object, self.key, value)
     def get_history(self, obj, passive=False):
         return sessionlib.attribute_manager.get_history(obj, self.key, passive=passive)
-    def merge(self, session, source, dest):
+    def merge(self, session, source, dest, _recursive):
         setattr(dest, self.key, getattr(source, self.key, None))
     def compare(self, value):
         return self.columns[0] == value
     def __str__(self):
         return self.__class__.__name__ + " " + str(self.parent) + "->" + self.key + "->" + str(self.mapper)
 
-    def merge(self, session, source, dest):
-        if not "merge" in self.cascade:
+    def merge(self, session, source, dest, _recursive):
+        if not "merge" in self.cascade or source in _recursive:
             return
-        childlist = sessionlib.attribute_manager.get_history(source, self.key, passive=True)
-        if childlist is None:
-            return
-        if self.uselist:
-            # sets a blank list according to the correct list class
-            dest_list = getattr(self.parent.class_, self.key).initialize(dest)
-            for current in list(childlist):
-                dest_list.append(session.merge(current))
-        else:
-            setattr(dest, self.key, session.merge(current))
-        
+        _recursive.add(source)
+        try:
+            childlist = sessionlib.attribute_manager.get_history(source, self.key, passive=True)
+            if childlist is None:
+                return
+            if self.uselist:
+                # sets a blank list according to the correct list class
+                dest_list = getattr(self.parent.class_, self.key).initialize(dest)
+                for current in list(childlist):
+                    dest_list.append(session.merge(current, _recursive=_recursive))
+            else:
+                current = list(childlist)[0]
+                if current is not None:
+                    setattr(dest, self.key, session.merge(current, _recursive=_recursive))
+        finally:
+            _recursive.remove(source)
+            
     def cascade_iterator(self, type, object, recursive, halt_on=None):
         if not type in self.cascade:
             return

lib/sqlalchemy/orm/session.py

         for c in [object] + list(_object_mapper(object).cascade_iterator('delete', object)):
             self.uow.register_deleted(c)
 
-    def merge(self, object, entity_name=None):
+    def merge(self, object, entity_name=None, _recursive=None):
         """copy the state of the given object onto the persistent object with the same identifier. 
         
         If there is no persistent instance currently associated with the session, it will be loaded. 
         a newly persistent instance. The given instance does not become associated with the session. 
         This operation cascades to associated instances if the association is mapped with cascade="merge".
         """
+        if _recursive is None:
+            _recursive = util.Set()
         mapper = _object_mapper(object)
         key = getattr(object, '_instance_key', None)
         if key is None:
             else:
                 merged = self.get(mapper.class_, key[1])
         for prop in mapper.props.values():
-            prop.merge(self, object, merged)
+            prop.merge(self, object, merged, _recursive)
         if key is None:
             self.save(merged)
         return merged

test/orm/merge.py

     def test_saved_cascade(self):
         """test merge of a persistent entity with two child persistent entities."""
         mapper(User, users, properties={
-            'addresses':relation(mapper(Address, addresses))
+            'addresses':relation(mapper(Address, addresses), backref='user')
         })
         sess = create_session()
         
         
         mapper(User, users, properties={
             'addresses':relation(mapper(Address, addresses)),
-            'orders':relation(Order)
+            'orders':relation(Order, backref='customer')
         })
         
         sess = create_session()
         u.orders[0].items[1].item_name = 'item 2 modified'
         sess2.merge(u)
         assert u2.orders[0].items[1].item_name == 'item 2 modified'
+
+        sess2 = create_session()
+        o2 = sess2.query(Order).get(o.order_id)
+        o.customer.user_name = 'also fred'
+        sess2.merge(o)
+        assert o2.customer.user_name == 'also fred'
         
         
 if __name__ == "__main__":