1. idank
  2. sqlalchemy

Commits

Mike Bayer  committed 7c541ef

streamlined engine.schemagenerator and engine.schemadropper methodology
added support for creating PassiveDefault (i.e. regular DEFAULT) on table columns
postgres can reflect default values via information_schema
added unittests for PassiveDefault values getting created, inserted, coming back in result sets

  • Participants
  • Parent commits 409ff01
  • Branches default

Comments (0)

Files changed (10)

File lib/sqlalchemy/ansisql.py

View file
 
 class ANSISQLEngine(sqlalchemy.engine.SQLEngine):
 
-    def schemagenerator(self, proxy, **params):
-        return ANSISchemaGenerator(proxy, **params)
+    def schemagenerator(self, **params):
+        return ANSISchemaGenerator(self, **params)
     
-    def schemadropper(self, proxy, **params):
-        return ANSISchemaDropper(proxy, **params)
+    def schemadropper(self, **params):
+        return ANSISchemaDropper(self, **params)
 
     def compiler(self, statement, parameters, **kwargs):
         return ANSICompiler(self, statement, parameters, **kwargs)
 
 
 class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator):
-
     def get_column_specification(self, column, override_pk=False, first_pk=False):
         raise NotImplementedError()
         
     def post_create_table(self, table):
         return ''
 
+    def get_column_default_string(self, column):
+        if isinstance(column.default, schema.PassiveDefault):
+            if not isinstance(column.default.arg, str):
+                arg = str(column.default.arg.compile(self.engine))
+            else:
+                arg = column.default.arg
+            return arg
+        else:
+            return None
+
     def visit_column(self, column):
         pass
     

File lib/sqlalchemy/databases/information_schema.py

View file
     Column("character_maximum_length", Integer),
     Column("numeric_precision", Integer),
     Column("numeric_scale", Integer),
+    Column("column_default", Integer),
     schema="information_schema")
     
 gen_constraints = schema.Table("table_constraints", generic_engine,
         row = c.fetchone()
         if row is None:
             break
-#        print "row! " + repr(row)
+        #print "row! " + repr(row)
  #       continue
-        (name, type, nullable, charlen, numericprec, numericscale) = (
+        (name, type, nullable, charlen, numericprec, numericscale, default) = (
             row[columns.c.column_name], 
             row[columns.c.data_type], 
             row[columns.c.is_nullable] == 'YES', 
             row[columns.c.character_maximum_length],
             row[columns.c.numeric_precision],
             row[columns.c.numeric_scale],
+            row[columns.c.column_default]
             )
 
         args = []
         coltype = ischema_names[type]
         #print "coltype " + repr(coltype) + " args " +  repr(args)
         coltype = coltype(*args)
-        table.append_item(schema.Column(name, coltype, nullable = nullable))
+        colargs= []
+        if default is not None:
+            colargs.append(PassiveDefault(default))
+        table.append_item(schema.Column(name, coltype, nullable=nullable, *colargs))
 
     s = select([constraints.c.constraint_name, constraints.c.constraint_type, constraints.c.table_name, key_constraints], use_labels=True)
     if not use_mysql:

File lib/sqlalchemy/databases/mysql.py

View file
     def compiler(self, statement, bindparams, **kwargs):
         return MySQLCompiler(self, statement, bindparams, **kwargs)
 
-    def schemagenerator(self, proxy, **params):
-        return MySQLSchemaGenerator(proxy, **params)
+    def schemagenerator(self, **params):
+        return MySQLSchemaGenerator(self, **params)
 
     def get_default_schema_name(self):
         if not hasattr(self, '_default_schema_name'):
         self.mysql_engine = mysql_engine
 
 class MySQLCompiler(ansisql.ANSICompiler):
+
+    def visit_function(self, func):
+        if len(func.clauses):
+            super(MySQLCompiler, self).visit_function(func)
+        else:
+            self.strings[func] = func.name
+
     def limit_clause(self, select):
         text = ""
         if select.limit is not None:
 class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, override_pk=False, first_pk=False):
         colspec = column.name + " " + column.type.get_col_spec()
+        default = self.get_column_default_string(column)
+        if default is not None:
+            colspec += " DEFAULT " + default
 
         if not column.nullable:
             colspec += " NOT NULL"

File lib/sqlalchemy/databases/oracle.py

View file
     def compiler(self, statement, bindparams, **kwargs):
         return OracleCompiler(self, statement, bindparams, use_ansi=self._use_ansi, **kwargs)
 
-    def schemagenerator(self, proxy, **params):
-        return OracleSchemaGenerator(proxy, **params)
-    def schemadropper(self, proxy, **params):
-        return OracleSchemaDropper(proxy, **params)
+    def schemagenerator(self, **params):
+        return OracleSchemaGenerator(self, **params)
+    def schemadropper(self, **params):
+        return OracleSchemaDropper(self, **params)
     def defaultrunner(self, proxy):
         return OracleDefaultRunner(self, proxy)
         
     def get_column_specification(self, column, override_pk=False, **kwargs):
         colspec = column.name
         colspec += " " + column.type.get_col_spec()
+        default = self.get_column_default_string(column)
+        if default is not None:
+            colspec += " DEFAULT " + default
 
         if not column.nullable:
             colspec += " NOT NULL"

File lib/sqlalchemy/databases/postgres.py

View file
     def compiler(self, statement, bindparams, **kwargs):
         return PGCompiler(self, statement, bindparams, **kwargs)
 
-    def schemagenerator(self, proxy, **params):
-        return PGSchemaGenerator(proxy, **params)
+    def schemagenerator(self, **params):
+        return PGSchemaGenerator(self, **params)
 
-    def schemadropper(self, proxy, **params):
-        return PGSchemaDropper(proxy, **params)
+    def schemadropper(self, **params):
+        return PGSchemaDropper(self, **params)
 
     def defaultrunner(self, proxy):
         return PGDefaultRunner(self, proxy)
 
 class PGCompiler(ansisql.ANSICompiler):
 
+    def visit_function(self, func):
+        if len(func.clauses):
+            super(PGCompiler, self).visit_function(func)
+        else:
+            self.strings[func] = func.name
+
     def visit_insert_column(self, column):
         # Postgres advises against OID usage and turns it off in 8.1,
         # effectively making cursor.lastrowid
         return text
         
 class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
+        
     def get_column_specification(self, column, override_pk=False, **kwargs):
         colspec = column.name
-        if isinstance(column.default, schema.PassiveDefault):
-            colspec += " DEFAULT " + column.default.text
-        elif column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
+        if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
             colspec += " SERIAL"
         else:
             colspec += " " + column.type.get_col_spec()
+            default = self.get_column_default_string(column)
+            if default is not None:
+                colspec += " DEFAULT " + default
 
         if not column.nullable:
             colspec += " NOT NULL"

File lib/sqlalchemy/databases/sqlite.py

View file
     def dbapi(self):
         return sqlite
 
-    def schemagenerator(self, proxy, **params):
-        return SQLiteSchemaGenerator(proxy, **params)
+    def schemagenerator(self, **params):
+        return SQLiteSchemaGenerator(self, **params)
 
     def reflecttable(self, table):
         c = self.execute("PRAGMA table_info(" + table.name + ")", {})
 class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, override_pk=False, **kwargs):
         colspec = column.name + " " + column.type.get_col_spec()
+        default = self.get_column_default_string(column)
+        if default is not None:
+            colspec += " DEFAULT " + default
+
         if not column.nullable:
             colspec += " NOT NULL"
         if column.primary_key and not override_pk:

File lib/sqlalchemy/engine.py

View file
     
 class SchemaIterator(schema.SchemaVisitor):
     """a visitor that can gather text into a buffer and execute the contents of the buffer."""
-    def __init__(self, sqlproxy, **params):
+    def __init__(self, engine, **params):
         """initializes this SchemaIterator and initializes its buffer.
         
         sqlproxy - a callable function returned by SQLEngine.proxy(), which executes a
         statement plus optional parameters.
         """
-        self.sqlproxy = sqlproxy
+        self.engine = engine
         self.buffer = StringIO.StringIO()
 
     def append(self, s):
         """executes the contents of the SchemaIterator's buffer using its sql proxy and
         clears out the buffer."""
         try:
-            return self.sqlproxy(self.buffer.getvalue())
+            return self.engine.execute(self.buffer.getvalue(), None)
         finally:
             self.buffer.truncate(0)
 
         """returns a sql.text() object for performing literal queries."""
         return sql.text(text, engine=self, *args, **kwargs)
         
-    def schemagenerator(self, proxy, **params):
+    def schemagenerator(self, **params):
         """returns a schema.SchemaVisitor instance that can generate schemas, when it is
-        invoked to traverse a set of schema objects.  The 
-        "proxy" argument is a callable will execute a given string SQL statement
-        and a dictionary or list of parameters.  
+        invoked to traverse a set of schema objects. 
         
         schemagenerator is called via the create() method.
         """
         raise NotImplementedError()
 
-    def schemadropper(self, proxy, **params):
+    def schemadropper(self, **params):
         """returns a schema.SchemaVisitor instance that can drop schemas, when it is
-        invoked to traverse a set of schema objects.  The 
-        "proxy" argument is a callable will execute a given string SQL statement
-        and a dictionary or list of parameters.  
+        invoked to traverse a set of schema objects. 
         
         schemagenerator is called via the drop() method.
         """
         
     def create(self, table, **params):
         """creates a table within this engine's database connection given a schema.Table object."""
-        table.accept_visitor(self.schemagenerator(self.proxy(), **params))
+        table.accept_visitor(self.schemagenerator(**params))
 
     def drop(self, table, **params):
         """drops a table within this engine's database connection given a schema.Table object."""
-        table.accept_visitor(self.schemadropper(self.proxy(), **params))
+        table.accept_visitor(self.schemadropper(**params))
 
     def compile(self, statement, parameters, **kwargs):
         """given a sql.ClauseElement statement plus optional bind parameters, creates a new
         """implementations might want to put logic here for turning autocommit on/off, etc."""
         connection.commit()
 
-    def proxy(self, **kwargs):
-        """provides a callable that will execute the given string statement and parameters.
-        The statement and parameters should be in the format specific to the particular database;
-        i.e. named or positional."""
-        return lambda s, p = None: self.execute(s, p, **kwargs)
-
     def connection(self):
         """returns a managed DBAPI connection from this SQLEngine's connection pool."""
         return self._pool.connect()

File lib/sqlalchemy/schema.py

View file
 from sqlalchemy.types import *
 import copy, re, string
 
-__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'SchemaEngine', 'SchemaVisitor']
+__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'SchemaEngine', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault']
 
 
 class SchemaItem(object):
 
 class PassiveDefault(DefaultGenerator):
     """a default that takes effect on the database side"""
-    def __init__(self, text):
-        self.text = text
+    def __init__(self, arg):
+        self.arg = arg
     def accept_visitor(self, visitor):
-        return visitor_visit_passive_default(self)
+        return visitor.visit_passive_default(self)
     def __repr__(self):
-        return "PassiveDefault(%s)" % repr(self.text)
+        return "PassiveDefault(%s)" % repr(self.arg)
         
 class ColumnDefault(DefaultGenerator):
     """A plain default value on a column.  this could correspond to a constant, 

File test/engines.py

View file
 class EngineTest(PersistTest):
     def testbasic(self):
         # really trip it up with a circular reference
+        
+        use_function_defaults = testbase.db.engine.__module__.endswith('postgres') or testbase.db.engine.__module__.endswith('oracle')
+        
+        if use_function_defaults:
+            defval = func.current_date()
+            deftype = Date
+        else:
+            defval = "3"
+            deftype = Integer
+            
         users = Table('engine_users', testbase.db,
             Column('user_id', INT, primary_key = True),
             Column('user_name', VARCHAR(20), nullable = False),
             Column('test6', DateTime, nullable = False),
             Column('test7', String),
             Column('test8', Binary),
+            Column('test_passivedefault', deftype, PassiveDefault(defval)),
             Column('test9', Binary(100)),
             mysql_engine='InnoDB'
         )

File test/query.py

View file
 import sqlalchemy.databases.sqlite as sqllite
 
 db = testbase.db
-
+db.echo='debug'
 from sqlalchemy import *
 from sqlalchemy.engine import ResultProxy, RowProxy
 
         def mydefault():
             x['x'] += 1
             return x['x']
-            
+
+        use_function_defaults = db.engine.__module__.endswith('postgres') or db.engine.__module__.endswith('oracle')
+        
         # select "count(1)" from the DB which returns different results
         # on different DBs
-        f = select([func.count(1)], engine=db).execute().fetchone()[0]
-        
+        f = select([func.count(1)], engine=db).scalar()
+        if use_function_defaults:
+            def1 = func.current_date()
+            def2 = "current_date"
+            deftype = Date
+            ts = select([func.current_date()], engine=db).scalar()
+        else:
+            def1 = def2 = "3"
+            ts = 3
+            deftype = Integer
+            
         t = Table('default_test1', db, 
             Column('col1', Integer, primary_key=True, default=mydefault),
             Column('col2', String(20), default="imthedefault"),
             Column('col3', Integer, default=func.count(1)),
+            Column('col4', deftype, PassiveDefault(def1)),
+            Column('col5', deftype, PassiveDefault(def2))
         )
         t.create()
         try:
             t.insert().execute()
         
             l = t.select().execute()
-            self.assert_(l.fetchall() == [(1, 'imthedefault', f), (2, 'imthedefault', f), (3, 'imthedefault', f)])
+            self.assert_(l.fetchall() == [(1, 'imthedefault', f, ts, ts), (2, 'imthedefault', f, ts, ts), (3, 'imthedefault', f, ts, ts)])
         finally:
             t.drop()