Commits

Waldemar Kornewald  committed e50d253

added __in filter support and exclude(pk__in=...) and exclude(pk=...) as a special case (needed for model validation when using unique_together)

  • Participants
  • Parent commits 8eeccc7

Comments (0)

Files changed (2)

File db/compiler.py

 
 from functools import wraps
 
-from google.appengine.api.datastore import Entity, Query, Put, Get, Delete, Key
+from google.appengine.api.datastore import Entity, Query, MultiQuery, \
+    Put, Get, Delete, Key
 from google.appengine.api.datastore_errors import Error as GAEError
 from google.appengine.api.datastore_types import Text, Category, Email, Link, \
     PhoneNumber, PostalAddress, Text, Blob, ByteString, GeoPt, IM, Key, \
 
     # The following operators are supported with special code below:
     'isnull': None,
+    'in': None,
     'startswith': None,
     'range': None,
     'year': None,
-
-    # TODO: support these filters
-    # in
 }
 
 NEGATION_MAP = {
         super(GAEQuery, self).__init__(compiler, fields)
         self.inequality_field = None
         self.pk_filters = None
+        self.excluded_pks = ()
+        self.ordering = ()
+        self.gae_ordering = []
         pks_only = False
         if len(fields) == 1 and fields[0].primary_key:
             pks_only = True
-        db_table = self.query.get_meta().db_table
-        self.gae_query = Query(db_table, keys_only=pks_only)
+        self.db_table = self.query.get_meta().db_table
+        self.pks_only = pks_only
+        self.gae_query = [Query(self.db_table, keys_only=self.pks_only)]
 
     # This is needed for debugging
     def __repr__(self):
 
     @safe_call
     def fetch(self, low_mark, high_mark):
-        query = self.gae_query
+        query = self._build_query()
         if self.pk_filters is not None:
             results = self.get_matching_pk(low_mark, high_mark)
         else:
             if high_mark is None:
-                results = query.Run(offset=low_mark, prefetch_count=25,
-                                    next_count=75)
+                kw = {}
+                if low_mark:
+                    kw['offset'] = low_mark
+                results = query.Run(**kw)
             elif high_mark > low_mark:
                 results = query.Get(high_mark - low_mark, low_mark)
             else:
                 results = ()
 
         for entity in results:
+            if isinstance(entity, Key):
+                key = entity
+            else:
+                key = entity.key()
+            if key in self.excluded_pks:
+                continue
             yield self._make_entity(entity)
 
     @safe_call
     def count(self, limit=None):
         if self.pk_filters is not None:
             return len(self.get_matching_pk(0, limit))
-        return self.gae_query.Count(limit)
+        if self.excluded_pks:
+            raise DatabaseError("Counting with excluded primary keys is not "
+                                "supported.")
+        return self._build_query().Count(limit)
 
     @safe_call
     def delete(self):
     @safe_call
     def order_by(self, ordering):
         self.ordering = ordering
-        gae_ordering = []
         for order in self.ordering:
             if order.startswith('-'):
                 order, direction = order[1:], Query.DESCENDING
                 direction = Query.ASCENDING
             if order == self.query.get_meta().pk.column:
                 order = '__key__'
-            gae_ordering.append((order, direction))
-        self.gae_query.Order(*gae_ordering)
+            self.gae_ordering.append((order, direction))
 
     # This function is used by the default add_filters() implementation
     @safe_call
             column = '__key__'
             db_table = self.query.get_meta().db_table
             if lookup_type in ('exact', 'in'):
+                # Optimization: batch-get by key
                 if self.pk_filters is not None:
                     raise DatabaseError("You can't apply multiple AND filters "
                                         "on the primary key. "
                                         "Did you mean __in=[...]?")
-                # Optimization: batch-get by key
-                if negated:
-                    raise DatabaseError("You can't negate equality lookups on "
-                                        "the primary key.")
                 if not isinstance(value, (tuple, list)):
                     value = [value]
-                self.pk_filters = [create_key(db_table, pk) for pk in value if pk]
+                pks = [create_key(db_table, pk) for pk in value if pk]
+                if negated:
+                    self.excluded_pks = pks
+                else:
+                    self.pk_filters = pks
                 return
             else:
                 # XXX: set db_type to 'gae_key' in order to allow
                 raise DatabaseError("Can't have inequality filters on multiple "
                     "columns (here: %r and %r)" % (self.inequality_field, column))
             self.inequality_field = column
+        elif lookup_type == 'in':
+            # Create sub-query combinations, one for each value
+            if len(self.gae_query) * len(value) > 30:
+                raise DatabaseError("You can't query against more than "
+                                    "30 __in filter value combinations")
+            gae_query = self.gae_query
+            combined = []
+            values = [self.convert_value_for_db(db_type, v) for v in value]
+            for query in gae_query:
+                for value in values:
+                    self.gae_query = [Query(self.db_table,
+                                            keys_only=self.pks_only)]
+                    self.gae_query[0].update(query)
+                    self._add_filter(column, '=', db_type, value)
+                    combined.append(self.gae_query[0])
+            self.gae_query = combined
+            return
         elif lookup_type == 'startswith':
             self._add_filter(column, '>=', db_type, value)
             if isinstance(value, str):
         self._add_filter(column, op, db_type, value)
 
     def _add_filter(self, column, op, db_type, value):
-        query = self.gae_query
-        key = '%s %s' % (column, op)
-        value = self.convert_value_for_db(db_type, value)
-        if key in query:
-            existing_value = query[key]
-            if isinstance(existing_value, list):
-                existing_value.append(value)
+        for query in self.gae_query:
+            key = '%s %s' % (column, op)
+            value = self.convert_value_for_db(db_type, value)
+            if key in query:
+                existing_value = query[key]
+                if isinstance(existing_value, list):
+                    existing_value.append(value)
+                else:
+                    query[key] = [existing_value, value]
             else:
-                query[key] = [existing_value, value]
-        else:
-            query[key] = value
+                query[key] = value
 
     # ----------------------------------------------
     # Internal API
         entity[self.query.get_meta().pk.column] = key
         return entity
 
+    @safe_call
+    def _build_query(self):
+        if len(self.gae_query) > 1:
+            return MultiQuery(self.gae_query, self.gae_ordering)
+        query = self.gae_query[0]
+        query.Order(*self.gae_ordering)
+        return query
+
     def get_matching_pk(self, low_mark=0, high_mark=None):
         pk_filters = [key for key in self.pk_filters if key is not None]
         if not pk_filters:

File tests/filter.py

                             foreign_key__gt=ordered_instance)]),
                             ['app-engine@scholardocs.com', 'sharingan@uchias.com',])
 
+    def test_exclude_pk(self):
+        self.assertEquals([entity.pk for entity in
+                           OrderedModel.objects.exclude(pk__in=[2, 3])
+                           .order_by('pk')],
+                          [1, 4])
 
     def test_chained_filter(self):
         # additionally tests count :)
                             'rasengan@naruto.com'])], ['app-engine@scholardocs.com',
                             'rasengan@naruto.com'])
 
+    def test_in(self):
+        self.assertEquals([entity.email for entity in
+                           FieldsWithOptionsModel.objects.filter(
+                           floating_point__in=[5.3, 2.6, 1.58]).filter(
+                           integer__in=[1, 5, 9])],
+                          ['app-engine@scholardocs.com', 'rasengan@naruto.com'])
+
+    def test_in_with_pk_in(self):
+        self.assertEquals([entity.email for entity in
+                           FieldsWithOptionsModel.objects.filter(
+                           floating_point__in=[5.3, 2.6, 1.58]).filter(
+                           email__in=['app-engine@scholardocs.com',
+                                      'rasengan@naruto.com'])],
+                          ['app-engine@scholardocs.com', 'rasengan@naruto.com'])
+
     def test_values(self):
         # test values()
         self.assertEquals([entity['pk'] for entity in