Lau Bech Lauritzen committed 919223c

Fixed compiler type comparison to support subclassed compilers

Comments (0)

Files changed (1)

 # XXX: Thread safety concerns?  Should we only need to patch once per process?
+from django.db.models.sql import compiler
+from django.db.models.sql.constants import MULTI
+from django.db.models.sql.datastructures import EmptyResultSet
 class QueryCacheBackend(object):
     """This class is the engine behind the query cache. It reads the queries
     going through the django Query and returns from the cache using
     call ``johnny.cache.get_backend`` to automatically get the proper class.
     __shared_state = {}
+    _read_compilers = (
+        compiler.SQLCompiler,
+        compiler.SQLAggregateCompiler,
+        compiler.SQLDateCompiler
+        )
+    _write_compilers = (
+        compiler.SQLInsertCompiler,
+        compiler.SQLDeleteCompiler,
+        compiler.SQLUpdateCompiler
+        )    
     def __init__(self, cache_backend=None, keyhandler=None, keygen=None):
         self.__dict__ = self.__shared_state
         self.prefix = getattr(settings, 'JOHNNY_MIDDLEWARE_KEY_PREFIX', 'jc')
         self._patched = getattr(self, '_patched', False)
     def _monkey_select(self, original):
-        from django.db.models.sql import compiler
-        from django.db.models.sql.constants import MULTI
-        from django.db.models.sql.datastructures import EmptyResultSet
         def newfun(cls, *args, **kwargs):
             result_type = args[0] if args else kwargs.get('result_type', MULTI)
-            if type(cls) in (compiler.SQLInsertCompiler, compiler.SQLDeleteCompiler, compiler.SQLUpdateCompiler):
+            if any([isinstance(cls, c) for c in self._write_compilers]):
                 return original(cls, *args, **kwargs)
                 sql, params = cls.as_sql()
         def newfun(cls, *args, **kwargs):
             db = getattr(cls, 'using', 'default')
-            from django.db.models.sql import compiler
             # we have to do this before we check the tables, since the tables
             # are actually being set in the original function
             ret = original(cls, *args, **kwargs)
-            if type(cls) == compiler.SQLInsertCompiler:
+            if isinstance(cls, compiler.SQLInsertCompiler):
                 #Inserts are a special case where cls.tables
                 #are not populated.
                 tables = [cls.query.model._meta.db_table]
     def patch(self):
         """monkey patches django.db.models.sql.compiler.SQL*Compiler series"""
         if not self._patched:
-            from django.db.models.sql import compiler
             self._original = {}
-            for reader in (compiler.SQLCompiler, compiler.SQLAggregateCompiler, compiler.SQLDateCompiler):
+            for reader in self._read_compilers:
                 self._original[reader] = reader.execute_sql
                 reader.execute_sql = self._monkey_select(reader.execute_sql)
-            for updater in (compiler.SQLInsertCompiler, compiler.SQLDeleteCompiler, compiler.SQLUpdateCompiler):
+            for updater in self._write_compilers:
                 self._original[updater] = updater.execute_sql
                 updater.execute_sql = self._monkey_write(updater.execute_sql)
             self._patched = True
         """un-applies this patch."""
         if not self._patched:
-        from django.db.models.sql import compiler
-        for func in (compiler.SQLCompiler, compiler.SQLAggregateCompiler, compiler.SQLDateCompiler,
-                compiler.SQLInsertCompiler, compiler.SQLDeleteCompiler, compiler.SQLUpdateCompiler):
+        for func in self._read_compilers + self._write_compilers:
             func.execute_sql = self._original[func]
         self._patched = False