Mike Bayer avatar Mike Bayer committed ddaea56

- oracle reflection of case-sensitive names all fixed up
- other unit tests corrected for oracle

Comments (0)

Files changed (6)

lib/sqlalchemy/databases/oracle.py

         return OracleDefaultRunner(connection, **kwargs)
 
     def has_table(self, connection, table_name, schema=None):
-        cursor = connection.execute("""select table_name from all_tables where table_name=:name""", {'name':table_name.upper()})
+        cursor = connection.execute("""select table_name from all_tables where table_name=:name""", {'name':self._denormalize_name(table_name)})
         return bool( cursor.fetchone() is not None )
 
     def has_sequence(self, connection, sequence_name):
-        cursor = connection.execute("""select sequence_name from all_sequences where sequence_name=:name""", {'name':sequence_name.upper()})
+        cursor = connection.execute("""select sequence_name from all_sequences where sequence_name=:name""", {'name':self._denormalize_name(sequence_name)})
         return bool( cursor.fetchone() is not None )
 
     def _locate_owner_row(self, owner, name, rows, raiseerr=False):
                     dblink = ''
                 return name, owner, dblink
             raise
-        
+
+    def _normalize_name(self, name):
+        if name is None:
+            return None
+        elif name.upper() == name and not self.identifier_preparer._requires_quotes(name.lower(), True):
+            return name.lower()
+        else:
+            return name
+    
+    def _denormalize_name(self, name):
+        if name is None:
+            return None
+        elif name.lower() == name and not self.identifier_preparer._requires_quotes(name.lower(), True):
+            return name.upper()
+        else:
+            return name
+    
     def table_names(self, connection, schema):
-        # sorry, I have no idea what that dblink stuff is about :)
+        # note that table_names() isnt loading DBLINKed or synonym'ed tables
         s = "select table_name from all_tables where tablespace_name NOT IN ('SYSTEM', 'SYSAUX')"
-        return [row[0] for row in connection.execute(s)]
+        return [self._normalize_name(row[0]) for row in connection.execute(s)]
 
     def reflecttable(self, connection, table, include_columns):
         preparer = self.identifier_preparer
-        if not preparer.should_quote(table):
-            name = table.name.upper()
-        else:
-            name = table.name
 
         # search for table, including across synonyms and dblinks.
         # locate the actual name of the table, the real owner, and any dblink clause needed.
-        actual_name, owner, dblink = self._resolve_table_owner(connection, name, table)
+        actual_name, owner, dblink = self._resolve_table_owner(connection, self._denormalize_name(table.name), table)
+
+        print "ACTUALNAME:", actual_name
 
         c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':actual_name, 'owner':owner})
 
+                
         while True:
             row = c.fetchone()
             if row is None:
             found_table = True
 
             #print "ROW:" , row
-            (colname, coltype, length, precision, scale, nullable, default) = (row[0], row[1], row[2], row[3], row[4], row[5]=='Y', row[6])
-
-            # if name comes back as all upper, assume its case folded
-            if (colname.upper() == colname):
-                colname = colname.lower()
+            (colname, coltype, length, precision, scale, nullable, default) = (self._normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5]=='Y', row[6])
 
             if include_columns and colname not in include_columns:
                 continue
         c = connection.execute("""SELECT
              ac.constraint_name,
              ac.constraint_type,
-             LOWER(loc.column_name) AS local_column,
-             LOWER(rem.table_name) AS remote_table,
-             LOWER(rem.column_name) AS remote_column,
-             LOWER(rem.owner) AS remote_owner
+             loc.column_name AS local_column,
+             rem.table_name AS remote_table,
+             rem.column_name AS remote_column,
+             rem.owner AS remote_owner
            FROM all_constraints%(dblink)s ac,
              all_cons_columns%(dblink)s loc,
              all_cons_columns%(dblink)s rem
             if row is None:
                 break
             #print "ROW:" , row
-            (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row
+            (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row[0:2] + tuple([self._normalize_name(x) for x in row[2:]])
             if cons_type == 'P':
                 table.primary_key.add(table.c[local_column])
             elif cons_type == 'R':
 class OracleSchemaDropper(ansisql.ANSISchemaDropper):
     def visit_sequence(self, sequence):
         if not self.checkfirst or self.dialect.has_sequence(self.connection, sequence.name):
-            self.append("DROP SEQUENCE %s" % sequence.name)
+            self.append("DROP SEQUENCE %s" % self.preparer.format_sequence(sequence))
             self.execute()
 
 class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
         return self.connection.execute(c).scalar()
 
     def visit_sequence(self, seq):
-        return self.connection.execute("SELECT " + seq.name + ".nextval FROM DUAL").scalar()
+        return self.connection.execute("SELECT " + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL").scalar()
 
 dialect = OracleDialect

test/engine/reflection.py

             # the colon thing isnt working out for PG reflection just yet
             #defval3 = '1999-09-09 00:00:00'
             deftype3 = Date
-            defval3 = '1999-09-09'
+            if testbase.db.engine.name == 'oracle':
+                defval3 = text("to_date('09-09-1999', 'MM-DD-YYYY')")
+            else:
+                defval3 = '1999-09-09'
         else:
             deftype2, deftype3 = Integer, Integer
             defval2, defval3 = "15", "16"

test/sql/defaults.py

     def testwithautoincrement(self):
         meta = MetaData(testbase.db)
         table = Table("aitest", meta, 
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, Sequence('ai_id_seq', optional=True), primary_key=True),
             Column('data', String(20)))
         table.create(checkfirst=True)
         try:
         
         meta = MetaData(testbase.db)
         table = Table("aitest", meta, 
-            Column('id', Integer, primary_key=True),
+            Column('id', Integer, Sequence('ai_id_seq', optional=True), primary_key=True),
             Column('data', String(20)))
         table.create(checkfirst=True)
 
             # simulate working on a table that doesn't already exist
             meta2 = MetaData(testbase.db)
             table2 = Table("aitest", meta2,
-                Column('id', Integer, primary_key=True),
+                Column('id', Integer, Sequence('ai_id_seq', optional=True), primary_key=True),
                 Column('data', String(20)))
             class AiTest(object):
                 pass
         sometable = Table( 'Manager', metadata,
                Column( 'obj_id', Integer, Sequence('obj_id_seq'), ),
                Column( 'name', String, ),
-               Column( 'id', Integer, primary_key= True, ),
+               Column( 'id', Integer, Sequence('Manager_id_seq', optional=True), primary_key=True),
            )
         
         metadata.create_all()

test/sql/quote.py

         x = select([sql.literal_column("'FooCol'").label("SomeLabel")], from_obj=[table])
         x = x.select()
         assert str(x) == '''SELECT "SomeLabel" \nFROM (SELECT 'FooCol' AS "SomeLabel" \nFROM "ImATable")'''
-        
+   
+    # oracle doesn't support non-case-sensitive until ticket #726 is fixed 
+    @testing.unsupported('oracle')    
     def testlabelsnocase(self):
         metadata = MetaData()
         table1 = Table('SomeCase1', metadata,

test/sql/rowcount.py

         global employees_table
 
         employees_table = Table('employees', metadata,
-            Column('employee_id', Integer, primary_key=True),
+            Column('employee_id', Integer, Sequence('employee_id_seq', optional=True), primary_key=True),
             Column('name', String(50)),
             Column('department', String(1)),
         )

test/sql/unicode.py

 
 
 class UnicodeSchemaTest(PersistTest):
+    @testing.unsupported('oracle')
     def setUpAll(self):
         global unicode_bind, metadata, t1, t2
 
             )
         metadata.create_all()
 
+    @testing.unsupported('oracle')
     def tearDown(self):
         if metadata.tables:
             t2.delete().execute()
             t1.delete().execute()
         
+    @testing.unsupported('oracle')
     def tearDownAll(self):
         global unicode_bind
         metadata.drop_all()
         del unicode_bind
         
+    @testing.unsupported('oracle')
     def test_insert(self):
         t1.insert().execute({u'méil':1, u'\u6e2c\u8a66':5})
         t2.insert().execute({'a':1, 'b':1})
         assert t1.select().execute().fetchall() == [(1, 5)]
         assert t2.select().execute().fetchall() == [(1, 1)]
     
+    @testing.unsupported('oracle')
     def test_reflect(self):
         t1.insert().execute({u'méil':2, u'\u6e2c\u8a66':7})
         t2.insert().execute({'a':2, 'b':2})
         meta.drop_all()
         metadata.create_all()
         
+    @testing.unsupported('oracle')
     def test_mapping(self):
         # TODO: this test should be moved to the ORM tests, tests should be
         # added to this module testing SQL syntax and joins, etc.
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.