Commits

Andrew Godwin committed 3b73b3e Merge

Merge in schema-caching branch.

Comments (0)

Files changed (6)

south/db/generic.py

         return getattr(self, attrname)(*args, **kwds)
     return func
 
+def invalidate_table_constraints(func):
+    def _cache_clear(self, table, *args, **opts):
+        self._set_cache(table, value=INVALID)
+        return func(self, table, *args, **opts)
+    return _cache_clear
+
+def delete_column_constraints(func):
+    def _column_rm(self, table, column, *args, **opts):
+        self._set_cache(table, column, value=[])
+        return func(self, table, column, *args, **opts)
+    return _column_rm
+
+def copy_column_constraints(func):
+    def _column_cp(self, table, column_old, column_new, *args, **opts):
+        db_name = self._get_setting('NAME')
+        self._set_cache(table, column_new, value=self.lookup_constraint(db_name, table, column_old))
+        return func(self, table, column_old, column_new, *args, **opts)
+    return _column_cp
+
+class INVALID(Exception):
+    def __repr__(self):
+        return 'INVALID'
 
 class DatabaseOperations(object):
 
         self.pending_transactions = 0
         self.pending_create_signals = []
         self.db_alias = db_alias
+        self._constraint_cache = {}
         self._initialised = False
-    
+
+    def lookup_constraint(self, db_name, table_name, column_name=None):
+        """ return a set() of constraints for db_name.table_name.column_name """
+        def _lookup():
+            table = self._constraint_cache[db_name][table_name]
+            if table is INVALID:
+                raise INVALID
+            elif column_name is None:
+                return table.items()
+            else:
+                return table[column_name]
+
+        try:
+            ret = _lookup()
+            return ret
+        except INVALID as e:
+            del self._constraint_cache[db_name][table_name]
+            self._fill_constraint_cache(db_name, table_name)
+        except KeyError as e:
+            if self._is_valid_cache(db_name, table_name):
+                return []
+            self._fill_constraint_cache(db_name, table_name)
+
+        return self.lookup_constraint(db_name, table_name, column_name)
+
+    def _set_cache(self, table_name, column_name=None, value=INVALID):
+        db_name = self._get_setting('NAME')
+        try:
+            if column_name is not None:
+                self._constraint_cache[db_name][table_name][column_name] = value
+            else:
+                self._constraint_cache[db_name][table_name] = value
+        except (LookupError, TypeError):
+            pass
+
+    def _is_valid_cache(self, db_name, table_name):
+        # we cache per-table so if the table is there it is valid
+        try:
+            return self._constraint_cache[db_name][table_name] is not INVALID
+        except KeyError:
+            return False
+
     def _is_multidb(self):
         try: 
             from django.db import connections
         return self.pending_create_signals
 
 
+    @invalidate_table_constraints
     def create_table(self, table_name, fields):
         """
         Creates the table 'table_name'. 'fields' is a tuple of fields,
     add_table = alias('create_table') # Alias for consistency's sake
 
 
+    @invalidate_table_constraints
     def rename_table(self, old_table_name, table_name):
         """
         Renames the table 'old_table_name' to 'table_name'.
         self.execute('ALTER TABLE %s RENAME TO %s;' % params)
 
 
+    @invalidate_table_constraints
     def delete_table(self, table_name, cascade=True):
         """
         Deletes the table 'table_name'.
     drop_table = alias('delete_table')
 
 
+    @invalidate_table_constraints
     def clear_table(self, table_name):
         """
         Deletes all rows from 'table_name'.
 
 
 
+    @invalidate_table_constraints
     def add_column(self, table_name, name, field, keep_default=True):
         """
         Adds the column 'name' to the table 'table_name'.
         else:
             sqls.append(('ALTER COLUMN %s DROP DEFAULT' % (self.quote_name(name),), []))
 
+    @invalidate_table_constraints
     def alter_column(self, table_name, name, field, explicit_name=True, ignore_constraints=False):
         """
         Alters the given column name so it will match the given field.
                     )
                 )
 
+    def _fill_constraint_cache(self, db_name, table_name):
+
+        schema = self._get_schema_name()            
+        ifsc_tables = ["constraint_column_usage", "key_column_usage"]
+
+        self._constraint_cache.setdefault(db_name, {})
+        self._constraint_cache[db_name][table_name] = {}
+
+        for ifsc_table in ifsc_tables:
+            rows = self.execute("""
+                SELECT kc.constraint_name, kc.column_name, c.constraint_type
+                FROM information_schema.%s AS kc
+                JOIN information_schema.table_constraints AS c ON
+                    kc.table_schema = c.table_schema AND
+                    kc.table_name = c.table_name AND
+                    kc.constraint_name = c.constraint_name
+                WHERE
+                    kc.table_schema = %%s AND
+                    kc.table_name = %%s
+            """ % ifsc_table, [schema, table_name])
+            for constraint, column, kind in rows:
+                self._constraint_cache[db_name][table_name].setdefault(column, set())
+                self._constraint_cache[db_name][table_name][column].add((kind, constraint))
+        return
 
     def _constraints_affecting_columns(self, table_name, columns, type="UNIQUE"):
         """
         Gets the names of the constraints affecting the given columns.
         If columns is None, returns all constraints of the type on the table.
         """
-
         if self.dry_run:
             raise ValueError("Cannot get constraints for columns during a dry run.")
 
         if columns is not None:
             columns = set(columns)
 
-        if type == "CHECK":
-            ifsc_table = "constraint_column_usage"
-        else:
-            ifsc_table = "key_column_usage"
+        db_name = self._get_setting('NAME')
 
-        schema = self._get_schema_name()            
+        cnames = {}
+        for col, constraints in self.lookup_constraint(db_name, table_name):
+            for kind, cname in constraints:
+                if kind == type:
+                    cnames.setdefault(cname, set())
+                    cnames[cname].add(col)
 
-        # First, load all constraint->col mappings for this table.
-        rows = self.execute("""
-            SELECT kc.constraint_name, kc.column_name
-            FROM information_schema.%s AS kc
-            JOIN information_schema.table_constraints AS c ON
-                kc.table_schema = c.table_schema AND
-                kc.table_name = c.table_name AND
-                kc.constraint_name = c.constraint_name
-            WHERE
-                kc.table_schema = %%s AND
-                kc.table_name = %%s AND
-                c.constraint_type = %%s
-        """ % ifsc_table, [schema, table_name, type])
-        
-        # Load into a dict
-        mapping = {}
-        for constraint, column in rows:
-            mapping.setdefault(constraint, set())
-            mapping[constraint].add(column)
-        
-        # Find ones affecting these columns
-        for constraint, itscols in mapping.items():
-            # If columns is None we definitely want this field! (see docstring)
-            if itscols == columns or columns is None:
-                yield constraint
+        for cname, cols in cnames.items():
+            if cols == columns or columns is None:
+                yield cname
 
+    @invalidate_table_constraints
     def create_unique(self, table_name, columns):
         """
         Creates a UNIQUE constraint on the columns on the given table.
         ))
         return name
 
+    @invalidate_table_constraints
     def delete_unique(self, table_name, columns):
         """
         Deletes a UNIQUE constraint on precisely the columns on the given table.
         )
     
 
+    @invalidate_table_constraints
     def delete_foreign_key(self, table_name, column):
         "Drop a foreign key constraint"
         if self.dry_run:
             tablespace_sql
         )
 
+    @invalidate_table_constraints
     def create_index(self, table_name, column_names, unique=False, db_tablespace=''):
         """ Executes a create index statement """
         sql = self.create_index_sql(table_name, column_names, unique, db_tablespace)
         self.execute(sql)
 
 
+    @invalidate_table_constraints
     def delete_index(self, table_name, column_names, db_tablespace=''):
         """
         Deletes an index created with create_index.
     drop_index = alias('delete_index')
 
 
+    @delete_column_constraints
     def delete_column(self, table_name, name):
         """
         Deletes the column 'column_name' from the table 'table_name'.
         """
+        db_name = self._get_setting('NAME')
         params = (self.quote_name(table_name), self.quote_name(name))
         self.execute(self.delete_column_string % params, [])
 
         raise NotImplementedError("rename_column has no generic SQL syntax")
 
 
+    @invalidate_table_constraints
     def delete_primary_key(self, table_name):
         """
         Drops the old primary key.
     drop_primary_key = alias('delete_primary_key')
 
 
+    @invalidate_table_constraints
     def create_primary_key(self, table_name, columns):
         """
         Creates a new primary key on the specified columns.

south/db/mysql.py

     has_ddl_transactions = False
     has_check_constraints = False
     delete_unique_sql = "ALTER TABLE %s DROP INDEX %s"
-    
-    
+
+    def _is_valid_cache(self, db_name, table_name):
+        cache = self._constraint_cache
+        # we cache the whole db so if there are any tables table_name is valid
+        return db_name in cache and cache[db_name].get(table_name, None) is not generic.INVALID
+
+    def _fill_constraint_cache(self, db_name, table_name):
+        # for MySQL grab all constraints for this database.  It's just as cheap as a single column.
+        self._constraint_cache[db_name] = {}
+        self._constraint_cache[db_name][table_name] = {}
+
+        name_query = """
+            SELECT kc.constraint_name, kc.column_name, kc.table_name
+            FROM information_schema.key_column_usage AS kc
+            WHERE
+                kc.table_schema = %s AND
+                kc.table_catalog IS NULL
+        """
+        rows = self.execute(name_query, [db_name])
+        if not rows:
+            return
+        cnames = {}
+        for constraint, column, table in rows:
+            key = (table, constraint)
+            cnames.setdefault(key, set())
+            cnames[key].add(column)
+
+        type_query = """
+            SELECT c.constraint_name, c.table_name, c.constraint_type
+            FROM information_schema.table_constraints AS c
+            WHERE
+                c.table_schema = %s
+        """
+        rows = self.execute(type_query, [db_name])
+        for constraint, table, kind in rows:
+            key = (table, constraint)
+            self._constraint_cache[db_name].setdefault(table, {})
+            try:
+                cols = cnames[key]
+            except KeyError:
+                cols = set()
+            for column in cols:
+                self._constraint_cache[db_name][table].setdefault(column, set())
+                self._constraint_cache[db_name][table][column].add((kind, constraint))
+
+
     def connection_init(self):
         """
         Run before any SQL to let database-specific config be sent as a command,
         cursor.execute("SET FOREIGN_KEY_CHECKS=0;")
         self.deferred_sql.append("SET FOREIGN_KEY_CHECKS=1;")
 
-    
+    @generic.copy_column_constraints
+    @generic.delete_column_constraints
     def rename_column(self, table_name, old, new):
         if old == new or self.dry_run:
             return []
             self.execute(sql, (rows[0][4],))
         else:
             self.execute(sql)
-    
-    
+
+    @generic.delete_column_constraints
     def delete_column(self, table_name, name):
         db_name = self._get_setting('NAME')
-        
+
         # See if there is a foreign key on this column
-        cursor = self._get_connection().cursor()
-        get_fkeyname_query = "SELECT tc.constraint_name FROM \
-                              information_schema.table_constraints tc, \
-                              information_schema.key_column_usage kcu \
-                              WHERE tc.table_name=kcu.table_name \
-                              AND tc.table_schema=kcu.table_schema \
-                              AND tc.constraint_name=kcu.constraint_name \
-                              AND tc.constraint_type='FOREIGN KEY' \
-                              AND tc.table_schema='%s' \
-                              AND tc.table_name='%s' \
-                              AND kcu.column_name='%s'"
-
-        result = cursor.execute(get_fkeyname_query % (db_name, table_name, name))
-        
-        # If a foreign key exists, we need to delete it first
-        if result > 0:
+        result = 0
+        for kind, cname in self.lookup_constraint(db_name, table_name, name):
+            if kind == 'FOREIGN_KEY':
+                result += 1
+                fkey_name = cname
+        if result:
             assert result == 1 # We should only have one result, otherwise there's Issues
-            fkey_name = cursor.fetchone()[0]
+            cursor = self._get_connection().cursor()
             drop_query = "ALTER TABLE %s DROP FOREIGN KEY %s"
             cursor.execute(drop_query % (self.quote_name(table_name), self.quote_name(fkey_name)))
 
         super(DatabaseOperations, self).delete_column(table_name, name)
 
-    
+    @generic.invalidate_table_constraints
     def rename_table(self, old_table_name, table_name):
         """
         Renames the table 'old_table_name' to 'table_name'.
             return
         params = (self.quote_name(old_table_name), self.quote_name(table_name))
         self.execute('RENAME TABLE %s TO %s;' % params)
-    
-    
-    def _constraints_affecting_columns(self, table_name, columns, type="UNIQUE"):
-        """
-        Gets the names of the constraints affecting the given columns.
-        If columns is None, returns all constraints of the type on the table.
-        """
-        
-        if self.dry_run:
-            raise ValueError("Cannot get constraints for columns during a dry run.")
-        
-        if columns is not None:
-            columns = set(columns)
-        
-        db_name = self._get_setting('NAME')
-        
-        # First, load all constraint->col mappings for this table.
-        rows = self.execute("""
-            SELECT kc.constraint_name, kc.column_name
-            FROM information_schema.key_column_usage AS kc
-            JOIN information_schema.table_constraints AS c ON
-                kc.table_schema = c.table_schema AND
-                kc.table_name = c.table_name AND
-                kc.constraint_name = c.constraint_name
-            WHERE
-                kc.table_schema = %s AND
-                kc.table_catalog IS NULL AND
-                kc.table_name = %s AND
-                c.constraint_type = %s
-        """, [db_name, table_name, type])
-        
-        # Load into a dict
-        mapping = {}
-        for constraint, column in rows:
-            mapping.setdefault(constraint, set())
-            mapping[constraint].add(column)
-        
-        # Find ones affecting these columns
-        for constraint, itscols in mapping.items():
-            if itscols == columns or columns is None:
-                yield constraint
-    
-    
+
     def _field_sanity(self, field):
         """
         This particular override stops us sending DEFAULTs for BLOB/TEXT columns.

south/db/oracle.py

 import re
 import cx_Oracle
 
+
 from django.db import connection, models
 from django.db.backends.util import truncate_name
 from django.core.management.color import no_style
 
         return upper and tn.upper() or tn.lower()
 
+    @generic.invalidate_table_constraints
     def create_table(self, table_name, fields): 
         qn = self.quote_name(table_name, upper = False)
         qn_upper = qn.upper()
             self.execute(autoinc_sql[0])
             self.execute(autoinc_sql[1])
 
+    @generic.invalidate_table_constraints
     def delete_table(self, table_name, cascade=True):
         qn = self.quote_name(table_name, upper = False)
 
             self.execute('DROP TABLE %s;' % qn.upper())
         self.execute('DROP SEQUENCE %s;'%get_sequence_name(qn))
 
+    @generic.invalidate_table_constraints
     def alter_column(self, table_name, name, field, explicit_name=True):
         qn = self.quote_name(table_name)
 
                 if str(exc).find('ORA-01442') == -1:
                     raise
 
+    @generic.invalidate_table_constraints
     def add_column(self, table_name, name, field, keep_default=True):
         qn = self.quote_name(table_name, upper = False)
         sql = self.column_sql(qn, name, field)
             field.default = int(field.to_python(field.get_default()))
         return field
 
-    def _constraints_affecting_columns(self, table_name, columns, type='UNIQUE'):
-        """
-        Gets the names of the constraints affecting the given columns.
-        """
+
+
+    def _fill_constraint_cache(self, db_name, table_name):
         qn = self.quote_name
 
-        if self.dry_run:
-            raise ValueError("Cannot get constraints for columns during a dry run.")
-        columns = set(columns)
         rows = self.execute("""
-            SELECT user_cons_columns.constraint_name, user_cons_columns.column_name
+            SELECT user_cons_columns.constraint_name,
+                   user_cons_columns.column_name,
+                   user_constraints.constraint_type
             FROM user_constraints
             JOIN user_cons_columns ON
                  user_constraints.table_name = user_cons_columns.table_name AND 
                  user_constraints.constraint_name = user_cons_columns.constraint_name
-            WHERE user_constraints.table_name = '%s' AND
-                  user_constraints.constraint_type = '%s'
-        """ % (qn(table_name), self.constraits_dict[type]))
-        # Load into a dict
-        mapping = {}
-        for constraint, column in rows:
-            mapping.setdefault(constraint, set())
-            mapping[constraint].add(column)
-        # Find ones affecting these columns
-        for constraint, itscols in mapping.items():
-            if itscols == columns:
-                yield constraint
+            WHERE user_constraints.table_name = '%s'
+        """ % (qn(table_name)))
+
+        for constraint, column, kind in rows:
+            self._constraint_cache[db_name][table_name].setdefault(column, set())
+            self._constraint_cache[db_name][table_name][column].add((kind, constraint))
+        return

south/db/postgresql_psycopg2.py

     
     backend_name = "postgres"
 
+    @generic.copy_column_constraints
+    @generic.delete_column_constraints
     def rename_column(self, table_name, old, new):
         if old == new:
             # Short-circuit out
             self.quote_name(old),
             self.quote_name(new),
         ))
-    
+
+    @generic.invalidate_table_constraints
     def rename_table(self, old_table_name, table_name):
         "will rename the table and an associated ID sequence and primary key index"
         # First, rename the table

south/db/sqlite3.py

         self._remake_table(table_name, added={
             field.column: self._column_sql_for_create(table_name, name, field, False),
         })
-    
-    def _remake_table(self, table_name, added={}, renames={}, deleted=[], altered={},
-                      primary_key_override=None, uniques_deleted=[]):
+
+    @generic.invalidate_table_constraints
+    def _remake_table(self, table_name, added={}, renames={}, deleted=[], altered={}, primary_key_override=None, uniques_deleted=[]):
         """
         Given a table and three sets of changes (renames, deletes, alters),
         recreates it with the modified schema.

south/tests/db.py

 import unittest
 
-from south.db import db
+from south.db import db, generic
 from django.db import connection, models
 
 # Create a list of error classes from the various database libraries
     pass
 errors = tuple(errors)
 
+try:
+    from south.db import mysql
+except ImportError:
+    mysql = None
+
 class TestOperations(unittest.TestCase):
 
     """
         db.add_column("test_add_unique_fk", "mock2", models.OneToOneField(db.mock_model('Mock', 'mock'), null=True))
         
         db.delete_table("test_add_unique_fk")
+
+class TestCacheGeneric(unittest.TestCase):
+    base_ops_cls = generic.DatabaseOperations
+    def setUp(self):
+        class CacheOps(self.base_ops_cls):
+            def __init__(self):
+                self._constraint_cache = {}
+                self.cache_filled = 0
+                self.settings = {'NAME' : 'db'}
+
+            def _fill_constraint_cache(self, db, table):
+                self.cache_filled += 1
+                self._constraint_cache.setdefault(db, {})
+                self._constraint_cache[db].setdefault(table, {})
+
+            @generic.invalidate_table_constraints
+            def clear_con(self, table):
+                pass
+
+            @generic.copy_column_constraints
+            def cp_column(self, table, column_old, column_new):
+                pass
+
+            @generic.delete_column_constraints
+            def rm_column(self, table, column):
+                pass
+
+            @generic.copy_column_constraints
+            @generic.delete_column_constraints
+            def mv_column(self, table, column_old, column_new):
+                pass
+
+            def _get_setting(self, attr):
+                return self.settings[attr]
+        self.CacheOps = CacheOps
+
+    def test_cache(self):
+        ops = self.CacheOps()
+        self.assertEqual(0, ops.cache_filled)
+        self.assertFalse(ops.lookup_constraint('db', 'table'))
+        self.assertEqual(1, ops.cache_filled)
+        self.assertFalse(ops.lookup_constraint('db', 'table'))
+        self.assertEqual(1, ops.cache_filled)
+        ops.clear_con('table')
+        self.assertEqual(1, ops.cache_filled)
+        self.assertFalse(ops.lookup_constraint('db', 'table'))
+        self.assertEqual(2, ops.cache_filled)
+        self.assertFalse(ops.lookup_constraint('db', 'table', 'column'))
+        self.assertEqual(2, ops.cache_filled)
+
+        cache = ops._constraint_cache
+        cache['db']['table']['column'] = 'constraint'
+        self.assertEqual('constraint', ops.lookup_constraint('db', 'table', 'column'))
+        self.assertEqual([('column', 'constraint')], ops.lookup_constraint('db', 'table'))
+        self.assertEqual(2, ops.cache_filled)
+
+        # invalidate_table_constraints
+        ops.clear_con('new_table')
+        self.assertEqual('constraint', ops.lookup_constraint('db', 'table', 'column'))
+        self.assertEqual(2, ops.cache_filled)
+
+        self.assertFalse(ops.lookup_constraint('db', 'new_table'))
+        self.assertEqual(3, ops.cache_filled)
+
+        # delete_column_constraints
+        cache['db']['table']['column'] = 'constraint'
+        self.assertEqual('constraint', ops.lookup_constraint('db', 'table', 'column'))
+        ops.rm_column('table', 'column')
+        self.assertEqual([], ops.lookup_constraint('db', 'table', 'column'))
+        self.assertEqual([], ops.lookup_constraint('db', 'table', 'noexist_column'))
+
+        # copy_column_constraints
+        cache['db']['table']['column'] = 'constraint'
+        self.assertEqual('constraint', ops.lookup_constraint('db', 'table', 'column'))
+        import sys
+        ops.cp_column('table', 'column', 'column_new')
+        self.assertEqual('constraint', ops.lookup_constraint('db', 'table', 'column_new'))
+        self.assertEqual('constraint', ops.lookup_constraint('db', 'table', 'column'))
+
+        # copy + delete
+        cache['db']['table']['column'] = 'constraint'
+        self.assertEqual('constraint', ops.lookup_constraint('db', 'table', 'column'))
+        ops.mv_column('table', 'column', 'column_new')
+        self.assertEqual('constraint', ops.lookup_constraint('db', 'table', 'column_new'))
+        self.assertEqual([], ops.lookup_constraint('db', 'table', 'column'))
+        return
+
+    def test_valid(self):
+        ops = self.CacheOps()
+        # none of these should vivify a table into a valid state
+        self.assertFalse(ops._is_valid_cache('db', 'table'))
+        self.assertFalse(ops._is_valid_cache('db', 'table'))
+        ops.clear_con('table')
+        self.assertFalse(ops._is_valid_cache('db', 'table'))
+        ops.rm_column('table', 'column')
+        self.assertFalse(ops._is_valid_cache('db', 'table'))
+
+        # these should change the cache state
+        ops.lookup_constraint('db', 'table')
+        self.assertTrue(ops._is_valid_cache('db', 'table'))
+        ops.lookup_constraint('db', 'table', 'column')
+        self.assertTrue(ops._is_valid_cache('db', 'table'))
+        ops.clear_con('table')
+        self.assertFalse(ops._is_valid_cache('db', 'table'))
+
+    def test_valid_implementation(self):
+        # generic fills the cache on a per-table basis
+        ops = self.CacheOps()
+        self.assertFalse(ops._is_valid_cache('db', 'table'))
+        self.assertFalse(ops._is_valid_cache('db', 'other_table'))
+        ops.lookup_constraint('db', 'table')
+        self.assertTrue(ops._is_valid_cache('db', 'table'))
+        self.assertFalse(ops._is_valid_cache('db', 'other_table'))
+        ops.lookup_constraint('db', 'other_table')
+        self.assertTrue(ops._is_valid_cache('db', 'table'))
+        self.assertTrue(ops._is_valid_cache('db', 'other_table'))
+        ops.clear_con('table')
+        self.assertFalse(ops._is_valid_cache('db', 'table'))
+        self.assertTrue(ops._is_valid_cache('db', 'other_table'))
+
+if mysql:
+    class TestCacheMysql(TestCacheGeneric):
+        base_ops_cls = mysql.DatabaseOperations
+
+        def test_valid_implementation(self):
+            # mysql fills the cache on a per-db basis
+            ops = self.CacheOps()
+            self.assertFalse(ops._is_valid_cache('db', 'table'))
+            self.assertFalse(ops._is_valid_cache('db', 'other_table'))
+            ops.lookup_constraint('db', 'table')
+            cache = ops._constraint_cache
+            self.assertTrue(ops._is_valid_cache('db', 'table'))
+            self.assertTrue(ops._is_valid_cache('db', 'other_table'))
+            ops.lookup_constraint('db', 'other_table')
+            self.assertTrue(ops._is_valid_cache('db', 'table'))
+            self.assertTrue(ops._is_valid_cache('db', 'other_table'))
+            ops.clear_con('table')
+            self.assertFalse(ops._is_valid_cache('db', 'table'))
+            self.assertTrue(ops._is_valid_cache('db', 'other_table'))