Source

djangotoolbox / djangotoolbox / db / basecompiler.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
from datetime import date, time, datetime
from django.conf import settings
from django.db.models.fields import NOT_PROVIDED
from django.db.models.sql import aggregates as sqlaggregates
from django.db.models.sql.compiler import SQLCompiler
from django.db.models.sql.constants import LOOKUP_SEP, MULTI, SINGLE
from django.db.models.sql.where import AND, OR
from django.db.utils import DatabaseError, IntegrityError
from django.utils.tree import Node
import random

EMULATED_OPS = {
    'exact': lambda x, y: y in x if isinstance(x, (list,tuple)) else x == y,
    'iexact': lambda x, y: x.lower() == y.lower(),
    'startswith': lambda x, y: x.startswith(y),
    'istartswith': lambda x, y: x.lower().startswith(y.lower()),
    'isnull': lambda x, y: x is None if y else x is not None,
    'in': lambda x, y: x in y,
    'lt': lambda x, y: x < y,
    'lte': lambda x, y: x <= y,
    'gt': lambda x, y: x > y,
    'gte': lambda x, y: x >= y,
}

class NonrelQuery(object):
    # ----------------------------------------------
    # Public API
    # ----------------------------------------------
    def __init__(self, compiler, fields):
        self.fields = fields
        self.compiler = compiler
        self.connection = compiler.connection
        self.query = self.compiler.query
        self._negated = False

    def fetch(self, low_mark=0, high_mark=None):
        raise NotImplementedError('Not implemented')

    def count(self, limit=None):
        raise NotImplementedError('Not implemented')

    def delete(self):
        raise NotImplementedError('Not implemented')

    def order_by(self, ordering):
        raise NotImplementedError('Not implemented')

    # Used by add_filters()
    def add_filter(self, column, lookup_type, negated, db_type, value):
        raise NotImplementedError('Not implemented')

    # This is just a default implementation. You might want to override this
    # in case your backend supports OR queries
    def add_filters(self, filters):
        """Traverses the given Where tree and adds the filters to this query"""
        if filters.negated:
            self._negated = not self._negated

        if not self._negated and filters.connector != AND:
            raise DatabaseError('Only AND filters are supported')

        # Remove unneeded children from tree
        children = self._get_children(filters.children)

        if self._negated and filters.connector != OR and len(children) > 1:
            raise DatabaseError("When negating a whole filter subgroup "
                                "(e.g., a Q object) the subgroup filters must "
                                "be connected via OR, so the non-relational "
                                "backend can convert them like this: "
                                '"not (a OR b) => (not a) AND (not b)".')

        for child in children:
            if isinstance(child, Node):
                self.add_filters(child)
                continue

            column, lookup_type, db_type, value = self._decode_child(child)
            self.add_filter(column, lookup_type, self._negated, db_type, value)

        if filters.negated:
            self._negated = not self._negated

    # ----------------------------------------------
    # Internal API for reuse by subclasses
    # ----------------------------------------------
    def _decode_child(self, child):
        constraint, lookup_type, annotation, value = child
        packed, value = constraint.process(lookup_type, value, self.connection)
        alias, column, db_type = packed
        if alias and alias != self.query.model._meta.db_table:
            raise DatabaseError("This database doesn't support JOINs "
                                "and multi-table inheritance.")
        value = self._normalize_lookup_value(value, annotation, lookup_type)
        return column, lookup_type, db_type, value

    def _normalize_lookup_value(self, value, annotation, lookup_type):
        # Django fields always return a list (see Field.get_db_prep_lookup)
        # except if get_db_prep_lookup got overridden by a subclass
        if lookup_type not in ('in', 'range', 'year') and isinstance(value, (tuple, list)):
            if len(value) > 1:
                raise DatabaseError('Filter lookup type was: %s. Expected the '
                                'filters value not to be a list. Only "in"-filters '
                                'can be used with lists.'
                                % lookup_type)
            elif lookup_type == 'isnull':
                value = annotation
            else:
                value = value[0]

        if isinstance(value, unicode):
            value = unicode(value)
        elif isinstance(value, str):
            value = str(value)

        if lookup_type in ('startswith', 'istartswith'):
            value = value[:-1]
        elif lookup_type in ('endswith', 'iendswith'):
            value = value[1:]
        elif lookup_type in ('contains', 'icontains'):
            value = value[1:-1]

        return value

    def _get_children(self, children):
        # Filter out nodes that were automatically added by sql.Query, but are
        # not necessary with emulated negation handling code
        result = []
        for child in children:
            if isinstance(child, tuple):
                constraint = child[0]
                lookup_type = child[1]
                if lookup_type == 'isnull' and constraint.field is None:
                    continue
            result.append(child)
        return result

    def _matches_filters(self, entity, filters):
        # Filters without rules match everything
        if not filters.children:
            return True

        result = filters.connector == AND

        for child in filters.children:
            if isinstance(child, Node):
                submatch = self._matches_filters(entity, child)
            else:
                constraint, lookup_type, annotation, value = child
                packed, value = constraint.process(lookup_type, value, self.connection)
                alias, column, db_type = packed
                if alias != self.query.model._meta.db_table:
                    raise DatabaseError("This database doesn't support JOINs "
                                        "and multi-table inheritance.")

                # Django fields always return a list (see Field.get_db_prep_lookup)
                # except if get_db_prep_lookup got overridden by a subclass
                if lookup_type != 'in' and isinstance(value, (tuple, list)):
                    if len(value) > 1:
                        raise DatabaseError('Filter lookup type was: %s. '
                            'Expected the filters value not to be a list. '
                            'Only "in"-filters can be used with lists.'
                            % lookup_type)
                    elif lookup_type == 'isnull':
                        value = annotation
                    else:
                        value = value[0]

                if entity[column] is None:
                    if isinstance(value, (datetime, date, time)):
                        submatch = lookup_type in ('lt', 'lte')
                    elif lookup_type in ('startswith', 'contains', 'endswith', 'iexact',
                                         'istartswith', 'icontains', 'iendswith'):
                        submatch = False
                    else:
                        submatch = EMULATED_OPS[lookup_type](entity[column], value)
                else:
                    submatch = EMULATED_OPS[lookup_type](entity[column], value)

            if filters.connector == OR and submatch:
                result = True
                break
            elif filters.connector == AND and not submatch:
                result = False
                break

        if filters.negated:
            return not result
        return result

    def _order_in_memory(self, lhs, rhs):
        for order in self.compiler._get_ordering():
            if LOOKUP_SEP in order:
                raise DatabaseError("JOINs in ordering not supported (%s)" % order)
            if order == '?':
                result = random.choice([1, 0, -1])
            else:
                column = order.lstrip('-')
                result = cmp(lhs.get(column), rhs.get(column))
                if order.startswith('-'):
                    result *= -1
            if result != 0:
                return result
        return 0

    def convert_value_from_db(self, db_type, value):
        return self.compiler.convert_value_from_db(db_type, value)

    def convert_value_for_db(self, db_type, value):
        return self.compiler.convert_value_for_db(db_type, value)

class NonrelCompiler(SQLCompiler):
    """
    Base class for non-relational compilers. Provides in-memory filter matching
    and ordering. Entities are assumed to be dictionaries where the keys are
    column names.
    """

    # ----------------------------------------------
    # Public API
    # ----------------------------------------------
    def results_iter(self):
        """
        Returns an iterator over the results from executing this query.
        """
        self.check_query()
        fields = self.get_fields()
        low_mark = self.query.low_mark
        high_mark = self.query.high_mark
        for entity in self.build_query(fields).fetch(low_mark, high_mark):
            yield self._make_result(entity, fields)

    def has_results(self):
        return self.get_count(check_exists=True)

    def execute_sql(self, result_type=MULTI):
        """
        Handles aggregate/count queries
        """
        aggregates = self.query.aggregate_select.values()
        # Simulate a count()
        if aggregates:
            assert len(aggregates) == 1
            aggregate = aggregates[0]
            assert isinstance(aggregate, sqlaggregates.Count)
            meta = self.query.get_meta()
            assert aggregate.col == '*' or aggregate.col == (meta.db_table, meta.pk.column)
            count = self.get_count()
            if result_type is SINGLE:
                return [count]
            elif result_type is MULTI:
                return [[count]]
        raise NotImplementedError('The database backend only supports count() queries')

    # ----------------------------------------------
    # Additional NonrelCompiler API
    # ----------------------------------------------
    def _make_result(self, entity, fields):
        result = []
        for field in fields:
            value = entity.get(field.column, NOT_PROVIDED)
            if value is NOT_PROVIDED:
                value = field.get_default()
            else:
                value = self.convert_value_from_db(field.db_type(connection=self.connection), value)
            if value is None and not field.null:
                raise IntegrityError("Non-nullable field %s can't be None!" % field.name)
            result.append(value)
        return result

    def check_query(self):
        if (len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) > 1
                or self.query.distinct or self.query.extra or self.query.having):
            raise DatabaseError('This query is not supported by the database.')

    def get_count(self, check_exists=False):
        """
        Counts matches using the current filter constraints.
        """
        if check_exists:
            high_mark = 1
        else:
            high_mark = self.query.high_mark
        return self.build_query().count(high_mark)

    def build_query(self, fields=None):
        if fields is None:
            fields = self.get_fields()
        query = self.query_class(self, fields)
        query.add_filters(self.query.where)
        query.order_by(self._get_ordering())

        # This at least satisfies the most basic unit tests
        if settings.DEBUG:
            self.connection.queries.append({'sql': repr(query)})
        return query

    def get_fields(self):
        """
        Returns the fields which should get loaded from the backend by self.query
        """
        # We only set this up here because
        # related_select_fields isn't populated until
        # execute_sql() has been called.
        if self.query.select_fields:
            fields = self.query.select_fields + self.query.related_select_fields
        else:
            fields = self.query.model._meta.fields
        # If the field was deferred, exclude it from being passed
        # into `resolve_columns` because it wasn't selected.
        only_load = self.deferred_to_columns()
        if only_load:
            db_table = self.query.model._meta.db_table
            only_load = dict((k, v) for k, v in only_load.items()
                             if v or k == db_table)
            if len(only_load.keys()) > 1:
                raise DatabaseError('Multi-table inheritance is not supported '
                                    'by non-relational DBs.' + repr(only_load))
            fields = [f for f in fields if db_table in only_load and
                      f.column in only_load[db_table]]

        query_model = self.query.model
        if query_model._meta.proxy:
            query_model = query_model._meta.proxy_for_model

        for field in fields:
            if field.model._meta != query_model._meta:
                raise DatabaseError('Multi-table inheritance is not supported '
                                    'by non-relational DBs.')
        return fields

    def _get_ordering(self):
        if not self.query.default_ordering:
            ordering = self.query.order_by
        else:
            ordering = self.query.order_by or self.query.get_meta().ordering
        result = []
        for order in ordering:
            if LOOKUP_SEP in order:
                raise DatabaseError("Ordering can't span tables on non-relational backends (%s)" % order)
            if order == '?':
                raise DatabaseError("Randomized ordering isn't supported by the backend")

            order = order.lstrip('+')

            descending = order.startswith('-')
            name = order.lstrip('-')
            if name == 'pk':
                name = self.query.get_meta().pk.name
                order = '-' + name if descending else name

            if self.query.standard_ordering:
                result.append(order)
            else:
                if descending:
                    result.append(name)
                else:
                    result.append('-' + name)
        return result

class NonrelInsertCompiler(object):
    def execute_sql(self, return_id=False):
        data = {}
        for (field, value), column in zip(self.query.values, self.query.columns):
            if field is not None:
                if not field.null and value is None:
                    raise IntegrityError("You can't set %s (a non-nullable "
                                        "field) to None!" % field.name)
                db_type = field.db_type(connection=self.connection)
                value = self.convert_value_for_db(db_type, value)
            data[column] = value
        return self.insert(data, return_id=return_id)

    def insert(self, values, return_id):
        """
        :param values: The model object as a list of (column, value) pairs
        :param return_id: Whether to return the id of the newly created entity
        """
        raise NotImplementedError

class NonrelUpdateCompiler(object):
    def execute_sql(self, result_type):
        values = []
        for field, _, value in self.query.values:
            if hasattr(value, 'prepare_database_save'):
                value = value.prepare_database_save(field)
            else:
                value = field.get_db_prep_save(value, connection=self.connection)
            value = self.convert_value_for_db(
                field.db_type(connection=self.connection),
                value
            )
            values.append((field, value))
        return self.update(values)

    def update(self, values):
        """
        :param values: A list of (field, new-value) pairs
        """
        raise NotImplementedError

class NonrelDeleteCompiler(object):
    def execute_sql(self, result_type=MULTI):
        self.build_query([self.query.get_meta().pk]).delete()