Commits

Anonymous committed 892a218

Fixed 'add' parameter of embedded models in lists

add? should depend on the *embedded* model's state, not the containing model's.

Comments (0)

Files changed (2)

djangotoolbox/fields.py

 
     model = property(lambda self:self._model, _set_model)
 
-    def pre_save(self, model_instance, add):
-        embedded_instance = super(EmbeddedModelField, self).pre_save(model_instance, add)
+    def pre_save(self, model_instance, _):
+        embedded_instance = getattr(model_instance, self.attname)
         if embedded_instance is None:
             return None, None
 
 
         values = []
         for field in embedded_instance._meta.fields:
+            add = not embedded_instance._entity_exists
             value = field.pre_save(embedded_instance, add)
             if field.primary_key and value is None:
                 # exclude unset pks ({"id" : None})

djangotoolbox/tests.py

         simple = EmbeddedModelField('EmbeddedModel', null=True)
         simple_untyped = EmbeddedModelField(null=True)
         typed_list = ListField(EmbeddedModelField('SetModel'))
+        typed_list2 = ListField(EmbeddedModelField('EmbeddedModel'))
         untyped_list = ListField(EmbeddedModelField())
         untyped_dict = DictField(EmbeddedModelField())
         ordered_list = ListField(EmbeddedModelField(), ordering=lambda obj: obj.index)
         self.assertRaises(DatabaseError, lambda: list(ExtendedModelProxy.objects.all()))
 
 class EmbeddedModelFieldTest(TestCase):
+    def assertEqualDatetime(self, d1, d2):
+        """ Compares d1 and d2, ignoring microseconds """
+        self.assertEqual(d1.replace(microsecond=0), d2.replace(microsecond=0))
+
     def _simple_instance(self):
         EmbeddedModelFieldModel.objects.create(simple=EmbeddedModel(someint='5'))
         return EmbeddedModelFieldModel.objects.get()
         instance = EmbeddedModelFieldModel.objects.get()
         self.assertEqual(instance.simple.id, instance.id)
 
-    def test_pre_save(self):
-        # Make sure field.pre_save is called
-        instance = self._simple_instance()
-        self.assertNotEqual(instance.simple.auto_now, None)
-        self.assertNotEqual(instance.simple.auto_now_add, None)
-        auto_now = instance.simple.auto_now
-        auto_now_add = instance.simple.auto_now_add
+    def test_pre_save(self, field='simple'):
+        # Make sure field.pre_save is called for embedded objects
+        from time import sleep
+        instance = EmbeddedModelFieldModel.objects.create(**{field: EmbeddedModel()})
+        auto_now = getattr(instance, field).auto_now
+        auto_now_add = getattr(instance, field).auto_now_add
+        self.assertNotEqual(auto_now, None)
+        self.assertNotEqual(auto_now_add, None)
         instance.save()
         instance = EmbeddedModelFieldModel.objects.get()
         # auto_now_add shouldn't have changed now, but auto_now should.
-        self.assertEqual(instance.simple.auto_now_add, auto_now_add)
-        self.assertGreater(instance.simple.auto_now, auto_now)
+        self.assertEqualDatetime(getattr(instance, field).auto_now_add, auto_now_add)
+        self.assertGreater(getattr(instance, field).auto_now, auto_now)
+
+    def test_pre_save_untyped(self):
+        self.test_pre_save(field='simple_untyped')
+
+    def test_pre_save_list(self):
+        # Also make sure auto_now{,add} works for embedded object *lists*.
+        EmbeddedModelFieldModel.objects.create(typed_list2=[EmbeddedModel()])
+        instance = EmbeddedModelFieldModel.objects.get()
+
+        auto_now = instance.typed_list2[0].auto_now
+        auto_now_add = instance.typed_list2[0].auto_now_add
+        self.assertNotEqual(auto_now, None)
+        self.assertNotEqual(auto_now_add, None)
+
+        instance.typed_list2.append(EmbeddedModel())
+        instance.save()
+        instance = EmbeddedModelFieldModel.objects.get()
+
+        self.assertEqualDatetime(instance.typed_list2[0].auto_now_add, auto_now_add)
+        self.assertGreater(instance.typed_list2[0].auto_now, auto_now)
+        self.assertNotEqual(instance.typed_list2[1].auto_now, None)
+        self.assertNotEqual(instance.typed_list2[1].auto_now_add, None)
 
     def test_error_messages(self):
         for kwargs, expected in (