Commits

Mike Bayer committed db1b486

got MS-SQL support largely working, including reflection, basic types, fair amount of ORM stuff, etc.
'rowcount' label is reseved in MS-SQL and had to change in sql.py count() as well as orm.query

Comments (0)

Files changed (8)

 working around new setuptools PYTHONPATH-killing behavior
 - further fixes with attributes/dependencies/etc....
 - improved error handling for when DynamicMetaData is not connected
+- MS-SQL support largely working (tested with pymssql)
 
 0.2.4
 - try/except when the mapper sets init.__name__ on a mapped class,

lib/sqlalchemy/databases/mssql.py

         [["Provider=SQLOLEDB;Data Source=%s;User Id=%s;Password=%s;Initial Catalog=%s" % (
             keys["host"], keys["user"], keys["password"], keys["database"])], {}]
     do_commit = False
+    sane_rowcount = True
 except:
     try:
         import pymssql as dbmodule
     except:
         dbmodule = None
         make_connect_string = lambda keys: [[],{}]
+    sane_rowcount = False
     
 class MSNumeric(sqltypes.Numeric):
     def convert_result_value(self, value, dialect):
             for c in compiled.statement.table.c:
                 if hasattr(c,'sequence'):
                     self.HASIDENT = True
-                    if parameters.has_key(c.name):
+                    if isinstance(parameters, list):
+                        if parameters[0].has_key(c.name):
+                            self.IINSERT = True
+                    elif parameters.has_key(c.name):
                         self.IINSERT = True
                     break
             if self.IINSERT:
                 proxy("SET IDENTITY_INSERT %s ON" % compiled.statement.table.name)
-
+	super(MSSQLExecutionContext, self).pre_exec(engine, proxy, compiled, parameters, **kwargs)
+	
     def post_exec(self, engine, proxy, compiled, parameters, **kwargs):
         """ Turn off the INDENTITY_INSERT mode if it's been activated, and fetch recently inserted IDENTIFY values (works only for one column) """
         if getattr(compiled, "isinsert", False):
             elif self.HASIDENT:
                 cursor = proxy("SELECT @@IDENTITY AS lastrowid")
                 row = cursor.fetchone()
-                self.last_inserted_ids = [row[0]]
+                self._last_inserted_ids = [int(row[0])]
+                print "LAST ROW ID", self._last_inserted_ids
             self.HASIDENT = False
 
 class MSSQLDialect(ansisql.ANSIDialect):            
         return self.context.last_inserted_ids
 
     def supports_sane_rowcount(self):
-        return True
+        return sane_rowcount
 
     def compiler(self, statement, bindparams, **kwargs):
         return MSSQLCompiler(self, statement, bindparams, **kwargs)
     def dbapi(self):
         return self.module
 
+    def has_table(self, connection, tablename):
+        import sqlalchemy.databases.information_schema as ischema
+
+        current_schema = self.get_default_schema_name()
+        columns = ischema.columns
+        s = sql.select([columns],
+                   current_schema and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema) or columns.c.table_name==tablename,
+                   )
+        
+        c = connection.execute(s)
+        row  = c.fetchone()
+        return row is not None
+        
     def reflecttable(self, connection, table):
         import sqlalchemy.databases.information_schema as ischema
         
             current_schema = self.get_default_schema_name()
 
         columns = ischema.columns
-        s = select([columns],
+        s = sql.select([columns],
                    current_schema and sql.and_(columns.c.table_name==table.name, columns.c.table_schema==current_schema) or columns.c.table_name==table.name,
                    order_by=[columns.c.ordinal_position])
         
             for a in (charlen, numericprec, numericscale):
                 if a is not None:
                     args.append(a)
-                    coltype = ischema_names[type]
+            coltype = ischema_names[type]
             coltype = coltype(*args)
             colargs= []
             if default is not None:
-                colargs.append(PassiveDefault(sql.text(default)))
+                colargs.append(schema.PassiveDefault(sql.text(default)))
                 
             table.append_item(schema.Column(name, coltype, nullable=nullable, *colargs))
         
             col_name, type_name = row[3], row[5]
             if type_name.endswith("identity"):
                 ic = table.c[col_name]
+                ic.primary_key = True
                 # setup a psuedo-sequence to represent the identity attribute - we interpret this at table.create() time as the identity attribute
                 ic.sequence = schema.Sequence(ic.name + '_identity')
 
         # Add constraints
-        RR = ischema.ref_constraints(self)    #information_schema.referential_constraints
+        RR = ischema.ref_constraints    #information_schema.referential_constraints
         TC = ischema.constraints        #information_schema.table_constraints
         C  = ischema.column_constraints.alias('C') #information_schema.constraint_column_usage: the constrained column 
         R  = ischema.column_constraints.alias('R') #information_schema.constraint_column_usage: the referenced column
         fromjoin = TC.join(RR, RR.c.constraint_name == TC.c.constraint_name).join(C, C.c.constraint_name == RR.c.constraint_name)
         fromjoin = fromjoin.join(R, R.c.constraint_name == RR.c.unique_constraint_name)
 
-        s = select([TC.c.constraint_type, C.c.table_schema, C.c.table_name, C.c.column_name,
+        s = sql.select([TC.c.constraint_type, C.c.table_schema, C.c.table_name, C.c.column_name,
                     R.c.table_schema, R.c.table_name, R.c.column_name],
-                   and_(RR.c.constraint_schema == current_schema,  C.c.table_name == table.name),
-                   from_obj = [fromjoin]
+                   sql.and_(RR.c.constraint_schema == current_schema,  C.c.table_name == table.name),
+                   from_obj = [fromjoin], use_labels=True
                    )
+        colmap = [TC.c.constraint_type, C.c.column_name, R.c.table_schema, R.c.table_name, R.c.column_name]
                
         c = connection.execute(s)
 
             row = c.fetchone()
             if row is None:
                 break
+            print "CCROW", row.keys(), row
             (type, constrained_column, referred_schema, referred_table, referred_column) = (
                 row[colmap[0]],
+                row[colmap[1]],
+                row[colmap[2]],
                 row[colmap[3]],
-                row[colmap[4]],
-                row[colmap[5]],
-                row[colmap[6]]
+                row[colmap[4]]
                 )
 
             if type=='PRIMARY KEY':
                 table.c[constrained_column]._set_primary_key()
             elif type=='FOREIGN KEY':
-                remotetable = Table(referred_table, self, autoload = True, schema=referred_schema)
+                if current_schema == referred_schema:
+                    referred_schema = table.schema
+                remotetable = schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection, schema=referred_schema)
                 table.c[constrained_column].append_item(schema.ForeignKey(remotetable.c[referred_column]))
-        
 
 
 class MSSQLCompiler(ansisql.ANSICompiler):
         super(MSSQLCompiler, self).visit_column(column)
         if column.table is not None and self.tablealiases.has_key(column.table):
             self.strings[column] = \
-                self.strings[self.tablealiases[column.table].corresponding_column(column.original)]
+                self.strings[self.tablealiases[column.table].corresponding_column(column)]
 
         
 class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):

lib/sqlalchemy/engine/base.py

     class AmbiguousColumn(object):
         def __init__(self, key):
             self.key = key
+        def dialect_impl(self, dialect):
+            return self
         def convert_result_value(self, arg, engine):
-            raise InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % (self.key))
+            raise exceptions.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % (self.key))
     
     def __init__(self, engine, connection, cursor, executioncontext=None, typemap=None):
         """ResultProxy objects are constructed via the execute() method on SQLEngine."""

lib/sqlalchemy/engine/default.py

                     self._last_inserted_ids = None
                 else:
                     self._last_inserted_ids = last_inserted_ids
+                print "LAST INSERTED PARAMS", param
                 self._last_inserted_params = param
         elif getattr(compiled, 'isupdate', False):
             if isinstance(parameters, list):

lib/sqlalchemy/orm/query.py

 #            raise "ok first thing", str(s2)
             if not kwargs.get('distinct', False) and order_by:
                 s2.order_by(*util.to_list(order_by))
-            s3 = s2.alias('rowcount')
+            s3 = s2.alias('tbl_row_count')
             crit = []
             for i in range(0, len(self.table.primary_key)):
                 crit.append(s3.primary_key[i] == self.table.primary_key[i])

lib/sqlalchemy/sql.py

             col = self.primary_key[0]
         else:
             col = list(self.columns)[0]
-        return select([func.count(col).label('rowcount')], whereclause, from_obj=[self], **params)
+        return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params)
     def join(self, right, *args, **kwargs):
         return Join(self, right, *args, **kwargs)
     def outerjoin(self, right, *args, **kwargs):
             col = self.primary_key[0]
         else:
             col = list(self.columns)[0]
-        return select([func.count(col).label('rowcount')], whereclause, from_obj=[self], **params)
+        return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params)
     def join(self, right, *args, **kwargs):
         return Join(self, right, *args, **kwargs)
     def outerjoin(self, right, *args, **kwargs):

test/orm/objectstore.py

         version_table.delete().execute()
         SessionTest.tearDown(self)
     
-    @testbase.unsupported('mysql')
+    @testbase.unsupported('mysql', 'mssql')
     def testbasic(self):
         s = create_session()
         class Foo(object):pass
         assert len(t1.t2s) == 2
         
 class PKTest(SessionTest):
+    @testbase.unsupported('mssql')
     def setUpAll(self):
         SessionTest.setUpAll(self)
         db.echo = False
         global table2
         global table3
         table = Table(
-            'multi', db, 
+            'multipk', db, 
             Column('multi_id', Integer, Sequence("multi_id_seq", optional=True), primary_key=True),
             Column('multi_rev', Integer, primary_key=True),
             Column('name', String(50), nullable=False),
             Column('value', String(100))
         )
         
-        table2 = Table('multi2', db,
+        table2 = Table('multipk2', db,
             Column('pk_col_1', String(30), primary_key=True),
             Column('pk_col_2', String(30), primary_key=True),
             Column('data', String(30), )
             )
-        table3 = Table('multi3', db,
+        table3 = Table('multipk3', db,
             Column('pri_code', String(30), key='primary', primary_key=True),
             Column('sec_code', String(30), key='secondary', primary_key=True),
             Column('date_assigned', Date, key='assigned', primary_key=True),
         table2.create()
         table3.create()
         db.echo = testbase.echo
+    @testbase.unsupported('mssql')
     def tearDownAll(self):
         db.echo = False
         table.drop()
         db.echo = testbase.echo
         SessionTest.tearDownAll(self)
         
-    @testbase.unsupported('sqlite')
+    @testbase.unsupported('sqlite', 'mssql')
     def testprimarykey(self):
         class Entry(object):
             pass
         ctx.current.clear()
         e2 = Entry.mapper.get((e.multi_id, 2))
         self.assert_(e is not e2 and e._instance_key == e2._instance_key)
+    @testbase.unsupported('mssql')
     def testmanualpk(self):
         class Entry(object):
             pass
         e.pk_col_2 = 'pk1_related'
         e.data = 'im the data'
         ctx.current.flush()
+    @testbase.unsupported('mssql')
     def testkeypks(self):
         import datetime
         class Entity(object):
             db_uri = 'oracle://scott:tiger@127.0.0.1:1521'
             opts = {'use_ansi':False}
         elif DBTYPE == 'mssql':
-            db_uri = 'mssql://scott:tiger@/test'
+            db_uri = 'mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test'
 
     if not db_uri:
         raise "Could not create engine.  specify --db <sqlite|sqlite_file|postgres|mysql|oracle|oracle8|mssql> to test runner."