Source

django-mongodb / django_mongodb / related.py

#-*- coding:utf-8 -*-

# Copyright (c) 2009, Kirill Mavreshko (kimavr@gmail.com)
# All rights reserved.

# Redistribution and use in source and binary forms, with or without 
# modification, are permitted provided that the following conditions are met:

#    * Redistributions of source code must retain the above copyright notice, 
#      this list of conditions and the following disclaimer.
#    * Redistributions in binary form must reproduce the above copyright notice,
#      this list of conditions and the following disclaimer in the documentation
#      and/or other materials provided with the distribution.
#    * The name of the author may not be used to endorse or promote products
#      derived from this software without specific prior written permission.

# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, 
# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 
# POSSIBILITY OF SUCH DAMAGE.

"""
django_mongodb.related -- Work with related fields in MongoDB-style
"""

from django.db import connection, models
from pymongo.objectid import ObjectId

def clean_pk(obj):
    "Get primary key of obj in the right form"
    pk_val = obj._get_pk_val()
    if isinstance(obj._meta.pk, models.AutoField) and isinstance(pk_val, basestring):
        pk_val = ObjectId.url_decode(pk_val)
    return pk_val

def create_many_related_manager(superclass, through=False):
    class ManyRelatedManager(superclass):
        def __init__(self, model=None, core_filters=None, instance=None, symmetrical=None,
                join_table=None, source_col_name=None, target_col_name=None):
            super(ManyRelatedManager, self).__init__()
            if isinstance(instance._meta.pk, models.AutoField):
                core_filters = dict([(k, isinstance(par, str) and ObjectId.url_decode(par) or par) for k,par in core_filters.items()])
            self.core_filters = core_filters
            self.model = model
            self.symmetrical = symmetrical
            self.instance = instance
            self.join_table = join_table
            self.source_col_name = source_col_name
            self.target_col_name = target_col_name
            self.through = through
            self._pk_val = clean_pk(self.instance)
            if self._pk_val is None:
                raise ValueError("%r instance needs to have a primary key value before a many-to-many relationship can be used." % instance.__class__.__name__)

        def get_query_set(self):
            mongo_con = connection.cursor()
            if len(self.core_filters) > 1:
                raise NotImplementedError("Too complex query in relation manager")
            cursor = mongo_con[self.join_table].find({self.source_col_name:self.core_filters.values()[0]}, fields=[self.target_col_name])
            ids = [v[self.target_col_name] for v in cursor]
            qs = superclass.get_query_set(self)._next_is_sticky()
            if ids:
                return qs.filter(pk__in=ids)
            else:
                return qs.none()

        # If the ManyToMany relation has an intermediary model,
        # the add and remove methods do not exist.
        if through is None:
            def add(self, *objs):
                self._add_items(self.source_col_name, self.target_col_name, *objs)

                # If this is a symmetrical m2m relation to self, add the mirror entry in the m2m table
                if self.symmetrical:
                    self._add_items(self.target_col_name, self.source_col_name, *objs)
            add.alters_data = True

            def remove(self, *objs):
                self._remove_items(self.source_col_name, self.target_col_name, *objs)

                # If this is a symmetrical m2m relation to self, remove the mirror entry in the m2m table
                if self.symmetrical:
                    self._remove_items(self.target_col_name, self.source_col_name, *objs)
            remove.alters_data = True

        def clear(self):
            self._clear_items(self.source_col_name)

            # If this is a symmetrical m2m relation to self, clear the mirror entry in the m2m table
            if self.symmetrical:
                self._clear_items(self.target_col_name)
        clear.alters_data = True

        def create(self, **kwargs):
            # This check needs to be done here, since we can't later remove this
            # from the method lookup table, as we do with add and remove.
            if through is not None:
                raise AttributeError, "Cannot use create() on a ManyToManyField which specifies an intermediary model. Use %s's Manager instead." % through
            new_obj = super(ManyRelatedManager, self).create(**kwargs)
            self.add(new_obj)
            return new_obj
        create.alters_data = True

        def get_or_create(self, **kwargs):
            obj, created = \
                    super(ManyRelatedManager, self).get_or_create(**kwargs)
            # We only need to add() if created because if we got an object back
            # from get() then the relationship already exists.
            if created:
                self.add(obj)
            return obj, created
        get_or_create.alters_data = True

        def _add_items(self, source_col_name, target_col_name, *objs):
            # join_table: name of the m2m link table
            # source_col_name: the PK colname in join_table for the source object
            # target_col_name: the PK colname in join_table for the target object
            # *objs - objects to add. Either object instances, or primary keys of object instances.

            # If there aren't any objects, there is nothing to do.
            if objs:
                from django.db.models.base import Model
                # Check that all the objects are of the right type
                new_ids = set()
                for obj in objs:
                    if isinstance(obj, self.model):
                        new_ids.add(clean_pk(obj))
                    elif isinstance(obj, Model):
                        raise TypeError, "'%s' instance expected" % self.model._meta.object_name
                    else:
                        new_ids.add(obj)
                # Add the newly created or already existing objects to the join table.
                # First find out which items are already added, to avoid adding them twice
                mongo_con = connection.cursor()
                collection = mongo_con[self.join_table]
                cursor = collection.find(
                    {source_col_name: self._pk_val, target_col_name:{"$in":list(new_ids)}}, 
                    fields=[target_col_name])
                existing_ids = set([doc[target_col_name] for doc in cursor])

                # Add the ones that aren't there already
                docs = [{source_col_name: self._pk_val, target_col_name: obj_id} for obj_id in (new_ids - existing_ids)]
                collection.ensure_index([(source_col_name,1), (target_col_name, 1)])
                collection.insert(docs)

        def _remove_items(self, source_col_name, target_col_name, *objs):
            # source_col_name: the PK colname in join_table for the source object
            # target_col_name: the PK colname in join_table for the target object
            # *objs - objects to remove

            # If there aren't any objects, there is nothing to do.
            if objs:
                # Check that all the objects are of the right type
                old_ids = set()
                for obj in objs:
                    if isinstance(obj, self.model):
                        old_ids.add(clean_pk(obj))
                    else:
                        old_ids.add(obj)
                # Remove the specified objects from the join table
                mongo_con = connection.cursor()
                mongo_con[self.join_table].remove({source_col_name:self._pk_val, target_col_name:{"$in":list(old_ids)}})

        def _clear_items(self, source_col_name):
            mongo_con = connection.cursor()
            mongo_con[self.join_table].remove({source_col_name: self._pk_val})

    return ManyRelatedManager