Commits

Andrew Godwin  committed a7cc9cb Merge

Branch merge

  • Participants
  • Parent commits 5cc9f23, b241f6a

Comments (0)

Files changed (4)

File south/db/sql_server/pyodbc.py

+from datetime import date, datetime, time
+from warnings import warn
 from django.db import models
 from django.db.models import fields
 from south.db import generic
+from south.db.generic import delete_column_constraints, invalidate_table_constraints, copy_column_constraints
+from south.exceptions import ConstraintDropped
 
 class DatabaseOperations(generic.DatabaseOperations):
     """
     drop_index_string = 'DROP INDEX %(index_name)s ON %(table_name)s'
     drop_constraint_string = 'ALTER TABLE %(table_name)s DROP CONSTRAINT %(constraint_name)s'
     delete_column_string = 'ALTER TABLE %s DROP COLUMN %s'
+
+    create_check_constraint_sql = "ALTER TABLE %(table)s ADD CONSTRAINT %(constraint)s CHECK (%(check)s)"
+    create_foreign_key_sql = "ALTER TABLE %(table)s ADD CONSTRAINT %(constraint)s " + \
+                             "FOREIGN KEY (%(column)s) REFERENCES %(target)s"
+    create_unique_sql = "ALTER TABLE %(table)s ADD CONSTRAINT %(constraint)s UNIQUE (%(columns)s)"
+    
     
     default_schema_name = "dbo"
 
 
+    @delete_column_constraints
     def delete_column(self, table_name, name):
         q_table_name, q_name = (self.quote_name(table_name), self.quote_name(name))
 
         return [i[0] for i in idx]
 
 
-    def _find_constraints_for_column(self, table_name, name):
+    def _find_constraints_for_column(self, table_name, name, just_names=True):
         """
         Find the constraints that apply to a column, needed when deleting. Defaults not included.
         This is more general than the parent _constraints_affecting_columns, as on MSSQL this
         """
 
         sql = """
-        SELECT  CONSTRAINT_NAME
-        FROM INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE
-        WHERE CONSTRAINT_CATALOG = TABLE_CATALOG
-          AND CONSTRAINT_SCHEMA = TABLE_SCHEMA
-          AND TABLE_CATALOG = %s
-          AND TABLE_SCHEMA = %s
-          AND TABLE_NAME = %s
-          AND COLUMN_NAME = %s 
+         SELECT CC.[CONSTRAINT_NAME]
+              ,TC.[CONSTRAINT_TYPE]
+              ,CHK.[CHECK_CLAUSE]
+              ,RFD.TABLE_SCHEMA
+              ,RFD.TABLE_NAME
+              ,RFD.COLUMN_NAME
+              -- used for normalized names
+              ,CC.TABLE_NAME
+              ,CC.COLUMN_NAME
+          FROM [INFORMATION_SCHEMA].[TABLE_CONSTRAINTS] TC
+          JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE CC
+               ON TC.CONSTRAINT_CATALOG = CC.CONSTRAINT_CATALOG 
+              AND TC.CONSTRAINT_SCHEMA = CC.CONSTRAINT_SCHEMA
+              AND TC.CONSTRAINT_NAME = CC.CONSTRAINT_NAME
+          LEFT JOIN INFORMATION_SCHEMA.CHECK_CONSTRAINTS CHK
+               ON CHK.CONSTRAINT_CATALOG = CC.CONSTRAINT_CATALOG
+              AND CHK.CONSTRAINT_SCHEMA = CC.CONSTRAINT_SCHEMA
+              AND CHK.CONSTRAINT_NAME = CC.CONSTRAINT_NAME
+              AND 'CHECK' = TC.CONSTRAINT_TYPE
+          LEFT JOIN INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS REF
+               ON REF.CONSTRAINT_CATALOG = CC.CONSTRAINT_CATALOG
+              AND REF.CONSTRAINT_SCHEMA = CC.CONSTRAINT_SCHEMA
+              AND REF.CONSTRAINT_NAME = CC.CONSTRAINT_NAME
+              AND 'FOREIGN KEY' = TC.CONSTRAINT_TYPE
+          LEFT JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE RFD
+               ON RFD.CONSTRAINT_CATALOG = REF.UNIQUE_CONSTRAINT_CATALOG
+              AND RFD.CONSTRAINT_SCHEMA = REF.UNIQUE_CONSTRAINT_SCHEMA
+              AND RFD.CONSTRAINT_NAME = REF.UNIQUE_CONSTRAINT_NAME
+          WHERE CC.CONSTRAINT_CATALOG = CC.TABLE_CATALOG
+            AND CC.CONSTRAINT_SCHEMA = CC.TABLE_SCHEMA
+            AND CC.TABLE_CATALOG = %s
+            AND CC.TABLE_SCHEMA = %s
+            AND CC.TABLE_NAME = %s
+            AND CC.COLUMN_NAME = %s 
         """
         db_name = self._get_setting('name')
         schema_name = self._get_schema_name()
-        cons = self.execute(sql, [db_name, schema_name, table_name, name])
-        return [c[0] for c in cons]
+        table = self.execute(sql, [db_name, schema_name, table_name, name])
+        
+        if just_names:
+            return [r[0] for r in table]
+        
+        all = {}
+        for r in table:
+            cons_name, type = r[:2]
+            if type=='PRIMARY KEY' or type=='UNIQUE':
+                cons = all.setdefault(cons_name, (type,[]))
+                cons[1].append(r[7])
+            elif type=='CHECK':
+                cons = (type, r[2])
+            elif type=='FOREIGN KEY':
+                if cons_name in all:
+                    raise NotImplementedError("Multiple-column foreign keys are not supported")
+                else:
+                    cons = (type, r[3:6])
+            else:
+                raise NotImplementedError("Don't know how to handle constraints of type "+ type)
+            all[cons_name] = cons
+        return all
 
+    @invalidate_table_constraints        
+    def alter_column(self, table_name, name, field, explicit_name=True, ignore_constraints=False):
+        """
+        Alters the given column name so it will match the given field.
+        Note that conversion between the two by the database must be possible.
+        Will not automatically add _id by default; to have this behavour, pass
+        explicit_name=False.
+
+        @param table_name: The name of the table to add the column to
+        @param name: The name of the column to alter
+        @param field: The new field definition to use
+        """
+        self._fix_field_definition(field)
+
+        if not ignore_constraints:
+            qn = self.quote_name
+            sch = qn(self._get_schema_name())
+            tab = qn(table_name)
+            table = ".".join([sch, tab])
+            constraints = self._find_constraints_for_column(table_name, name, False)
+            for constraint in constraints.keys():
+                params = dict(table_name = table,
+                              constraint_name = qn(constraint))
+                sql = self.drop_constraint_string % params
+                self.execute(sql, [])
+                
+        ret_val = super(DatabaseOperations, self).alter_column(table_name, name, field, explicit_name, ignore_constraints=True)
+        
+        if not ignore_constraints:
+            for cname, (ctype,args) in constraints.items():
+                params = dict(table = table,
+                              constraint = qn(cname))
+                if ctype=='UNIQUE':
+                    #TODO: This preserves UNIQUE constraints, but does not yet create them when necessary
+                    if len(args)>1 or field.unique:
+                        params['columns'] = ", ".join(map(qn,args))
+                        sql = self.create_unique_sql % params
+                elif ctype=='PRIMARY KEY':
+                    params['columns'] = ", ".join(map(qn,args))
+                    sql = self.create_primary_key_string % params
+                elif ctype=='FOREIGN KEY':
+                    continue
+                    # Foreign keys taken care of below 
+                    #target = "%s.%s(%s)" % tuple(map(qn,args))
+                    #params.update(column = qn(name), target = target)
+                    #sql = self.create_foreign_key_sql % params
+                elif ctype=='CHECK':
+                    warn(ConstraintDropped("CHECK "+ args, table_name, name))
+                    continue
+                    #TODO: Some check constraints should be restored; but not before the generic
+                    #      backend restores them.
+                    #params['check'] = args
+                    #sql = self.create_check_constraint_sql % params
+                else:
+                    raise NotImplementedError("Don't know how to handle constraints of type "+ type)                    
+                self.execute(sql, [])
+            # Create foreign key if necessary
+            if field.rel and self.supports_foreign_keys:
+                self.execute(
+                    self.foreign_key_sql(
+                        table_name,
+                        field.column,
+                        field.rel.to._meta.db_table,
+                        field.rel.to._meta.get_field(field.rel.field_name).column
+                    )
+                )
+
+
+        return ret_val
+    
     def _alter_set_defaults(self, field, name, params, sqls): 
         "Subcommand of alter_column that sets default values (overrideable)"
         # First drop the current default if one exists
             
         # Next, set any default
         
-        if field.has_default(): # was: and not field.null
+        if field.has_default():
             default = field.get_default()
-            sqls.append(('ADD DEFAULT %%s for %s' % (self.quote_name(name),), [default]))
-        #else:
-        #    sqls.append(('ALTER COLUMN %s DROP DEFAULT' % (self.quote_name(name),), []))
+            literal = self._value_to_unquoted_literal(field, default)
+            sqls.append(('ADD DEFAULT %s for %s' % (self._quote_string(literal), self.quote_name(name),), []))
+
+    def _value_to_unquoted_literal(self, field, value):
+        # Start with the field's own translation
+        conn = self._get_connection()
+        value = field.get_db_prep_save(value, connection=conn)
+        # This is still a Python object -- nobody expects to need a literal.
+        if isinstance(value, basestring):
+            return smart_unicode(value)
+        elif isinstance(value, (date,time,datetime)):
+            return value.isoformat()
+        else:
+            #TODO: Anybody else needs special translations?
+            return str(value) 
+
+    def _quote_string(self, s):
+        return "'" + s.replace("'","''") + "'"
+    
 
     def drop_column_default_sql(self, table_name, name, q_name=None):
         "MSSQL specific drop default, which is a pain"
         return None
 
     def _fix_field_definition(self, field):
-        if isinstance(field, fields.BooleanField):
+        if isinstance(field, (fields.BooleanField, fields.NullBooleanField)):
             if field.default == True:
                 field.default = 1
             if field.default == False:
                 field.default = 0
 
+    # This is copied from South's generic add_column, with two modifications:
+    # 1) The sql-server-specific call to _fix_field_definition
+    # 2) Removing a default, when needed, by calling drop_default and not the more general alter_column
+    @invalidate_table_constraints
     def add_column(self, table_name, name, field, keep_default=True):
+        """
+        Adds the column 'name' to the table 'table_name'.
+        Uses the 'field' paramater, a django.db.models.fields.Field instance,
+        to generate the necessary sql
+
+        @param table_name: The name of the table to add the column to
+        @param name: The name of the column to add
+        @param field: The field to use
+        """
         self._fix_field_definition(field)
-        generic.DatabaseOperations.add_column(self, table_name, name, field, keep_default)
+        sql = self.column_sql(table_name, name, field)
+        if sql:
+            params = (
+                self.quote_name(table_name),
+                sql,
+            )
+            sql = self.add_column_string % params
+            self.execute(sql)
 
+            # Now, drop the default if we need to
+            if not keep_default and field.default is not None:
+                field.default = fields.NOT_PROVIDED
+                #self.alter_column(table_name, name, field, explicit_name=False, ignore_constraints=True)
+                self.drop_default(table_name, name, field)
+
+    @invalidate_table_constraints
+    def drop_default(self, table_name, name, field):
+        fragment = self.drop_column_default_sql(table_name, name)
+        if fragment:
+            table_name = self.quote_name(table_name)
+            sql = " ".join(["ALTER TABLE", table_name, fragment])
+            self.execute(sql)        
+
+
+    @invalidate_table_constraints
     def create_table(self, table_name, field_defs):
         # Tweak stuff as needed
         for _, f in field_defs:
         schema_name = self._get_schema_name()
         return self.execute(sql, [db_name, schema_name, table_name])
                 
+    @invalidate_table_constraints
     def delete_table(self, table_name, cascade=True):
         """
         Deletes the table 'table_name'.
             cascade = False
         super(DatabaseOperations, self).delete_table(table_name, cascade)
             
+    @copy_column_constraints
+    @delete_column_constraints
     def rename_column(self, table_name, old, new):
         """
         Renames the column of 'table_name' from 'old' to 'new'.
         params = (table_name, self.quote_name(old), self.quote_name(new))
         self.execute("EXEC sp_rename '%s.%s', %s, 'COLUMN'" % params)
 
+    @invalidate_table_constraints
     def rename_table(self, old_table_name, table_name):
         """
         Renames the table 'old_table_name' to 'table_name'.
             return super_result.split(" ")[0]
         return super_result
 
+    @invalidate_table_constraints
+    def delete_foreign_key(self, table_name, column):
+        super(DatabaseOperations, self).delete_foreign_key(table_name, column)
+        # A FK also implies a non-unique index
+        find_index_sql = """
+            SELECT i.name -- s.name, t.name,  c.name
+            FROM sys.tables t
+            INNER JOIN sys.schemas s ON t.schema_id = s.schema_id
+            INNER JOIN sys.indexes i ON i.object_id = t.object_id
+            INNER JOIN sys.index_columns ic ON ic.object_id = t.object_id
+            INNER JOIN sys.columns c ON c.object_id = t.object_id 
+                                     AND ic.column_id = c.column_id
+            WHERE i.is_unique=0 AND i.is_primary_key=0 AND i.is_unique_constraint=0
+              AND s.name = %s
+              AND t.name = %s
+              AND c.name = %s
+            """
+        schema = self._get_schema_name()
+        indexes = self.execute(find_index_sql, [schema, table_name, column])
+        qn = self.quote_name
+        for index in (i[0] for i in indexes):
+            self.execute("DROP INDEX %s on %s.%s" % (qn(index), qn(schema), qn(table_name) ))
+            

File south/exceptions.py

 class SouthError(RuntimeError):
     pass
 
+class SouthWarning(RuntimeWarning):
+    pass
 
 class BrokenMigration(SouthError):
     def __init__(self, migration, exc_info):
 class ImpossibleORMUnfreeze(SouthError):
     """Raised if the ORM can't manage to unfreeze all the models in a linear fashion."""
     pass
+
+class ConstraintDropped(SouthWarning):
+    def __init__(self, constraint, table, column=None):
+        self.table = table
+        self.column = (".%s" % column) if column else ""
+        self.constraint = constraint
+    
+    def __str__(self):
+        return "Constraint %(constraint)s was dropped from %(table)s%(column)s -- was this intended?" % self.__dict__  

File south/tests/db.py

         db.create_unique("test_unique", ["spam"])
         db.commit_transaction()
         db.start_transaction()
+
+        # Special preparations for Sql Server
+        if db.backend_name == "pyodbc":
+            db.execute("SET IDENTITY_INSERT test_unique2 ON;")
         
         # Test it works
+        TRUE = (True,)
+        FALSE = (False,)
         db.execute("INSERT INTO test_unique2 (id) VALUES (1)")
         db.execute("INSERT INTO test_unique2 (id) VALUES (2)")
-        db.execute("INSERT INTO test_unique (spam, eggs, ham_id) VALUES (true, 0, 1)")
-        db.execute("INSERT INTO test_unique (spam, eggs, ham_id) VALUES (false, 1, 2)")
+        db.execute("INSERT INTO test_unique (spam, eggs, ham_id) VALUES (%s, 0, 1)", TRUE)
+        db.execute("INSERT INTO test_unique (spam, eggs, ham_id) VALUES (%s, 1, 2)", FALSE)
         try:
-            db.execute("INSERT INTO test_unique (spam, eggs, ham_id) VALUES (true, 2, 1)")
+            db.execute("INSERT INTO test_unique (spam, eggs, ham_id) VALUES (%s, 2, 1)", FALSE)
         except:
             db.rollback_transaction()
         else:
         db.start_transaction()
         
         # Test similarly
-        db.execute("INSERT INTO test_unique (spam, eggs, ham_id) VALUES (true, 0, 1)")
-        db.execute("INSERT INTO test_unique (spam, eggs, ham_id) VALUES (false, 1, 2)")
+        db.execute("INSERT INTO test_unique (spam, eggs, ham_id) VALUES (%s, 0, 1)", TRUE)
+        db.execute("INSERT INTO test_unique (spam, eggs, ham_id) VALUES (%s, 1, 2)", FALSE)
         try:
-            db.execute("INSERT INTO test_unique (spam, eggs, ham_id) VALUES (true, 1, 1)")
+            db.execute("INSERT INTO test_unique (spam, eggs, ham_id) VALUES (%s, 1, 1)", TRUE)
         except:
             db.rollback_transaction()
         else:
         db.create_unique("test_unique", ["spam", "eggs", "ham_id"])
         db.start_transaction()
         # Test similarly
-        db.execute("INSERT INTO test_unique (spam, eggs, ham_id) VALUES (true, 0, 1)")
-        db.execute("INSERT INTO test_unique (spam, eggs, ham_id) VALUES (false, 1, 1)")
+        db.execute("INSERT INTO test_unique (spam, eggs, ham_id) VALUES (%s, 0, 1)", TRUE)
+        db.execute("INSERT INTO test_unique (spam, eggs, ham_id) VALUES (%s, 1, 1)", FALSE)
         try:
-            db.execute("INSERT INTO test_unique (spam, eggs, ham_id) VALUES (true, 0, 1)")
+            db.execute("INSERT INTO test_unique (spam, eggs, ham_id) VALUES (%s, 0, 1)", TRUE)
         except:
             db.rollback_transaction()
         else:
     
     def test_capitalised_constraints(self):
         """
-        Under PostgreSQL at least, capitalised constrains must be quoted.
+        Under PostgreSQL at least, capitalised constraints must be quoted.
         """
         db.create_table("test_capconst", [
             ('SOMECOL', models.PositiveIntegerField(primary_key=True)),

File south/tests/logic.py

 import unittest
 
 import datetime
+import sys
 
 from south import exceptions
 from south.migration import migrate_app
     
     def test_alter_column_null(self):
         
-        def null_ok():
+        def null_ok(eat_exception=True):
             from django.db import connection, transaction
             # the DBAPI introspection module fails on postgres NULLs.
             cursor = connection.cursor()
             # SQLite has weird now()
             if db.backend_name == "sqlite3":
                 now_func = "DATETIME('NOW')"
+            # So does SQLServer... should we be using a backend attribute?
+            elif db.backend_name == "pyodbc":
+                now_func = "GETDATE()"
             else:
                 now_func = "NOW()"
             
             try:
+                if db.backend_name == "pyodbc":
+                    cursor.execute("SET IDENTITY_INSERT southtest_spam ON;")
                 cursor.execute("INSERT INTO southtest_spam (id, weight, expires, name) VALUES (100, 10.1, %s, NULL);" % now_func)
             except:
-                transaction.rollback()
-                return False
+                if eat_exception:
+                    transaction.rollback()
+                    return False
+                else:
+                    raise
             else:
                 cursor.execute("DELETE FROM southtest_spam")
                 transaction.commit()
         
         # after 0003, it should be NULL
         migrate_app(migrations, target_name="0003", fake=False)
-        self.assert_(null_ok())
+        self.assert_(null_ok(False))
         self.assertListEqual(
             ((u"fakeapp", u"0001_spam"),
              (u"fakeapp", u"0002_eggs"),