Anonymous avatar Anonymous committed c408d14

changed create_table and add_column to use a list of django.db.models.fields.Field classes to represent fields now.
django now takes care of abstracting the SQL create syntax for the various DB's it supports, so far less code that south needs.

NOTE: startmigration is currently broken with this change.
Need to write automated tests for this stuff, too.

Comments (0)

Files changed (3)

 
 from django.db import connection, transaction, models
+from django.db.backends.util import truncate_name
 from django.dispatch import dispatcher
 
 class DatabaseOperations(object):
     Some of this code comes from Django Evolution.
     """
 
-    types = {
-        "varchar": "VARCHAR",
-        "text": "TEXT",
-        "integer": "INT",
-        "boolean": "BOOLEAN",
-        "serial": "SERIAL",
-        "datetime": "TIMESTAMP WITH TIME ZONE",
-        "float": "DOUBLE PRECISION",
-    }
-
     def __init__(self):
         self.debug = False
-
-
-    def get_type(self, name, param=None):
-        """
-        Generic type-converting method, to smooth things over.
-        """
-        if name in ["text", "string"]:
-            if param:
-                return "%s(%s)" % (self.types['varchar'], param)
-            else:
-                return self.types['text']
-        else:
-            return self.types[name]
+        self.deferred_sql = []
 
 
     def execute(self, sql, params=[]):
             return cursor.fetchall()
         except:
             return []
+            
+            
+    def add_deferred_sql(self, sql):
+        """
+        Add a SQL statement to the deferred list, that won't be executed until
+        this instance's execute_deferred_sql method is run.
+        """
+        self.deferred_sql.append(sql)
+        
+        
+    def execute_deferred_sql(self):
+        """
+        Executes all deferred SQL, resetting the deferred_sql list
+        """
+        for sql in self.deferred_sql:
+            self.execute(sql)
+            
+        self.deferred_sql = []
 
 
-    def get_column_value(self, column, name):
+    def create_table(self, table_name, fields):
         """
-        Gets a column's something value from either a list or dict.
-        Useful for when both are passed into create_table in the column list.
-        """
-        defaults = {
-            "type_param": 0,
-            "unique": False,
-            "null": True,
-            "related_to": None,
-            "default": None,
-            "primary": False,
-        }
-        if isinstance(column, (list, tuple)):
-            try:
-                return column[{
-                    "name": 0,
-                    "type": 1,
-                    "type_param": 2,
-                    "unique": 3,
-                    "null": 4,
-                    "related_to": 5,
-                    "default": 6,
-                    "primary": 7,
-                }[name]]
-            except IndexError:
-                return defaults[name]
-        else:
-            return column.get(name, defaults.get(name, None))
-
-
-    def create_table(self, table_name, columns):
-        """
-        Creates the table 'table_name'. 'columns' is a list of columns
-        in the same format used by add_column (but as a list - think of its
-        positional arguments).
+        Creates the table 'table_name'. 'fields' is a tuple of fields,
+        each repsented by a 2-part tuple of field name and a
+        django.db.models.fields.Field object
         """
         qn = connection.ops.quote_name
-        defaults = tuple(self.get_column_value(column, "default") for column in columns)
         columns = [
-            self.column_sql(
-                column_name = self.get_column_value(column, "name"),
-                type_name = self.get_column_value(column, "type"),
-                type_param = self.get_column_value(column, "type_param"),
-                unique = self.get_column_value(column, "unique"),
-                null = self.get_column_value(column, "null"),
-                related_to = self.get_column_value(column, "related_to"),
-                default = self.get_column_value(column, "default"),
-            )
-            for column in columns
+            self.column_sql(table_name, field_name, field)
+            for field_name, field in fields
         ]
-        sqlparams = tuple()
-        for s, p in columns:
-            sqlparams += p
-        params = (
-            qn(table_name),
-            ", ".join([s for s,p in columns]),
-        )
-        self.execute('CREATE TABLE %s (%s);' % params, sqlparams)
+        
+        self.execute('CREATE TABLE %s (%s);' % (table_name, ', '.join([col for col in columns])))
     
     add_table = create_table # Alias for consistency's sake
 
     drop_table = delete_table
 
 
-    def add_column(self, table_name, name, type, type_param=None, unique=False, null=True, related_to=None, default=None, primary=False):
+    def add_column(self, table_name, name, field):
         """
-        Adds the column 'column_name' to the table 'table_name'.
-        The column will have type 'type_name', which is one of the generic
-        types South offers, such as 'string' or 'integer'.
+        Adds the column 'name' to the table 'table_name'.
+        Uses the 'field' paramater, a django.db.models.fields.Field instance,
+        to generate the necessary sql
         
         @param table_name: The name of the table to add the column to
-        @param column_name: The name of the column to add
-        @param type_name: The (generic) name of this column's type
-        @param type_param: An optional parameter to the type - e.g., its length
-        @param unique: Whether this column has UNIQUE set. Defaults to False.
-        @param null: If this column will be allowed to contain NULL values. Defaults to True.
-        @param related_to: A tuple of (table_name, column_name) for the column this references if it is a ForeignKey.
-        @param default: The default value for this column.
-        @param primary: If this is the primary key column.
+        @param name: The name of the column to add
+        @param field: The field to use
         """
         qn = connection.ops.quote_name
-        sql, sqlparams = self.column_sql(name, type, type_param, unique, null, related_to, default, primary)
+        sql = self.column_sql(table_name, name, field)
         params = (
             qn(table_name),
             sql,
         )
         sql = 'ALTER TABLE %s ADD COLUMN %s;' % params
-        self.execute(sql, sqlparams)
+        self.execute(sql)
 
 
-    def column_sql(self, column_name, type_name, type_param=None, unique=False, null=True, related_to=None, default=None, primary=False):
+    def column_sql(self, table_name, field_name, field, tablespace=''):
         """
         Creates the SQL snippet for a column. Used by add_column and add_table.
         """
         qn = connection.ops.quote_name
-        no_default = (not default)
-        if type_name == "serial":
-            no_default = True
-            null = False
-        params = (
-            qn(column_name),
-            self.get_type(type_name, type_param),
-            (unique and "UNIQUE " or "") + (null and "NULL" or "NOT NULL"),
-            related_to and ("REFERENCES %s (%s) %s" % (
-                related_to[0],  # Table name
-                related_to[1],  # Column name
-                connection.ops.deferrable_sql(), # Django knows this
-            )) or "",
-            not no_default and "DEFAULT %s" or "",
+        field.set_attributes_from_name(field_name)
+        sql = field.db_type()
+        if not sql:
+            return None
+            
+        field_output = [field.column, sql]
+        field_output.append('%sNULL' % (not field.null and 'NOT ' or ''))
+        if field.primary_key:
+            field_output.append('PRIMARY KEY')
+        elif field.unique:
+            field_output.append('UNIQUE')
+        
+        tablespace = field.db_tablespace or tablespace
+        if tablespace and connection.features.supports_tablespaces and field.unique:
+            # We must specify the index tablespace inline, because we
+            # won't be generating a CREATE INDEX statement for this field.
+            field_output.append(connection.ops.tablespace_sql(tablespace, inline=True))
+            
+        sql = ' '.join(field_output)
+        sqlparams = ()
+        # if the field is "NOT NULL" and a default value is provided, create the column with it
+        # this allows the addition of a NOT NULL field to a table with existing rows
+        if not field.null and field.has_default():
+            sql += " DEFAULT %s"
+            sqlparams = (field.get_default())
+        
+        if field.rel:
+            self.add_deferred_sql(
+                self.foreign_key_sql(
+                    table_name,
+                    field.column,
+                    field.rel.to._meta.db_table,
+                    field.rel.to._meta.get_field(field.rel.field_name).column
+                )
+            )
+            
+        return sql % sqlparams
+        
+    def foreign_key_sql(self, from_table_name, from_column_name, to_table_name, to_column_name):
+        """
+        Generates a full SQL statement to add a foreign key constraint
+        """
+        constraint_name = '%s_refs_%s_%x' % (from_column_name, to_column_name, abs(hash((from_table_name, to_table_name))))
+        return 'ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' % (
+            from_table_name,
+            truncate_name(constraint_name, connection.ops.max_name_length()),
+            from_column_name,
+            to_table_name,
+            to_column_name,
+            connection.ops.deferrable_sql() # Django knows this
         )
-        sqlparams = not no_default and (default,) or tuple() 
-        return '%s %s %s %s %s' % params, sqlparams
+        
 
 
     def delete_column(self, table_name, name):
 add_column(
     table_name,
     name,
-    type,
-    type_param=None,
-    unique=False,
-    null=True,
-    related_to=None,
-    default=None,
-    primary=False,
+    field,
 )
 
-Adds the column with name 'name' to the table 'table_name'. The column will
-have type 'type', which is one of the generic types listed in the
-'Generic Types' section below.
-
-Optional parameters:
-
- type_param: The parameter to the type. For example, 255 if you have
-             type='string' will end up with something like VARCHAR(255) in most
-             databases.
-
- unique: If this is True, tells the database this column must have a unique
-         value for each row.
-
- null: If this is True, NULL values will be allowed in the column.
-
- related_to: Provide a tuple of (table_name, column_name) to specify that
-             this column is a foreign key to some other table.
-
- default: The default value for this column. Note that this only really takes
-          effect for existing rows; the Django ORM generally applies the default
-          on the field before it hits the database.
-
- primary: If True, this field will be specified as the table's primary key.
+Adds the column with name 'name' to the table 'table_name'. Uses the
+'field' instance to determine the type and other options for the column.
 
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-create_table(table_name, columns)
+create_table(table_name, fields)
 
-Creates the table 'table_name' with the given list of columns 'columns'.
+Creates the table 'table_name' with the given list of columns 'fields'.
 
-'columns' is a list of dicts, where the dict keys follow the same scheme as the
-arguments to add_column, without the table_name
- (and the same optionality; thus, you always need to at least have 'name' and
- 'type' as keys in your dict).
+'fields' is a list of 2-part tuples, where the first part is the field name,
+and the second part is a valid django.db.models.fields.Field instance.
 
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 
 
 
-Generic types
-=============
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+EXAMPLES:
 
-This is a list of the currently-supported generic types:
+Adding a new table:
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+from django.db import models
+from south.db import db
+from mysite.myapp.models import *
+from mysite.anotherapp.models import Image
 
- Type        Parameterised?    Example
- ----------------------------------------------------------------------
- string      Optional          string -> TEXT
-                               string, 255 -> VARCHAR(255)
- float       No                float -> DOUBLE PRECISION
- integer     No                integer -> INT
- boolean     No                boolean -> BOOLEAN
- datetime    No                datetime -> TIMESTAMP WITH TIME ZONE
- serial      No                serial -> SERIAL (auto-incrementing int)
+class Migration:
+    
+    def forwards(self):
+        db.create_table('myapp_foobar', (
+            ('name' , models.CharField(max_length=256)),
+            ('sort_name' , models.CharField(max_length=256, blank=True)),
+            ('description' , models.TextField(blank=True)),
+            ('image' , models.ForeignKey(Image, related_name="foobar_images")),
+            ('date_added' , models.DateTimeField(auto_now_add=True)),
+            ('date_last_updated' , models.DateTimeField(auto_now=True)),
+        ))
+    
+    def backwards(self):
+        db.delete_table('myapp_foobar')
+
+
+Adding a new column to the myapp_foobar table:
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+from django.db import models
+from south.db import db
+from mysite.myapp.models import *
+
+class Migration:
+    
+    def forwards(self):
+        db.add_column('myapp_foobar', 'bar', 
+            models.ForeignKey(AnotherModelInMyApp, related_name='awesome_foobars', null=True))
+    
+    def backwards(self):
+        db.delete_column('myapp_foobar', 'bar_id')
+
+
             db.start_transaction()
             try:
                 klass().forwards()
+                db.execute_deferred_sql()
             except:
                 db.rollback_transaction()
                 raise
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.