Commits

Branko Vukelic  committed 7c165a8

Fixed #2: Customizable context object names

  • Participants
  • Parent commits fff724b

Comments (0)

Files changed (2)

File related/views.py

     existing_request_pk_key = 'pk'
     existing_request_slug_key = 'slug'
     existing_redirect_url = None
+    existing_form_name = None
 
     def get_existing_redirect_url(self, existing_object):
         if not self.existing_redirect_url:
         form_kwargs.update(kwargs)
         return form_class(*args, **form_kwargs)
 
+    def get_existing_form_name(self):
+        return self.existing_form_name
+
     def get_existing_from_form(self):
         form = self.get_existing_form()
         if form.is_valid():
                 form = self.get_existing_form()
 
             context['existing_form'] = form
+            context_form_name = self.get_existing_form_name()
+            if context_form_name:
+                context[context_form_name] = form
 
         context.update(
             super(GetExistingMixin, self).get_context_data(*arg, **kwarg)
     related_slug_field = 'slug'
     related_slug_url_kwarg = 'slug'
     related_object_gone_message = '<h2>Database record is missing</h2>'
+    related_object_name = None
     integrity_error_message = 'Such record already exists'
 
     def get_related_404_url(self):
     def related_object_gone(self):
         return HttpResponseGone(self.get_related_object_gone_message())
 
+    def get_related_object_name(self):
+        return self.related_object_name
+
     def form_valid(self, form):
         try:
             related_object = self.get_related_object()
         context = super(
             CreateWithRelatedMixin, self
         ).get_context_data(*args, **kwargs)
-        context['related_object'] = self.get_related_object()
+
+        related_object = self.get_related_object()
+
+        context['related_object'] = related_object
+
+        related_object_name = self.get_related_object_name()
+        if related_object_name:
+            context[related_object_name] = related_object
+
         return context
 

File tests/tests.py

 
         self.assertTrue(ctx.has_key('existing_form'))
 
+    def test_get_context_name(self):
+        class Super(object):
+            def get_context_data(self):
+                return dict(foo='bar')
+
+        class View(GetExistingMixin, Super):
+            pass
+
+        view = View()
+
+        view.request = Mock()
+        view.request.get = dict(fam='dam')
+        view.get_existing_form_class = Mock()
+        view.get_existing_form = Mock()
+        view.existing_form_name = 'foo_form'
+
+        ctx = view.get_context_data()
+
+        self.assertTrue(ctx.has_key('existing_form'))
+        self.assertTrue(ctx.has_key('foo_form'))
+
     def test_setting_initial_when_GET_is_empty(self):
         class Super(object):
             def get_context_data(self):
         self.assertEqual(ctx['related_object'],
                          view.get_related_object.return_value)
 
+    def test_related_object_name(self):
+        class Super(object):
+            def get_context_data(self):
+                return dict()
+
+        class View(CreateWithRelatedMixin, Super):
+            pass
+
+        view = View()
+        view.get_related_object = Mock()
+        view.related_object_name = 'foo_object'
+
+        ctx = view.get_context_data()
+
+        self.assertEqual(ctx['related_object'],
+                         view.get_related_object.return_value)
+        self.assertEqual(ctx['foo_object'],
+                         view.get_related_object.return_value)
+
     def test_form_valid(self):
         self.view.get_related_field = Mock()
         self.view.get_related_field.return_value = 'foo'