F. Gabriel Gosselin avatar F. Gabriel Gosselin committed 3bbf98e

Adding referenced table/column lookup for constraints (extension of existing lookup)
First feature addition toward copying column constraints when renaming tables or columns

Comments (0)

Files changed (4)

south/db/generic.py

             if self.debug:
                 print '   - no dry run output for delete_foreign_key() due to dynamic DDL, sorry'
             return # We can't look at the DB to get the constraints
-        constraints = list(self._constraints_affecting_columns(table_name, [column], "FOREIGN KEY"))
+        constraints = self._find_foreign_contraints(table_name, column)
         if not constraints:
             raise ValueError("Cannot find a FOREIGN KEY constraint on table %s, column %s" % (table_name, column))
         for constraint_name in constraints:
 
     drop_foreign_key = alias('delete_foreign_key')
 
+    def _find_foreign_constraints(self, table_name, column_name):
+          return list(self._constraints_affecting_columns(
+                    table_name, [column_name], "FOREIGN KEY"))
 
     def create_index_name(self, table_name, column_names, suffix=""):
         """

south/db/mysql.py

     """
     MySQL implementation of database operations.
     
-    MySQL is an 'interesting' database; it has no DDL transaction support,
-    among other things. This can confuse people when they ask how they can
-    roll back - hence the dry runs, etc., found in the migration code.
-    Alex agrees, and Alex is always right.
-    [19:06] <Alex_Gaynor> Also, I want to restate once again that MySQL is a special database
-    
-    (Still, if you want a key-value store with relational tendancies, go MySQL!)
+    MySQL has no DDL transaction support This can confuse people when they ask
+    how to roll back - hence the dry runs, etc., found in the migration code.
     """
     
     backend_name = "mysql"
     geom_types = ['geometry', 'point', 'linestring', 'polygon']
     text_types = ['text', 'blob',]
 
+    def __init__(self, db_alias):
+        self._constraint_references = {}
+
+        super(DatabaseOperations, self).__init__(db_alias)
+ 
     def _is_valid_cache(self, db_name, table_name):
         cache = self._constraint_cache
         # we cache the whole db so if there are any tables table_name is valid
         # for MySQL grab all constraints for this database.  It's just as cheap as a single column.
         self._constraint_cache[db_name] = {}
         self._constraint_cache[db_name][table_name] = {}
+        self._constraint_references[db_name] = {}
 
         name_query = """
-            SELECT kc.constraint_name, kc.column_name, kc.table_name
+            SELECT kc.`constraint_name`, kc.`column_name`, kc.`table_name`,
+                kc.`referenced_table_name`, kc.`referenced_column_name`
             FROM information_schema.key_column_usage AS kc
             WHERE
                 kc.table_schema = %s
         if not rows:
             return
         cnames = {}
-        for constraint, column, table in rows:
+        for constraint, column, table, ref_table, ref_column in rows:
             key = (table, constraint)
             cnames.setdefault(key, set())
-            cnames[key].add(column)
+            cnames[key].add((column, ref_table, ref_column))
 
         type_query = """
-            SELECT c.constraint_name, c.table_name, c.constraint_type
-            FROM information_schema.table_constraints AS c
+            SELECT c.`constraint_name`, c.`table_name`, c.`constraint_type`
+            FROM `information_schema`.`table_constraints` AS c
             WHERE
-                c.table_schema = %s
+                c.`table_schema` = %s
         """
         rows = self.execute(type_query, [db_name])
         for constraint, table, kind in rows:
                 cols = cnames[key]
             except KeyError:
                 cols = set()
-            for column in cols:
+            for column_set in cols:
+                (column, ref_table, ref_column) = column_set
                 self._constraint_cache[db_name][table].setdefault(column, set())
-                self._constraint_cache[db_name][table][column].add((kind, constraint))
-
+                if kind == 'FOREIGN KEY':
+                    self._constraint_cache[db_name][table][column].add((kind,
+                        constraint))
+                    # Create constraint lookup, see constraint_references
+                    self._constraint_references[db_name][(table,
+                        constraint)] = (ref_table, ref_column)
+                else:
+                    self._constraint_cache[db_name][table][column].add((kind,
+                    constraint))
 
     def connection_init(self):
         """
         params = (self.quote_name(old_table_name), self.quote_name(table_name))
         self.execute('RENAME TABLE %s TO %s;' % params)
 
+    def constraint_references(self, table_name, cname):
+        """
+        Provide an existing table and constraint, returns tuple of (foreign
+        table, column)
+        """
+        db_name = self._get_setting('NAME')
+        try:
+            return self._constraint_references[db_name][(table_name, cname)]
+        except KeyError:
+            return None
+
     def _field_sanity(self, field):
         """
         This particular override stops us sending DEFAULTs for BLOB/TEXT columns.

south/tests/__init__.py

 
 if not skiptest:
     from south.tests.db import *
+    from south.tests.db_mysql import *
     from south.tests.logic import *
     from south.tests.autodetection import *
     from south.tests.logger import *

south/tests/db_mysql.py

+import unittest
+
+from south.db import db, generic, mysql
+from django.db import connection, models
+
+class TestMySQLOperations(unittest.TestCase):
+    """MySQL-specific tests"""
+    def setUp(self):
+        db.debug = False
+        db.clear_deferred_sql()
+
+    def tearDown(self):
+        pass
+
+    def _create_foreign_tables(self, main_name, reference_name):
+        # Create foreign table and model
+        Foreign = db.mock_model(model_name='Foreign', db_table=reference_name,
+                                db_tablespace='', pk_field_name='id',
+                                pk_field_type=models.AutoField,
+                                pk_field_args=[])
+        db.create_table(reference_name, [
+                ('id', models.AutoField(primary_key=True)),
+            ])
+        # Create table with foreign key
+        db.create_table(main_name, [
+                ('id', models.AutoField(primary_key=True)),
+                ('foreign', models.ForeignKey(Foreign)),
+            ])
+        return Foreign
+
+    def test_constraint_references(self):
+        """Tests that referred table is reported accurately"""
+        main_table = 'test_cns_ref'
+        reference_table = 'test_cr_foreign'
+        db.start_transaction()
+        self._create_foreign_tables(main_table, reference_table)
+        db.execute_deferred_sql()
+        db_name = db._get_setting('NAME')
+        constraint = db._find_foreign_constraints(main_table, 'foreign_id')[0]
+        constraint_name = 'foreign_id_refs_id_%x' % (abs(hash((main_table,
+            reference_table))))
+        print constraint + ': ' + constraint_name
+        self.assertEquals(constraint_name, constraint)
+        references = db.constraint_references(main_table, constraint)
+        self.assertEquals((reference_table, 'id'), references)
+
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.