Commits

Kirill Mavreshko committed 823b283

ForeignKey and ManyToManyField (for legacy apps), iexact in filter, tests

Comments (0)

Files changed (6)

+- storing files in MongoDB using GridFS
+- special version of ForeignKey for embedding objects
+- automated ensure_index() before query and order
+- add "hint" method to the QuerySet

django_mongodb/djangopatch.py

     return value
 
 def AutoField_get_db_prep_value(self, value):
+    if isinstance(value, basestring):
+        value = ObjectId.url_decode(value)
     return value
 
 def apply_mongo_patch():
     if not _django_is_patched:
         from django.db.models import fields
         from django.db.models.sql import subqueries
-        from django_mongodb import query
+        from django_mongodb import query, related
         from django.core import management
+        # Patching fields
         fields.AutoField.to_python = AutoField_to_python
         fields.AutoField.get_db_prep_value = AutoField_get_db_prep_value
+        fields.related.create_many_related_manager = related.create_many_related_manager
         
+        # Patching ORM
         subqueries.UpdateQuery.execute_sql = query.UpdateQuery_execute_sql
         subqueries.DeleteQuery.execute_sql = query.DeleteQuery_execute_sql
         subqueries.InsertQuery.execute_sql = query.InsertQuery_execute_sql

django_mongodb/query.py

             field_name = lvalue[1]
             if field_name == pk_field.name:
                 field_name = "_id"
-                if isinstance(pk_field, models.AutoField):
-                    params = [isinstance(par, ObjectId) and par or ObjectId.url_decode(par) for par in params]
+                if isinstance(pk_field, models.AutoField):                            
+                    params = [isinstance(par, str) and ObjectId.url_decode(par) or par for par in params]
             if lookup_type == "exact":
                 if parent_negated:
                     res = {field_name: {"$ne": params[0]}}
                 else:
                     res = {field_name: params[0]}
+            elif lookup_type == "iexact":
+                par_re = re.compile("^%s$" % re.escape(params[0]), re.I)
+                if parent_negated:
+                    res = {field_name: {"$ne": par_re}}
+                else:
+                    res = {field_name: par_re}
             elif lookup_type == "gt":
                 res = {field_name:{parent_negated and "$lte" or "$gt":params[0]}}
             elif lookup_type == "gte":

django_mongodb/related.py

+#-*- coding:utf-8 -*-
+
+# Copyright (c) 2009, Kirill Mavreshko (kimavr@gmail.com)
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without 
+# modification, are permitted provided that the following conditions are met:
+
+#    * Redistributions of source code must retain the above copyright notice, 
+#      this list of conditions and the following disclaimer.
+#    * Redistributions in binary form must reproduce the above copyright notice,
+#      this list of conditions and the following disclaimer in the documentation
+#      and/or other materials provided with the distribution.
+#    * The name of the author may not be used to endorse or promote products
+#      derived from this software without specific prior written permission.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 
+# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, 
+# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 
+# POSSIBILITY OF SUCH DAMAGE.
+
+"""
+django_mongodb.related -- Work with related fields in MongoDB-style
+"""
+
+from django.db import connection, models
+from pymongo.objectid import ObjectId
+
+def clean_pk(obj):
+    "Get primary key of obj in the right form"
+    pk_val = obj._get_pk_val()
+    if isinstance(obj._meta.pk, models.AutoField) and isinstance(pk_val, basestring):
+        pk_val = ObjectId.url_decode(pk_val)
+    return pk_val
+
+def create_many_related_manager(superclass, through=False):
+    class ManyRelatedManager(superclass):
+        def __init__(self, model=None, core_filters=None, instance=None, symmetrical=None,
+                join_table=None, source_col_name=None, target_col_name=None):
+            super(ManyRelatedManager, self).__init__()
+            if isinstance(instance._meta.pk, models.AutoField):
+                core_filters = dict([(k, isinstance(par, str) and ObjectId.url_decode(par) or par) for k,par in core_filters.items()])
+            self.core_filters = core_filters
+            self.model = model
+            self.symmetrical = symmetrical
+            self.instance = instance
+            self.join_table = join_table
+            self.source_col_name = source_col_name
+            self.target_col_name = target_col_name
+            self.through = through
+            self._pk_val = clean_pk(self.instance)
+            if self._pk_val is None:
+                raise ValueError("%r instance needs to have a primary key value before a many-to-many relationship can be used." % instance.__class__.__name__)
+
+        def get_query_set(self):
+            mongo_con = connection.cursor()
+            if len(self.core_filters) > 1:
+                raise NotImplementedError("Too complex query in relation manager")
+            cursor = mongo_con[self.join_table].find({self.source_col_name:self.core_filters.values()[0]}, fields=[self.target_col_name])
+            ids = [v[self.target_col_name] for v in cursor]
+            qs = superclass.get_query_set(self)._next_is_sticky()
+            if ids:
+                return qs.filter(pk__in=ids)
+            else:
+                return qs.none()
+
+        # If the ManyToMany relation has an intermediary model,
+        # the add and remove methods do not exist.
+        if through is None:
+            def add(self, *objs):
+                self._add_items(self.source_col_name, self.target_col_name, *objs)
+
+                # If this is a symmetrical m2m relation to self, add the mirror entry in the m2m table
+                if self.symmetrical:
+                    self._add_items(self.target_col_name, self.source_col_name, *objs)
+            add.alters_data = True
+
+            def remove(self, *objs):
+                self._remove_items(self.source_col_name, self.target_col_name, *objs)
+
+                # If this is a symmetrical m2m relation to self, remove the mirror entry in the m2m table
+                if self.symmetrical:
+                    self._remove_items(self.target_col_name, self.source_col_name, *objs)
+            remove.alters_data = True
+
+        def clear(self):
+            self._clear_items(self.source_col_name)
+
+            # If this is a symmetrical m2m relation to self, clear the mirror entry in the m2m table
+            if self.symmetrical:
+                self._clear_items(self.target_col_name)
+        clear.alters_data = True
+
+        def create(self, **kwargs):
+            # This check needs to be done here, since we can't later remove this
+            # from the method lookup table, as we do with add and remove.
+            if through is not None:
+                raise AttributeError, "Cannot use create() on a ManyToManyField which specifies an intermediary model. Use %s's Manager instead." % through
+            new_obj = super(ManyRelatedManager, self).create(**kwargs)
+            self.add(new_obj)
+            return new_obj
+        create.alters_data = True
+
+        def get_or_create(self, **kwargs):
+            obj, created = \
+                    super(ManyRelatedManager, self).get_or_create(**kwargs)
+            # We only need to add() if created because if we got an object back
+            # from get() then the relationship already exists.
+            if created:
+                self.add(obj)
+            return obj, created
+        get_or_create.alters_data = True
+
+        def _add_items(self, source_col_name, target_col_name, *objs):
+            # join_table: name of the m2m link table
+            # source_col_name: the PK colname in join_table for the source object
+            # target_col_name: the PK colname in join_table for the target object
+            # *objs - objects to add. Either object instances, or primary keys of object instances.
+
+            # If there aren't any objects, there is nothing to do.
+            if objs:
+                from django.db.models.base import Model
+                # Check that all the objects are of the right type
+                new_ids = set()
+                for obj in objs:
+                    if isinstance(obj, self.model):
+                        new_ids.add(clean_pk(obj))
+                    elif isinstance(obj, Model):
+                        raise TypeError, "'%s' instance expected" % self.model._meta.object_name
+                    else:
+                        new_ids.add(obj)
+                # Add the newly created or already existing objects to the join table.
+                # First find out which items are already added, to avoid adding them twice
+                mongo_con = connection.cursor()
+                collection = mongo_con[self.join_table]
+                cursor = collection.find(
+                    {source_col_name: self._pk_val, target_col_name:{"$in":list(new_ids)}}, 
+                    fields=[target_col_name])
+                existing_ids = set([doc[target_col_name] for doc in cursor])
+
+                # Add the ones that aren't there already
+                docs = [{source_col_name: self._pk_val, target_col_name: obj_id} for obj_id in (new_ids - existing_ids)]
+                collection.ensure_index([(source_col_name,1), (target_col_name, 1)])
+                collection.insert(docs)
+
+        def _remove_items(self, source_col_name, target_col_name, *objs):
+            # source_col_name: the PK colname in join_table for the source object
+            # target_col_name: the PK colname in join_table for the target object
+            # *objs - objects to remove
+
+            # If there aren't any objects, there is nothing to do.
+            if objs:
+                # Check that all the objects are of the right type
+                old_ids = set()
+                for obj in objs:
+                    if isinstance(obj, self.model):
+                        old_ids.add(clean_pk(obj))
+                    else:
+                        old_ids.add(obj)
+                # Remove the specified objects from the join table
+                mongo_con = connection.cursor()
+                mongo_con[self.join_table].remove({source_col_name:self._pk_val, target_col_name:{"$in":list(old_ids)}})
+
+        def _clear_items(self, source_col_name):
+            mongo_con = connection.cursor()
+            mongo_con[self.join_table].remove({source_col_name: self._pk_val})
+
+    return ManyRelatedManager

testproject/testapp/models.py

     text = models.CharField(max_length=255)
 
 class NoAutoPK_CRUDTestModel(models.Model):
-    name = models.CharField(max_length=50, primary_key=True)
+    name = models.IntegerField(primary_key=True)
     text = models.CharField(max_length=255)
     
+class Author(models.Model):
+    name = models.CharField(max_length=100)
+
+class Article(models.Model):
+    title = models.CharField(max_length=200)
+    author = models.ForeignKey(Author)
+
+class LargeArticle(models.Model):
+    title = models.CharField(max_length=200)
+    authors = models.ManyToManyField(Author)

testproject/testapp/tests.py

 from django.db import models, connection, DatabaseError
 from django.test import TestCase
 
-from testapp.models import CRUDTestModel, NoAutoPK_CRUDTestModel
+from testapp.models import CRUDTestModel, NoAutoPK_CRUDTestModel, Author, Article, LargeArticle
 
 class TestCRUD(TestCase):
 
         # Insert with skipped primary key in model without AutoField
         self.assertRaises(DatabaseError, lambda: NoAutoPK_CRUDTestModel(text="text").save())
         self.assertEqual(self.db[self.nopk_crud_col_name].find().count(), 0)
+        # Insert with predefined primary key
+        NoAutoPK_CRUDTestModel(text="text", name=12).save()
+        self.assertEqual(self.db[self.nopk_crud_col_name].find_one({'_id':12})['text'], "text")
+        self.assertEqual(NoAutoPK_CRUDTestModel.objects.get(name=12).pk, 12)
+
+
+class TestQueries(TestCase):
+    def setUp(self):
+        self.db = connection.cursor()
+        self.author_col = self.db[Author._meta.db_table]
+
+    def tearDown(self):
+        self.db.logout()
+        Author.objects.all().delete()
+
+    def test_queries(self):
+        author = Author.objects.create(name="simple name")
+        self.assertEqual(Author.objects.get(name__iexact="Simple Name"), author)
+
+
+class TestRelations(TestCase):
+    
+    def setUp(self):
+        self.db = connection.cursor()
+        self.author_col = self.db[Author._meta.db_table]
+        self.arts_col = self.db[Article._meta.db_table]
+        self.large_arts_col = self.db[LargeArticle._meta.db_table]
+
+    def tearDown(self):
+        Author.objects.all().delete()
+        self.db.logout()
+
+    def test_foreign_keys(self):
+        author1 = Author.objects.create(name="name1")
+        author2 = Author.objects.create(name="name2")
+        art1 = Article.objects.create(title="title1", author=author1)
+        art2 = Article.objects.create(title="title2", author=author1)
+        self.assertEqual(Article.objects.get(title="title1").author.pk, author1.pk)
+        self.assertEqual(Article.objects.get(title="title2").author.pk, author1.pk)
+        self.assertEqual(list(author1.article_set.all().order_by("title")), [art1, art2])
+        self.assertEqual(author1.article_set.all().count(), 2)
+        self.assertEqual(author2.article_set.all().count(), 0)
+        author1.delete()
+        self.assertEqual(list(Article.objects.filter(title__in=["title1", "title2"])), [])
+
+    def test_many_to_many(self):
+        author1 = Author.objects.create(name="name1")
+        author2 = Author.objects.create(name="name2")
+        art1 = LargeArticle.objects.create(title="title1")
+        art1.authors = [author1, author2]
+        art2 = LargeArticle.objects.create(title="title2")
+        self.assertEqual(list(art1.authors.all().order_by('name')), [author1, author2])
+        self.assertEqual(art1.authors.count(), 2)
+        author1.delete()
+        self.assertEqual(art1.authors.count(), 1)
+        self.assertEqual(art2.authors.count(), 0)