Mike Bayer avatar Mike Bayer committed a04c6e9

added indexes to schema/ansisql/engine
slightly different index syntax for mysql
fixed mysql Time type to convert from a timedelta to time
tweaks to date unit tests for mysql

Comments (0)

Files changed (8)

lib/sqlalchemy/ansisql.py

     def visit_fromclause(self, fromclause):
         self.froms[fromclause] = fromclause.from_name
 
+    def visit_index(self, index):
+        self.strings[index] = index.name
+        
     def visit_textclause(self, textclause):
         if textclause.parens and len(textclause.text):
             self.strings[textclause] = "(" + textclause.text + ")"
 
     def visit_function(self, func):
         self.strings[func] = func.name + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")"
-    
+        
     def visit_compound_select(self, cs):
         text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ")
         for tup in cs.clauses:
 
     def visit_column(self, column):
         pass
+
+    def visit_index(self, index):
+        self.append('CREATE ')
+        if index.unique:
+            self.append('UNIQUE ')
+        self.append('INDEX %s ON %s (%s)' \
+                    % (index.name, index.table.name,
+                       string.join([c.name for c in index.columns], ', ')))
+        self.execute()
+        
     
 class ANSISchemaDropper(sqlalchemy.engine.SchemaIterator):
+    def visit_index(self, index):
+        self.append("\nDROP INDEX " + index.name)
+        self.execute()
+        
     def visit_table(self, table):
         self.append("\nDROP TABLE " + table.fullname)
         self.execute()
 
 
 class ANSIDefaultRunner(sqlalchemy.engine.DefaultRunner):
-    pass
+    pass

lib/sqlalchemy/databases/mysql.py

 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-import sys, StringIO, string, types, re
+import sys, StringIO, string, types, re, datetime
 
 import sqlalchemy.sql as sql
 import sqlalchemy.engine as engine
 class MSTime(sqltypes.Time):
     def get_col_spec(self):
         return "TIME"
+    def convert_result_value(self, value, engine):
+        # convert from a timedelta value
+        if value is not None:
+            return datetime.time(value.seconds/60/60, value.seconds/60%60, value.seconds - (value.seconds/60*60))
+        else:
+            return None
+            
 class MSText(sqltypes.TEXT):
     def get_col_spec(self):
         return "TEXT"
     def schemagenerator(self, **params):
         return MySQLSchemaGenerator(self, **params)
 
+    def schemadropper(self, **params):
+        return MySQLSchemaDropper(self, **params)
+
     def get_default_schema_name(self):
         if not hasattr(self, '_default_schema_name'):
             self._default_schema_name = text("select database()", self).scalar()
         else:
             return ""
 
+class MySQLSchemaDropper(ansisql.ANSISchemaDropper):
+    def visit_index(self, index):
+        self.append("\nDROP INDEX " + index.name + " ON " + index.table.name)
+        self.execute()

lib/sqlalchemy/engine.py

         for the "rowcount" function on a statement handle.  """
         return True
         
-    def create(self, table, **params):
-        """creates a table within this engine's database connection given a schema.Table object."""
-        table.accept_visitor(self.schemagenerator(**params))
+    def create(self, entity, **params):
+        """creates a table or index within this engine's database connection given a schema.Table object."""
+        entity.accept_visitor(self.schemagenerator(**params))
 
-    def drop(self, table, **params):
-        """drops a table within this engine's database connection given a schema.Table object."""
-        table.accept_visitor(self.schemadropper(**params))
+    def drop(self, entity, **params):
+        """drops a table or index within this engine's database connection given a schema.Table object."""
+        entity.accept_visitor(self.schemadropper(**params))
 
     def compile(self, statement, parameters, **kwargs):
         """given a sql.ClauseElement statement plus optional bind parameters, creates a new
         database-specific behavior."""
         return sql.ColumnImpl(column)
 
+    def indeximpl(self, index):
+        """returns a new sql.IndexImpl object to correspond to the given Index
+        object. An IndexImpl provides SQL statement builder operations on an
+        Index metadata object, and a subclass of this object may be provided
+        by a SQLEngine subclass to provide database-specific behavior.
+        """
+        return sql.IndexImpl(index)
+    
     def get_default_schema_name(self):
         """returns the currently selected schema in the current connection."""
         return None

lib/sqlalchemy/schema.py

 from sqlalchemy.types import *
 import copy, re, string
 
-__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'SchemaEngine', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault']
-
+__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index',
+           'SchemaEngine', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault']
 
 class SchemaItem(object):
     """base class for items that define a database schema."""
         """calls the visit_seauence method on the given visitor."""
         return visitor.visit_sequence(self)
 
+class Index(SchemaItem):
+    """Represents an index of columns from a database table
+    """
+
+    def __init__(self, name, *columns, **kw):
+        """Constructs an index object. Arguments are:
+
+        name : the name of the index
+
+        *columns : columns to include in the index. All columns must belong to
+        the same table, and no column may appear more than once.
+
+        **kw : keyword arguments include:
+
+        unique=True : create a unique index
+        """
+        self.name = name
+        self.columns = columns
+        self.unique = kw.pop('unique', False)
+        self._init_items()
+
+    def _init_items(self):
+        # make sure all columns are from the same table
+        # FIXME: and no column is repeated
+        self.table = None
+        for column in self.columns:
+            if self.table is None:
+                self.table = column.table
+            elif column.table != self.table:
+                # all columns muse be from same table
+                raise ValueError("All index columns must be from same table. "
+                                 "%s is from %s not %s" % (column,
+                                                           column.table,
+                                                           self.table))
+        # set my _impl from col.table.engine
+        self._impl = self.table.engine.indeximpl(self)
+        
+    def accept_visitor(self, visitor):
+        visitor.visit_index(self)
+    def __str__(self):
+        return repr(self)
+    def __repr__(self):
+        return 'Index("%s", %s%s)' % (self.name,
+                                      ', '.join([repr(c)
+                                                 for c in self.columns]),
+                                      (self.unique and ', unique=True') or '')
+        
 class SchemaEngine(object):
     """a factory object used to create implementations for schema objects.  This object
     is the ultimate base class for the engine.SQLEngine class."""
     def columnimpl(self, column):
         """returns a new implementation object for a Column (usually sql.ColumnImpl)"""
         raise NotImplementedError()
+    def indeximpl(self, index):
+        """returns a new implementation object for an Index (usually
+        sql.IndexImpl)
+        """
+        raise NotImplementedError()
     def reflecttable(self, table):
         """given a table, will query the database and populate its Column and ForeignKey 
         objects."""

lib/sqlalchemy/sql.py

             self.whereclause.accept_visitor(visitor)
         visitor.visit_delete(self)
 
+class IndexImpl(ClauseElement):
+
+    def __init__(self, index):
+        self.index = index
+        self.name = index.name
+        self._engine = self.index.table.engine
+
+    table = property(lambda s: s.index.table)
+    columns = property(lambda s: s.index.columns)
         
+    def hash_key(self):
+        return self.index.hash_key()
+    def accept_visitor(self, visitor):
+        visitor.visit_index(self.index)
+    def compare(self, other):
+        return self.index is other
+    def create(self):
+        self._engine.create(self.index)
+    def drop(self):
+        self._engine.drop(self.index)
+    def execute(self):
+        self.create()
         # schema/tables
         'engines', 
         'testtypes',
-        
+	'indexes',
+	        
         # SQL syntax
         'select',
         'selectable',
+from sqlalchemy import *
+import sys
+import testbase
+
+class IndexTest(testbase.AssertMixin):
+    
+    def setUp(self):
+        self.created = []
+
+    def tearDown(self):
+        if self.created:
+            self.created.reverse()
+            for entity in self.created:
+                entity.drop()
+    
+    def test_index_create(self):
+        employees = Table('employees', testbase.db,
+                          Column('id', Integer, primary_key=True),
+                          Column('first_name', String(30)),
+                          Column('last_name', String(30)),
+                          Column('email_address', String(30)))
+        employees.create()
+        self.created.append(employees)
+        
+        i = Index('employee_name_index',
+                  employees.c.last_name, employees.c.first_name)
+        i.create()
+        self.created.append(i)
+        
+        i = Index('employee_email_index',
+                  employees.c.email_address, unique=True)        
+        i.create()
+        self.created.append(i)
+        
+if __name__ == "__main__":    
+    testbase.main()

test/testtypes.py

     def setUpAll(self):
         global users_with_date, insert_data
 
-        insert_data =  [[7, 'jack', datetime.datetime(2005, 11, 10, 0, 0), datetime.date(2005,11,10), datetime.time(12,20,2)],
+        insert_data =  [
+                        [7, 'jack', datetime.datetime(2005, 11, 10, 0, 0), datetime.date(2005,11,10), datetime.time(12,20,2)],
                         [8, 'roy', datetime.datetime(2005, 11, 10, 11, 52, 35), datetime.date(2005,10,10), datetime.time(0,0,0)],
                         [9, 'foo', datetime.datetime(2005, 11, 10, 11, 52, 35, 54839), datetime.date(1970,4,1), datetime.time(23,59,59,999)],
-                        [10, 'colber', None, None, None]]
+                        [10, 'colber', None, None, None]
+        ]
 
         fnames = ['user_id', 'user_name', 'user_datetime', 'user_date', 'user_time']
 
         collist = [Column('user_id', INT, primary_key = True), Column('user_name', VARCHAR(20)), Column('user_datetime', DateTime),
                    Column('user_date', Date), Column('user_time', Time)]
-
-
         
         if db.engine.__module__.endswith('mysql'):
             # strip microseconds -- not supported by this engine (should be an easier way to detect this)
             for d in insert_data:
-                d[2] = d[2].replace(microsecond=0)
-                d[4] = d[4].replace(microsecond=0)
+                if d[2] is not None:
+                    d[2] = d[2].replace(microsecond=0)
+                if d[4] is not None:
+                    d[4] = d[4].replace(microsecond=0)
         
         try:
             db.type_descriptor(types.TIME).get_col_spec()
-            print  "HI"
         except:
             # don't test TIME type -- not supported by this engine
             insert_data = [d[:-1] for d in insert_data]
             fnames = fnames[:-1]
             collist = collist[:-1]
 
-
         users_with_date = Table('query_users_with_date', db, redefine = True, *collist)
         users_with_date.create()
-
         insert_dicts = [dict(zip(fnames, d)) for d in insert_data]
         for idict in insert_dicts:
             users_with_date.insert().execute(**idict) # insert the data
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.