Commits

Andrew Godwin  committed f7a9743 Merge

Branch merge

  • Participants
  • Parent commits f1e34ab, db135ed

Comments (0)

Files changed (9)

File south/creator/freezer.py

 import sys
 
 from django.db import models
+from django.db.models.base import ModelBase, Model
 from django.contrib.contenttypes.generic import GenericRelation
 
 from south.orm import FakeORM
-from south.utils import auto_model
+from south.utils import get_attribute, auto_through
 from south import modelsinspector
 
 def freeze_apps(apps):
     checked_models = checked_models or set()
     # Get deps for each field
     for field in model._meta.fields + model._meta.many_to_many:
-        depends.update(field_dependencies(field))
+        depends.update(field_dependencies(field, checked_models))
     # Add in any non-abstract bases
     for base in model.__bases__:
         if issubclass(base, models.Model) and hasattr(base, '_meta') and not base._meta.abstract:
 def field_dependencies(field, checked_models=None):
     checked_models = checked_models or set()
     depends = set()
-    if isinstance(field, (models.OneToOneField, models.ForeignKey, models.ManyToManyField, GenericRelation)):
-        if field.rel.to in checked_models:
-            return depends
-        checked_models.add(field.rel.to)
-        depends.add(field.rel.to)
-        depends.update(field_dependencies(field.rel.to._meta.pk, checked_models))
-        # Also include M2M throughs
-        if isinstance(field, models.ManyToManyField):
-            if field.rel.through:
-                if hasattr(field.rel, "through_model"): # 1.1 and below
-                    depends.add(field.rel.through_model)
-                else:
-                    # Make sure it's not an automatic one
-                    if not auto_model(field.rel.through):
-                        depends.add(field.rel.through) # 1.2 and up
+    arg_defs, kwarg_defs = modelsinspector.matching_details(field)
+    for attrname, options in arg_defs + kwarg_defs.values():
+        if options.get("ignore_if_auto_through", False) and auto_through(field):
+            continue
+        if options.get("is_value", False):
+            value = attrname
+        elif attrname == 'rel.through' and hasattr(getattr(field, 'rel', None), 'through_model'):
+            # Hack for django 1.1 and below, where the through model is stored
+            # in rel.through_model while rel.through stores only the model name.
+            value = field.rel.through_model
+        else:
+            try:
+                value = get_attribute(field, attrname)
+            except AttributeError:
+                if options.get("ignore_missing", False):
+                    continue
+                raise
+        if isinstance(value, Model):
+            value = value.__class__
+        if not isinstance(value, ModelBase):
+            continue
+        if getattr(value._meta, "proxy", False):
+            value = value._meta.proxy_for_model
+        if value in checked_models:
+            continue
+        checked_models.add(value)
+        depends.add(value)
+        depends.update(model_dependencies(value, checked_models))
+
     return depends
 
 ### Prettyprinters

File south/db/generic.py

     delete_column_string = 'ALTER TABLE %s DROP COLUMN %s CASCADE;'
     create_primary_key_string = "ALTER TABLE %(table)s ADD CONSTRAINT %(constraint)s PRIMARY KEY (%(columns)s)"
     delete_primary_key_sql = "ALTER TABLE %(table)s DROP CONSTRAINT %(constraint)s"
+    add_check_constraint_fragment = "ADD CONSTRAINT %(constraint)s CHECK (%(check)s)"
     backend_name = None
     default_schema_name = "public"
 
         except TypeError:
             return field.db_type()
         
+    def _alter_add_column_mods(self, field, name, params, sqls):
+        """
+        Subcommand of alter_column that modifies column definitions beyond
+        the type string -- e.g. adding constraints where they cannot be specified
+        as part of the type (overrideable)
+        """
+        pass
+
     def _alter_set_defaults(self, field, name, params, sqls): 
         "Subcommand of alter_column that sets default values (overrideable)"
         # Next, set any default
         if params["type"] is not None:
             sqls.append((self.alter_string_set_type % params, []))
         
+        # Add any field- and backend- specific modifications
+        self._alter_add_column_mods(field, name, params, sqls)
         # Next, nullity
         if field.null:
             sqls.append((self.alter_string_set_null % params, []))
         MockModel._meta.model = MockModel
         return MockModel
 
+    def _db_positive_type_for_alter_column(self, field):
+        """
+        A helper for subclasses overriding _db_type_for_alter_column:
+        Remove the check constraint from the type string for PositiveInteger
+        and PositiveSmallInteger fields.
+        @param field: The field to generate type for
+        """
+        super_result = super(type(self), self)._db_type_for_alter_column(field)
+        if isinstance(field, (models.PositiveSmallIntegerField, models.PositiveIntegerField)):
+            return super_result.split(" ", 1)[0]
+        return super_result
+        
+    def _alter_add_positive_check(self, field, name, params, sqls):
+        """
+        A helper for subclasses overriding _alter_add_column_mods:
+        Add a check constraint verifying positivity to PositiveInteger and
+        PositiveSmallInteger fields.
+        """
+        super(type(self), self)._alter_add_column_mods(field, name, params, sqls)
+        if isinstance(field, (models.PositiveSmallIntegerField, models.PositiveIntegerField)):
+            uniq_hash = abs(hash(tuple(params.values()))) 
+            d = dict(
+                     constraint = "CK_%s_PSTV_%s" % (name, hex(uniq_hash)[2:]),
+                     check = "%s > 0" % self.quote_name(name))
+            sqls.append((self.add_check_constraint_fragment % d, []))
+    
+
 
 # Single-level flattening of lists
 def flatten(ls):

File south/db/oracle.py

 import os.path
 import sys
 import re
+import warnings
 import cx_Oracle
 
 
 from django.db import connection, models
 from django.db.backends.util import truncate_name
 from django.core.management.color import no_style
-from django.db.backends.oracle.base import get_sequence_name
 from django.db.models.fields import NOT_PROVIDED
 from django.db.utils import DatabaseError
+
+# In revision r16016 function get_sequence_name has been transformed into
+# method of DatabaseOperations class. To make code backward-compatible we
+# need to handle both situations.
+try:
+    from django.db.backends.oracle.base import get_sequence_name\
+        as original_get_sequence_name
+except ImportError:
+    original_get_sequence_name = None
+
 from south.db import generic
 
-print >> sys.stderr, " ! WARNING: South's Oracle support is still alpha."
-print >> sys.stderr, " !          Be wary of possible bugs."
+warnings.warn("! WARNING: South's Oracle support is still alpha. "
+              "Be wary of possible bugs.")
 
 class DatabaseOperations(generic.DatabaseOperations):    
     """
         'R': 'FOREIGN KEY'
     }
 
+    def get_sequence_name(self, table_name):
+        if original_get_sequence_name is None:
+            return self._get_connection().ops._get_sequence_name(table_name)
+        else:
+            return original_get_sequence_name(table_name)
+
     def adj_column_sql(self, col):
         col = re.sub('(?P<constr>CHECK \(.*\))(?P<any>.*)(?P<default>DEFAULT \d+)', 
                      lambda mo: '%s %s%s'%(mo.group('default'), mo.group('constr'), mo.group('any')), col) #syntax fix for boolean/integer field only
         EXECUTE IMMEDIATE 'DROP SEQUENCE "%(sq_name)s"';
     END IF;
 END;
-/""" % {'sq_name': get_sequence_name(table_name)}
+/""" % {'sq_name': self.get_sequence_name(table_name)}
         self.execute(sequence_sql)
 
     @generic.invalidate_table_constraints

File south/db/postgresql_psycopg2.py

         "Rename an index individually"
         generic.DatabaseOperations.rename_table(self, old_index_name, index_name)
 
-    def _db_type_for_alter_column(self, field):
-        """
-        Returns a field's type suitable for ALTER COLUMN.
-        Strips CHECKs from PositiveSmallIntegerField) and PositiveIntegerField
-        @param field: The field to generate type for
-        """
-        super_result = super(DatabaseOperations, self)._db_type_for_alter_column(field)
-        if isinstance(field, models.PositiveSmallIntegerField) or isinstance(field, models.PositiveIntegerField):
-            return super_result.split(" ")[0]
-        return super_result
+    _db_type_for_alter_column = generic.alias("_db_positive_type_for_alter_column")
+    _alter_add_column_mods = generic.alias("_alter_add_positive_check")

File south/db/sql_server/pyodbc.py

     drop_constraint_string = 'ALTER TABLE %(table_name)s DROP CONSTRAINT %(constraint_name)s'
     delete_column_string = 'ALTER TABLE %s DROP COLUMN %s'
 
-    create_check_constraint_sql = "ALTER TABLE %(table)s ADD CONSTRAINT %(constraint)s CHECK (%(check)s)"
+    #create_check_constraint_sql = "ALTER TABLE %(table)s " + \
+    #                              generic.DatabaseOperations.add_check_constraint_fragment 
     create_foreign_key_sql = "ALTER TABLE %(table)s ADD CONSTRAINT %(constraint)s " + \
                              "FOREIGN KEY (%(column)s) REFERENCES %(target)s"
     create_unique_sql = "ALTER TABLE %(table)s ADD CONSTRAINT %(constraint)s UNIQUE (%(columns)s)"
         ret_val = super(DatabaseOperations, self).alter_column(table_name, name, field, explicit_name, ignore_constraints=True)
         
         if not ignore_constraints:
+            unique_field_handled = False
             for cname, (ctype,args) in constraints.items():
                 params = dict(table = table,
                               constraint = qn(cname))
                 if ctype=='UNIQUE':
-                    #TODO: This preserves UNIQUE constraints, but does not yet create them when necessary
+                    if len(args)==1:
+                        unique_field_handled = True # 
                     if len(args)>1 or field.unique:
                         params['columns'] = ", ".join(map(qn,args))
                         sql = self.create_unique_sql % params
+                    else:
+                        continue
                 elif ctype=='PRIMARY KEY':
                     params['columns'] = ", ".join(map(qn,args))
                     sql = self.create_primary_key_string % params
                 else:
                     raise NotImplementedError("Don't know how to handle constraints of type "+ type)                    
                 self.execute(sql, [])
+            # Create unique constraint if necessary
+            if field.unique and not unique_field_handled:
+                self.create_unique(table_name, (name,))
             # Create foreign key if necessary
             if field.rel and self.supports_foreign_keys:
                 self.execute(
         params = (self.quote_name(old_table_name), self.quote_name(table_name))
         self.execute('EXEC sp_rename %s, %s' % params)
 
-    # Copied from South's psycopg2 backend
-    def _db_type_for_alter_column(self, field):
-        """
-        Returns a field's type suitable for ALTER COLUMN.
-        Strips CHECKs from PositiveSmallIntegerField) and PositiveIntegerField
-        @param field: The field to generate type for
-        """
-        super_result = super(DatabaseOperations, self)._db_type_for_alter_column(field)
-        if isinstance(field, models.PositiveSmallIntegerField) or isinstance(field, models.PositiveIntegerField):
-            return super_result.split(" ")[0]
-        return super_result
+    _db_type_for_alter_column = generic.alias("_db_positive_type_for_alter_column")
+    _alter_add_column_mods = generic.alias("_alter_add_positive_check")
 
     @invalidate_table_constraints
     def delete_foreign_key(self, table_name, column):

File south/introspection_plugins/geodjango.py

                     "srid": ["srid", {"default": 4326}],
                     "spatial_index": ["spatial_index", {"default": True}],
                     "dim": ["dim", {"default": 2}],
+                    "geography": ["geography", {"default": False}],
                 },
             ),
         ]

File south/tests/__init__.py

     from south.tests.autodetection import *
     from south.tests.logger import *
     from south.tests.inspector import *
+    from south.tests.freezer import *

File south/tests/fakeapp/models.py

 from django.db import models
 from django.contrib.auth.models import User as UserAlias
 
+from south.modelsinspector import add_introspection_rules
+
 def default_func():
     return "yays"
 
 # Special case.
 class Other2(models.Model):
     # Try loading a field without a newline after it (inspect hates this)
-    close_but_no_cigar = models.PositiveIntegerField(primary_key=True)
+    close_but_no_cigar = models.PositiveIntegerField(primary_key=True)
+
+class CustomField(models.IntegerField):
+    def __init__(self, an_other_model, **kwargs):
+        super(CustomField, self).__init__(**kwargs)
+        self.an_other_model = an_other_model
+
+add_introspection_rules([
+    (
+        [CustomField],
+        [],
+        {'an_other_model': ('an_other_model', {})},
+    ),
+], ['^south\.tests\.fakeapp\.models\.CustomField'])
+
+class BaseModel(models.Model):
+    pass
+
+class SubModel(BaseModel):
+    others = models.ManyToManyField(Other1)
+    custom = CustomField(Other2)
+
+class CircularA(models.Model):
+    c = models.ForeignKey('CircularC')
+
+class CircularB(models.Model):
+    a = models.ForeignKey(CircularA)
+
+class CircularC(models.Model):
+    b = models.ForeignKey(CircularB)
+
+class Recursive(models.Model):
+   self = models.ForeignKey('self')

File south/tests/freezer.py

+import unittest
+
+from south.creator.freezer import model_dependencies
+from south.tests.fakeapp import models
+
+class TestFreezer(unittest.TestCase):
+    def test_dependencies(self):
+        self.assertEqual(set(model_dependencies(models.SubModel)),
+                         set([models.BaseModel, models.Other1, models.Other2]))
+
+        self.assertEqual(set(model_dependencies(models.CircularA)),
+                         set([models.CircularA, models.CircularB, models.CircularC]))
+
+        self.assertEqual(set(model_dependencies(models.Recursive)),
+                         set([models.Recursive]))