Commits

F. Gabriel Gosselin committed 8121be5

Add reverse lookup to detect inbound foreign references (foreign keys referring to given table.column)
Step towards fixing table/column rename failure

Comments (0)

Files changed (2)

south/db/mysql.py

 
     def __init__(self, db_alias):
         self._constraint_references = {}
-
+        self._reverse_cache = {}
         super(DatabaseOperations, self).__init__(db_alias)
  
     def _is_valid_cache(self, db_name, table_name):
         # for MySQL grab all constraints for this database.  It's just as cheap as a single column.
         self._constraint_cache[db_name] = {}
         self._constraint_cache[db_name][table_name] = {}
+        self._reverse_cache[db_name] = {}
         self._constraint_references[db_name] = {}
 
         name_query = """
             cnames[key].add((column, ref_table, ref_column))
 
         type_query = """
-            SELECT c.`constraint_name`, c.`table_name`, c.`constraint_type`
-            FROM `information_schema`.`table_constraints` AS c
+            SELECT c.constraint_name, c.table_name, c.constraint_type
+            FROM information_schema.table_constraints AS c
             WHERE
-                c.`table_schema` = %s
+                c.table_schema = %s
         """
         rows = self.execute(type_query, [db_name])
         for constraint, table, kind in rows:
                     # Create constraint lookup, see constraint_references
                     self._constraint_references[db_name][(table,
                         constraint)] = (ref_table, ref_column)
+                    # Create reverse table lookup, reverse_lookup
+                    self._reverse_cache[db_name].setdefault(ref_table, {})
+                    self._reverse_cache[db_name][ref_table].setdefault(ref_column,
+                            set())
+                    self._reverse_cache[db_name][ref_table][ref_column].add(
+                            (constraint, table, column))
                 else:
                     self._constraint_cache[db_name][table][column].add((kind,
                     constraint))
         params = (self.quote_name(old_table_name), self.quote_name(table_name))
         self.execute('RENAME TABLE %s TO %s;' % params)
 
-    def constraint_references(self, table_name, cname):
+    def _lookup_constraint_references(self, table_name, cname):
         """
-        Provide an existing table and constraint, returns tuple of (foreign
+        Provided an existing table and constraint, returns tuple of (foreign
         table, column)
         """
         db_name = self._get_setting('NAME')
         except KeyError:
             return None
 
+    def _lookup_reverse_constraint(self, table_name, column_name=None):
+        """Look for the column referenced by a foreign constraint"""
+        db_name = self._get_setting('NAME')
+        if self.dry_run:
+            raise DryRunError("Cannot get constraints for columns.")
+
+        if not self._is_valid_cache(db_name, table_name):
+            # Piggy-back on lookup_constraint, ensures cache exists
+            self.lookup_constraint(db_name, table_name)
+
+        try:
+            table = self._reverse_cache[db_name][table_name]
+            if column_name == None:
+                return table.items()
+            else:
+                return table[column_name]
+        except KeyError, e:
+            return []
+
     def _field_sanity(self, field):
         """
         This particular override stops us sending DEFAULTs for BLOB/TEXT columns.
         is_text = True in [ type.find(t) > -1 for t in self.text_types ]
         if not is_geom and not is_text:
             super(DatabaseOperations, self)._alter_set_defaults(field, name, params, sqls)
+
+class DryRunError(ValueError):
+    pass

south/tests/db_mysql.py

         constraint = db._find_foreign_constraints(main_table, 'foreign_id')[0]
         constraint_name = 'foreign_id_refs_id_%x' % (abs(hash((main_table,
             reference_table))))
-        print constraint + ': ' + constraint_name
         self.assertEquals(constraint_name, constraint)
-        references = db.constraint_references(main_table, constraint)
+        references = db._lookup_constraint_references(main_table, constraint)
         self.assertEquals((reference_table, 'id'), references)
 
+    def test_reverse_column_constraint(self):
+        """Tests that referred column in a foreign key (ex. id) is found"""
+        main_table = 'test_reverse_ref'
+        reference_table = 'test_rr_foreign'
+        db.start_transaction()
+        self._create_foreign_tables(main_table, reference_table)
+        db.execute_deferred_sql()
+        db_name = db._get_setting('NAME')
+        inverse = db._lookup_reverse_constraint(reference_table, 'id')
+        # Hard to extract single value from set, .pop affects cache
+        (cname, rev_table, rev_column) = tuple(inverse)[0]
+        self.assertEquals(main_table, rev_table)
+        self.assertEquals('foreign_id', rev_column)
+
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.