charettes avatar charettes committed 516a37d

Make sure `delete_foreign_key` also works with recursive relationships

Comments (0)

Files changed (1)

south/db/generic.py

 
     @cached_property
     def has_ddl_transactions(self):
-        "Tests the database using feature detection to see if it has DDL transactional support"
+        """
+        Tests the database using feature detection to see if it has
+        transactional DDL support.
+        """
         self._possibly_initialise()
         connection = self._get_connection()
         if hasattr(connection.features, "confirm") and not connection.features._confirmed:
 
     @invalidate_table_constraints
     def delete_foreign_key(self, table_name, column):
-        "Drop a foreign key constraint"
+        """
+        Drop a foreign key constraint
+        """
         if self.dry_run:
             if self.debug:
                 print '   - no dry run output for delete_foreign_key() due to dynamic DDL, sorry'
     drop_foreign_key = alias('delete_foreign_key')
 
     def _find_foreign_constraints(self, table_name, column_name=None):
-        return list(self._constraints_affecting_columns(
-                    table_name, [column_name], "FOREIGN KEY"))
+        constraints = self._constraints_affecting_columns(
+                            table_name, [column_name], "FOREIGN KEY")
+        
+        primary_key_columns = self._find_primary_key_columns(table_name)
+        
+        if len(primary_key_columns) > 1:
+            # Composite primary keys cannot be referenced by a foreign key
+            return list(constraints)
+        else:
+            primary_key_columns.add(column_name)
+            recursive_constraints = set(self._constraints_affecting_columns(
+                                table_name, primary_key_columns, "FOREIGN KEY"))
+            return list(recursive_constraints.union(constraints))
 
     def _digest(self, *args):
         """
             "columns": ", ".join(map(self.quote_name, columns)),
         })
 
+    def _find_primary_key_columns(self, table_name):
+        """
+        Find all columns of the primary key of the specified table
+        """
+        db_name = self._get_setting('NAME')
+        
+        primary_key_columns = set()
+        for col, constraints in self.lookup_constraint(db_name, table_name):
+            for kind, cname in constraints:
+                if kind == 'PRIMARY KEY':
+                    primary_key_columns.add(col.lower())
+                    
+        return primary_key_columns
+
     def start_transaction(self):
         """
         Makes sure the following commands are inside a transaction.
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.