Commits

Anonymous committed 7dc9c53

added support for update-queries as well as F() objects

Comments (0)

Files changed (7)

             else:
                 print 'Aborting'
                 exit()
-#        elif on_production_server or have_appserver:
-#            delete_all_entities()
         else:
             destroy_datastore(self._get_paths())
         self._setup_stubs()
 from .db_settings import get_model_indexes
+from .utils import commit_locked
+from .expressions import ExpressionEvaluator
 
 import datetime
 import sys
 
+from django.db.models import F
 from django.db.models.sql import aggregates as sqlaggregates
 from django.db.models.sql.constants import LOOKUP_SEP, MULTI, SINGLE
 from django.db.models.sql.where import AND, OR
         return key.id_or_name()
 
 class SQLUpdateCompiler(NonrelUpdateCompiler, SQLCompiler):
-    pass
+    def execute_sql(self, result_type=MULTI):
+        # modify query to fetch pks only and then execute the query
+        # to get all pks 
+        self.query.add_immediate_loading(['id'])
+        pks = [row for row in self.results_iter()]
+        self.update_entities(pks)
+        return len(pks)
+    
+    def update_entities(self, pks):
+        for pk in pks:
+            self.update_entity(pk[0])
+    
+    @commit_locked    
+    def update_entity(self, pk):
+        gae_query = self.build_query()
+        key = create_key(self.query.get_meta().db_table, pk)
+        entity = Get(key)
+        if not gae_query.matches_filters(entity):
+            return
+        
+        qn = self.quote_name_unless_alias
+        update_dict = {}
+        for field, o, value in self.query.values:
+            if hasattr(value, 'prepare_database_save'):
+                value = value.prepare_database_save(field)
+            else:
+                value = field.get_db_prep_save(value, connection=self.connection)
+            
+            if hasattr(value, "evaluate"):
+                assert not value.negated
+                assert not value.subtree_parents
+                value = ExpressionEvaluator(value, self.query, entity,
+                                                allow_joins=False)
+                
+            if hasattr(value, 'as_sql'):
+                # evaluate expression and return the new value
+                val = value.as_sql(qn, self.connection)
+                update_dict[field] = val
+            else:
+                update_dict[field] = value
+
+        for field, value in update_dict.iteritems():
+            db_type = field.db_type(connection=self.connection)
+            entity[qn(field.column)] = self.convert_value_for_db(db_type, value)
+
+        key = Put(entity)
 
 class SQLDeleteCompiler(NonrelDeleteCompiler, SQLCompiler):
     pass
+from django.db.models.sql.expressions import SQLEvaluator 
+
+OPERATION_MAP = {
+    '+': lambda x, y: x+y, 
+    '-': lambda x, y: x-y,
+    '*': lambda x, y: x*y,
+    '/': lambda x, y: x/y,
+}
+
+class ExpressionEvaluator(SQLEvaluator):
+    def __init__(self, expression, query, entity, allow_joins=True):
+        super(ExpressionEvaluator, self).__init__(expression, query, allow_joins)
+        self.entity = entity
+    
+    ##################################################
+    # Vistor methods for final expression evaluation #
+    ##################################################
+
+    def evaluate_node(self, node, qn, connection):
+        values = []
+        for child in node.children:
+            if hasattr(child, 'evaluate'):
+                value = child.evaluate(self, qn, connection)
+            else:
+                value = child
+
+            if value:
+                values.append(value)
+
+        return OPERATION_MAP[node.connector](*values)
+
+    def evaluate_leaf(self, node, qn, connection):
+        return self.entity[qn(self.cols[node][1])]
 from google.appengine.datastore.datastore_query import Cursor
+from django.db import models, DEFAULT_DB_ALIAS
+try:
+    from functools import wraps
+except ImportError:
+    from django.utils.functional import wraps  # Python 2.3, 2.4 fallback.
 
 class CursorQueryMixin(object):
     def clone(self, *args, **kwargs):
         end = Cursor.from_websafe_string(end)
     queryset.query._gae_end_cursor = end
     return queryset
+
+def commit_locked(func_or_using=None):
+    """
+    Decorator that locks rows on DB reads.
+    """
+    def inner_commit_locked(func, using=None):
+        def _commit_locked(*args, **kw):
+            from google.appengine.api.datastore import RunInTransaction
+            return RunInTransaction(func, *args, **kw)
+        return wraps(func)(_commit_locked)
+    if func_or_using is None:
+        func_or_using = DEFAULT_DB_ALIAS
+    if callable(func_or_using):
+        return inner_commit_locked(func_or_using, DEFAULT_DB_ALIAS)
+    return lambda func: inner_commit_locked(func, func_or_using)
+
 from .order import OrderTest
 from .not_return_sets import NonReturnSetsTest
 from .decimals import DecimalTest
+from .transactions import TransactionTest
 
 class EmailModel(models.Model):
     email = models.EmailField()
+    number = models.IntegerField(null=True)
 
 class DateTimeModel(models.Model):
     datetime = models.DateTimeField()

tests/transactions.py

+from .testmodels import EmailModel
+from django.db.models import F
+from django.test import TestCase
+
+class TransactionTest(TestCase):
+    emails = ['app-engine@scholardocs.com', 'sharingan@uchias.com',
+        'rinnengan@sage.de', 'rasengan@naruto.com']
+
+    def setUp(self):
+        EmailModel(email=self.emails[0], number=1).save()
+        EmailModel(email=self.emails[0], number=2).save()
+        EmailModel(email=self.emails[1], number=3).save()
+
+    def test_update(self):
+        self.assertEqual(2, len(EmailModel.objects.all().filter(
+            email=self.emails[0])))
+        
+        self.assertEqual(1, len(EmailModel.objects.all().filter(
+            email=self.emails[1])))
+        
+        EmailModel.objects.all().filter(email=self.emails[0]).update(
+            email=self.emails[1])
+        
+        self.assertEqual(0, len(EmailModel.objects.all().filter(
+            email=self.emails[0])))
+        self.assertEqual(3, len(EmailModel.objects.all().filter(
+            email=self.emails[1])))
+        
+    def test_f_object_updates(self):
+        self.assertEqual(1, len(EmailModel.objects.all().filter(
+            number=1)))
+        self.assertEqual(1, len(EmailModel.objects.all().filter(
+            number=2)))
+        
+       # test add
+        EmailModel.objects.all().filter(email=self.emails[0]).update(number=
+            F('number') + F('number'))
+        
+        self.assertEqual(1, len(EmailModel.objects.all().filter(
+            number=2)))
+        self.assertEqual(1, len(EmailModel.objects.all().filter(
+            number=4)))
+        
+        EmailModel.objects.all().filter(email=self.emails[1]).update(number=
+            F('number') + 10, email=self.emails[0])
+        
+        self.assertEqual(1, len(EmailModel.objects.all().filter(number=13)))
+        self.assertEqual(self.emails[0], EmailModel.objects.all().get(number=13).
+                         email)
+        
+        # complex expression test
+        EmailModel.objects.all().filter(number=13).update(number=
+            F('number')*(F('number') + 10) - 5, email=self.emails[0])
+        self.assertEqual(1, len(EmailModel.objects.all().filter(number=294)))
+       
+       # TODO: tests for
+       # test sub
+       # test muld
+       # test div
+       # test mod, ....