django-mongodb / django_mongodb / query.py

#-*- coding:utf-8 -*-

# Copyright (c) 2009, Kirill Mavreshko (kimavr@gmail.com)
# All rights reserved.

# Redistribution and use in source and binary forms, with or without 
# modification, are permitted provided that the following conditions are met:

#    * Redistributions of source code must retain the above copyright notice, 
#      this list of conditions and the following disclaimer.
#    * Redistributions in binary form must reproduce the above copyright notice,
#      this list of conditions and the following disclaimer in the documentation
#      and/or other materials provided with the distribution.
#    * The name of the author may not be used to endorse or promote products
#      derived from this software without specific prior written permission.

# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, 
# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 
# POSSIBILITY OF SUCH DAMAGE.

"""
django_mongodb.query -- Core of django_mongodb query engine
"""

import re, sys
from copy import deepcopy

from pymongo.objectid import ObjectId
from pymongo import ASCENDING, DESCENDING

from django.db import DatabaseError
from django.db.models.sql.query import BaseQuery
from django.db.models.sql.constants import *
from django.db.models.sql.datastructures import Empty, EmptyResultSet
from django.db.models.sql.where import AND, OR
from django.utils.datastructures import SortedDict

class MongoDbQuery(BaseQuery):
    """
    Alternative of django.db.models.query.sql.BaseQuery for MongoDB
    """

    def __init__(self, *args, **kwargs):
        from django_mongodb import djangopatch
        djangopatch.apply_mongo_patch()
        super(MongoDbQuery, self).__init__(*args, **kwargs)

    def where_to_mongo_query(self, node, parent_negated=False):
        """ 
        Convert django.db.models.sql.where.WhereNode tree 
        to query dict for MongoDB "find" collection method.
        """
        from django.db import models
        pk_field = self.model._meta.pk
        if isinstance(node, tuple):
            lvalue, lookup_type, value_annot, params = node
            field_name = lvalue[1]
            if field_name == pk_field.name:
                field_name = "_id"
                if isinstance(pk_field, models.AutoField):                            
                    params = [isinstance(par, str) and ObjectId.url_decode(par) or par for par in params]
            if lookup_type == "exact":
                if parent_negated:
                    res = {field_name: {"$ne": params[0]}}
                else:
                    res = {field_name: params[0]}
            elif lookup_type == "iexact":
                par_re = re.compile("^%s$" % re.escape(params[0]), re.I)
                if parent_negated:
                    res = {field_name: {"$ne": par_re}}
                else:
                    res = {field_name: par_re}
            elif lookup_type == "gt":
                res = {field_name:{parent_negated and "$lte" or "$gt":params[0]}}
            elif lookup_type == "gte":
                res = {field_name:{parent_negated and "$lt" or "$gte":params[0]}}
            elif lookup_type == "lt":
                res = {field_name:{parent_negated and "$gte" or "$lt":params[0]}}
            elif lookup_type == "lte":
                res = {field_name:{parent_negated and "$gt" or "$lte":params[0]}}
            elif lookup_type == 'in':
                if not value_annot:
                    raise EmptyResultSet
                res = {field_name: {parent_negated and "$nin" or "$in": params}}
            else:
                raise DatabaseError("Unsupported lookup type: %r" % lookup_type)
        else:
            # Trying eliminate OR and NOT.
            # TODO: Need better condition validator
            if node.negated and node.connector == OR:
                node.connector = AND
                node.negated = False

            if node.negated and node.connector == AND and len(node.children) > 1:
                raise DatabaseError("Invalid condition - NOT (... AND ...)" % node.as_sql())

            res = {}
            for ch in node.children:
                res.update(self.where_to_mongo_query(ch, node.negated))
        return res

    def _obj_to_list(self, obj, out_cols, cols_map):
        """ 
        Convert SON instance from pymongo to list of "columns" values.
        This is imitation to "columns" of sql table and default Django QuerySet can understand it.
        """
        rval = []
        for col in out_cols:
            orig_field = cols_map[col]
            if isinstance(orig_field, tuple): # extra select
                extra_expression, extra_params = orig_field
                try:
                    val = int(extra_expression)
                except ValueError:
                    val = extra_expression
            else:
                try:
                    val = obj[orig_field]
                except KeyError:
                    val = None
                else:
                    if isinstance(val, ObjectId):
                        val = unicode(val.url_encode())
                        
            rval.append(val)
        return rval

    def _column_to_attr_map(self, out_cols):
        """ 
        For list of SQL table columns (out_cols), creates
        map to attributes of pymongo SON object.
        """
        cols_map = {}
        pk_name = self.model._meta.pk.name
        for col in out_cols:
            if " AS " in col:
                extra_col = col.rsplit(" AS ",1)[1].strip()
                cols_map[col] = self.extra_select[extra_col]
            else:
                val = col.rsplit('.',1)[1]
                if val == pk_name:
                    val = "_id"
                cols_map[col] = val
        return cols_map

    def _mongo_sort_opts(self):
        """
        Return list of pymongo collection "sort" params for sorting query results.
        """
        if self.extra_order_by:
            ordering = self.extra_order_by
        elif not self.default_ordering:
            ordering = self.order_by
        else:
            ordering = self.order_by or self.model._meta.ordering
        
        sortopts = []
        for field in ordering:
            if field[0] == "-":
                sortopts.append((field[1:], DESCENDING))
            else:
                sortopts.append((field, ASCENDING))
        return sortopts

    def _mongo_query_limit_params(self):
        """
        Return dict of pymongo collection "find" method params for limiting query results.
        """
        extra = {}
        if self.high_mark is not None:
            extra['limit'] = self.high_mark - self.low_mark
        if self.low_mark:
            if self.high_mark is None:
                val = self.connection.ops.no_limit_value()
                if val:
                    extra['limit'] = val
            extra['skip'] = self.low_mark
        return extra

    def execute_sql(self, result_type=MULTI):
        """
        see docs for django.db.models.sql.query.BaseQuery
        """
        from django.db import DatabaseError
        self.pre_sql_setup()
        with_col_aliases = False
        out_cols = self.get_columns(with_col_aliases)
        from_, f_params = self.get_from_clause()
        query = self.where_to_mongo_query(self.where) or None
        db = self.connection.cursor()
        collection = getattr(db, from_[0])

        cols_map = self._column_to_attr_map(out_cols)

        if result_type == SINGLE:
            obj = collection.find_one(query)
            plain_list = self._obj_to_list(obj, out_cols, cols_map)
            if self.ordering_aliases:
                yield [plain_list[:-len(results.ordering_aliases)]]
            yield [plain_list]

        # The MULTI case.
        
        limit_params = self._mongo_query_limit_params()
        sort_opts = self._mongo_sort_opts()
        result = collection.find(query, **limit_params)
        if sort_opts:
            result = result.sort(sort_opts)

        nothing = False
        while not nothing:
            i = 0
            chunk = []
            while i < GET_ITERATOR_CHUNK_SIZE:
                try:
                    item = result.next()
                except StopIteration:
                    i = GET_ITERATOR_CHUNK_SIZE
                    nothing = True
                else:
                    i += 1
                    plain_list = self._obj_to_list(item, out_cols, cols_map)
                    chunk.append(plain_list)

            if len(chunk) > 0:
                yield chunk

    def get_count(self):
        """
        Performs a count() query using the current filter constraints.
        """
        self.pre_sql_setup()
        from_, f_params = self.get_from_clause()
        query = self.where_to_mongo_query(self.where) or None
        db = self.connection.cursor()
        collection = getattr(db, from_[0])
        result = collection.find(query).count()
        return result


def UpdateQuery_execute_sql(self, result_type=None):
    "see django.db.models.sql.subqueries.UpdateQuery.execute_sql method"
    self.pre_sql_setup()
    if not self.values:
        return 0
    table = self.tables[0]
    values, update_params = [], []
    doc = {}
    for name, val, placeholder in self.values:
        if val is not None:
            doc[name] = val
        else:
            doc[name] = None

    query = self.where_to_mongo_query(self.where) or None
    db = self.connection.cursor()
    collection = getattr(db, table)
    source = collection.find(query, fields=["_id"])
    rows = 0
    update_data = {"$set":doc}
    for s in source:
        collection.update(s, update_data)
        rows +=1 
    return rows

def InsertQuery_execute_sql(self, return_id=False):
    "see django.db.models.sql.subqueries.InsertQuery.execute_sql method"
    from django.db import models
    self.return_id = return_id
    db = self.connection.cursor()
    collection = getattr(db, self.model._meta.db_table)
    pk_field = self.model._meta.pk
    pk_name = pk_field.name
    cols = [c == pk_name and "_id" or c for c in self.columns]
    doc = dict(zip(cols, self.params))
    if not (isinstance(pk_field, models.AutoField) or doc.get("_id")):
        raise DatabaseError("You must provide value for %r primary key" % pk_name)
    rval = collection.save(doc)
    if isinstance(rval, ObjectId):
        rval = rval.url_encode()
    return rval

def DeleteQuery_execute_sql(self, *args, **kwargs):
    "see django.db.models.sql.subqueries.DeleteQuery.execute_sql method"
    assert len(self.tables) == 1, \
        "Can only delete from one table at a time."
    db = self.connection.cursor()
    collection = getattr(db, self.tables[0])
    query = self.where_to_mongo_query(self.where)
    collection.remove(query)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.