1. Kirill Mavreshko
  2. django-mongodb

Source

django-mongodb / django_mongodb / query.py

#-*- coding:utf-8 -*-

# Copyright (c) 2009, Kirill Mavreshko (kimavr@gmail.com)
# All rights reserved.

# Redistribution and use in source and binary forms, with or without 
# modification, are permitted provided that the following conditions are met:

#    * Redistributions of source code must retain the above copyright notice, 
#      this list of conditions and the following disclaimer.
#    * Redistributions in binary form must reproduce the above copyright notice,
#      this list of conditions and the following disclaimer in the documentation
#      and/or other materials provided with the distribution.
#    * The name of the author may not be used to endorse or promote products
#      derived from this software without specific prior written permission.

# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, 
# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 
# POSSIBILITY OF SUCH DAMAGE.

"""
django_mongodb.query -- Core of django_mongodb query engine
"""

import re
from copy import deepcopy

from pymongo.objectid import ObjectId
from pymongo import ASCENDING, DESCENDING

from django.db.models.sql.query import BaseQuery
from django.db.models.sql.constants import *
from django.db.models.sql.datastructures import Empty, EmptyResultSet
from django.db.models.sql.where import AND, OR
from django.utils.datastructures import SortedDict


class MongoDbQuery(BaseQuery):
    """
    Alternative of django.db.models.query.sql.BaseQuery for MongoDB
    """

    def __init__(self, *args, **kwargs):
        apply_mongo_patch()
        super(MongoDbQuery, self).__init__(*args, **kwargs)

    def where_to_mongo_query(self, node, parent_negated=False):
        """ 
        Convert django.db.models.sql.where.WhereNode tree 
        to query dict for MongoDB "find" collection method.
        """
        from django.db import DatabaseError
        if isinstance(node, tuple):
            lvalue, lookup_type, value_annot, params = node
            field_name = lvalue[1]
            if field_name == self.model._meta.pk.name:
                field_name = "_id"
            if lookup_type == "exact":
                if parent_negated:
                    res = {field_name: {"$ne": params[0]}}
                else:
                    res = {field_name: params[0]}
            elif lookup_type == "gt":
                res = {field_name:{parent_negated and "$lte" or "$gt":params[0]}}
            elif lookup_type == "gte":
                res = {field_name:{parent_negated and "$lt" or "$gte":params[0]}}
            elif lookup_type == "lt":
                res = {field_name:{parent_negated and "$gte" or "$lt":params[0]}}
            elif lookup_type == "lte":
                res = {field_name:{parent_negated and "$gt" or "$lte":params[0]}}
            elif lookup_type == 'in':
                if not value_annot:
                    raise EmptyResultSet
                res = {field_name: {parent_negated and "$nin" or "$in": params}}
            else:
                raise DatabaseError("Unsupported lookup type: %r" % lookup_type)
        else:
            # Trying eliminate OR and NOT.
            # TODO: Need better condition validator
            if node.negated and node.connector == OR:
                node.connector = AND
                node.negated = False

            if node.negated and node.connector == AND and len(node.children) > 1:
                raise DatabaseError("Invalid condition - NOT (... AND ...)" % node.as_sql())

            res = {}
            for ch in node.children:
                res.update(self.where_to_mongo_query(ch, node.negated))
        return res

    def _obj_to_list(self, obj, out_cols, cols_map):
        """ 
        Convert SON instance from pymongo to list of "columns" values.
        This is imitation to "columns" of sql table and default Django QuerySet can understand it.
        """
        rval = []
        for col in out_cols:
            #(u'1', [])
            orig_field = cols_map[col]
            if isinstance(orig_field, tuple): # extra select
                extra_expression, extra_params = orig_field
                try:
                    val = int(extra_expression)
                except ValueError:
                    val = extra_expression
            else:
                try:
                    val = obj[orig_field]
                except KeyError:
                    val = None
                else:
                    if isinstance(val, ObjectId):
                        val = unicode(val.url_encode())
                        
            rval.append(val)
        return rval

    def _column_to_attr_map(self, out_cols):
        """ 
        For list of SQL table columns (out_cols), creates
        map to attributes of pymongo SON object.
        """
        cols_map = {}
        pk_name = self.model._meta.pk.name
        for col in out_cols:
            if " AS " in col:
                extra_col = col.rsplit(" AS ",1)[1].strip()
                cols_map[col] = self.extra_select[extra_col]
            else:
                val = col.rsplit('.',1)[1]
                if val == pk_name:
                    val = "_id"
                cols_map[col] = val
        return cols_map

    def _mongo_sort_opts(self):
        """
        Return list of pymongo collection "sort" params for sorting query results.
        """
        if self.extra_order_by:
            ordering = self.extra_order_by
        elif not self.default_ordering:
            ordering = self.order_by
        else:
            ordering = self.order_by or self.model._meta.ordering
        
        sortopts = []
        for field in ordering:
            if field[0] == "-":
                sortopts.append((field[1:], DESCENDING))
            else:
                sortopts.append((field, ASCENDING))
        return sortopts

    def _mongo_query_limit_params(self):
        """
        Return dict of pymongo collection "find" method params for limiting query results.
        """
        extra = {}
        if self.high_mark is not None:
            extra['limit'] = self.high_mark - self.low_mark
        if self.low_mark:
            if self.high_mark is None:
                val = self.connection.ops.no_limit_value()
                if val:
                    extra['limit'] = val
            extra['skip'] = self.low_mark
        return extra

    def execute_sql(self, result_type=MULTI):
        """
        see docs for django.db.models.sql.query.BaseQuery
        """
        from django.db import DatabaseError
        self.pre_sql_setup()
        with_col_aliases = False
        out_cols = self.get_columns(with_col_aliases)
        from_, f_params = self.get_from_clause()
        query = self.where_to_mongo_query(self.where) or None
        db = self.connection.cursor()
        collection = getattr(db, from_[0])

        cols_map = self._column_to_attr_map(out_cols)

        if result_type == SINGLE:
            obj = collection.find_one(query)
            plain_list = self._obj_to_list(obj, out_cols, cols_map)
            if self.ordering_aliases:
                yield [plain_list[:-len(results.ordering_aliases)]]
            yield [plain_list]

        # The MULTI case.
        
        limit_params = self._mongo_query_limit_params()
        sort_opts = self._mongo_sort_opts()
        result = collection.find(query, **limit_params)
        if sort_opts:
            result = result.sort(sort_opts)

        nothing = False
        while not nothing:
            i = 0
            chunk = []
            while i < GET_ITERATOR_CHUNK_SIZE:
                try:
                    item = result.next()
                except StopIteration:
                    i = GET_ITERATOR_CHUNK_SIZE
                    nothing = True
                else:
                    i += 1
                    plain_list = self._obj_to_list(item, out_cols, cols_map)
                    chunk.append(plain_list)

            if len(chunk) > 0:
                yield chunk

    def get_count(self):
        """
        Performs a count() query using the current filter constraints.
        """
        self.pre_sql_setup()
        from_, f_params = self.get_from_clause()
        query = self.where_to_mongo_query(self.where) or None
        db = self.connection.cursor()
        collection = getattr(db, from_[0])
        result = collection.find(query).count()
        return result


def UpdateQuery_execute_sql(self, result_type=None):
    "see django.db.models.sql.subqueries.UpdateQuery.execute_sql method"
    self.pre_sql_setup()
    if not self.values:
        return 0
    table = self.tables[0]
    values, update_params = [], []
    doc = {}
    for name, val, placeholder in self.values:
        if val is not None:
            doc[name] = val
        else:
            doc[name] = None

    query = self.where_to_mongo_query(self.where) or None
    db = self.connection.cursor()
    collection = getattr(db, table)
    source = collection.find(query, fields=["_id"])
    rows = 0
    update_data = {"$set":doc}
    for s in source:
        collection.update(s, update_data)
        rows +=1 
    return rows

def InsertQuery_execute_sql(self, return_id=False):
    "see django.db.models.sql.subqueries.InsertQuery.execute_sql method"
    self.return_id = return_id
    
    db = self.connection.cursor()
    collection = getattr(db, self.model._meta.db_table)
    pk_name = self.model._meta.pk.name
    cols = [c == pk_name and "_id" or c for c in self.columns]
    doc = dict(zip(cols, self.params))
    rval = collection.save(doc)
    return unicode(rval.url_encode())

def DeleteQuery_execute_sql(self, *args, **kwargs):
    "see django.db.models.sql.subqueries.DeleteQuery.execute_sql method"
    assert len(self.tables) == 1, \
        "Can only delete from one table at a time."
    db = self.connection.cursor()
    collection = getattr(db, self.tables[0])
    query = self.where_to_mongo_query(self.where)
    collection.remove(query)

def AutoField_to_python(self, value):
    if isinstance(value, ObjectId):
        return value
    return value

def AutoField_get_db_prep_value(self, value):
    if isinstance(value, basestring):
        value = ObjectId.url_decode(value)
    return value


_patch_info = {'applied':False}

def apply_mongo_patch():
    if not _patch_info['applied']:
        from django.db.models import fields
        from django.db.models.sql import subqueries
        fields.AutoField.to_python = AutoField_to_python
        fields.AutoField.get_db_prep_value = AutoField_get_db_prep_value
        subqueries.UpdateQuery.execute_sql = UpdateQuery_execute_sql
        subqueries.DeleteQuery.execute_sql = DeleteQuery_execute_sql
        subqueries.InsertQuery.execute_sql = InsertQuery_execute_sql
        _patch_info['applied'] = True