Jason Pellerin avatar Jason Pellerin committed 15c4306

Merge indexes [1047]:[1048] into trunk (for #6)

Comments (0)

Files changed (5)

lib/sqlalchemy/ansisql.py

             self.append("\tPRIMARY KEY (%s)" % string.join([c.name for c in pks],', '))
                     
         self.append("\n)%s\n\n" % self.post_create_table(table))
-        self.execute()
-
+        self.execute()        
+        if hasattr(table, 'indexes'):
+            for index in table.indexes:
+                self.visit_index(index)
+        
     def post_create_table(self, table):
         return ''
 
         self.execute()
         
     def visit_table(self, table):
+        # NOTE: indexes on the table will be automatically dropped, so
+        # no need to drop them individually
         self.append("\nDROP TABLE " + table.fullname)
         self.execute()
 

lib/sqlalchemy/databases/sqlite.py

             self.append("\tUNIQUE (%s)" % string.join([c.name for c in table.primary_key],', '))
 
         self.append("\n)\n\n")
-        self.execute()
+        self.execute()        
+        if hasattr(table, 'indexes'):
+            for index in table.indexes:
+                self.visit_index(index)
 
         

lib/sqlalchemy/schema.py

             self.primary_key.append(column)
         column.table = self
         column.type = self.engine.type_descriptor(column.type)
-            
+
+    def append_index(self, index):
+        self.indexes[index.name] = index
+        
     def _set_parent(self, schema):
         schema.tables[self.name] = self
         self.schema = schema
         for c in self.columns:
             c.accept_schema_visitor(visitor)
         return visitor.visit_table(self)
+
+    def append_index_column(self, column, index=None, unique=None):
+        """Add an index or a column to an existing index of the same name.
+        """
+        if index is not None and unique is not None:
+            raise ValueError("index and unique may not both be specified")
+        if index:
+            if index is True:
+                name = 'ix_%s' % column.name
+            else:
+                name = index
+        elif unique:
+            if unique is True:
+                name = 'ux_%s' % column.name
+            else:
+                name = unique
+        # find this index in self.indexes
+        # add this column to it if found
+        # otherwise create new
+        try:
+            index = self.indexes[name]
+            index.append_column(column)
+        except KeyError:
+            index = Index(name, column, unique=unique)
+        return index
+    
     def deregister(self):
         """removes this table from it's engines table registry.  this does not
         issue a SQL DROP statement."""
         which will be invoked upon insert if this column is not present in the insert list or is given a value
         of None.
         
-        hidden=False : indicates this column should not be listed in the table's list of columns.  Used for the "oid" 
-        column, which generally isnt in column lists.
-        """
+        hidden=False : indicates this column should not be listed in the
+        table's list of columns.  Used for the "oid" column, which generally
+        isnt in column lists.
+
+        index=None : True or index name. Indicates that this column is
+        indexed. Pass true to autogenerate the index name. Pass a string to
+        specify the index name. Multiple columns that specify the same index
+        name will all be included in the index, in the order of their
+        creation.
+
+        unique=None : True or undex name. Indicates that this column is
+        indexed in a unique index . Pass true to autogenerate the index
+        name. Pass a string to specify the index name. Multiple columns that
+        specify the same index name will all be included in the index, in the
+        order of their creation.  """
+        
         name = str(name) # in case of incoming unicode
         super(Column, self).__init__(name, None, type)
         self.args = args
         self.nullable = kwargs.pop('nullable', not self.primary_key)
         self.hidden = kwargs.pop('hidden', False)
         self.default = kwargs.pop('default', None)
+        self.index = kwargs.pop('index', None)
+        self.unique = kwargs.pop('unique', None)
+        if self.index is not None and self.unique is not None:
+            raise ArgumentError("Column may not define both index and unique")
         self._foreign_key = None
         self._orig = None
         self._parent = None
         if getattr(self, 'table', None) is not None:
             raise ArgumentError("this Column already has a table!")
         table.append_column(self)
+        if self.index or self.unique:
+            table.append_index_column(self, index=self.index,
+                                      unique=self.unique)
+        
         if self.default is not None:
             self.default = ColumnDefault(self.default)
             self._init_items(self.default)
 class Index(SchemaItem):
     """Represents an index of columns from a database table
     """
-
     def __init__(self, name, *columns, **kw):
         """Constructs an index object. Arguments are:
 
         unique=True : create a unique index
         """
         self.name = name
-        self.columns = columns
+        self.columns = []
+        self.table = None
         self.unique = kw.pop('unique', False)
-        self._init_items()
+        self._init_items(*columns)
 
     engine = property(lambda s:s.table.engine)
-    def _init_items(self):
+    def _init_items(self, *args):
+        for column in args:
+            self.append_column(column)
+            
+    def append_column(self, column):
         # make sure all columns are from the same table
-        # FIXME: and no column is repeated
-        self.table = None
-        for column in self.columns:
-            if self.table is None:
-                self.table = column.table
-            elif column.table != self.table:
-                # all columns muse be from same table
-                raise ArgumentError("All index columns must be from same table. "
-                                 "%s is from %s not %s" % (column,
-                                                           column.table,
-                                                           self.table))
+        # and no column is repeated
+        if self.table is None:
+            self.table = column.table
+            self.table.append_index(self)
+        elif column.table != self.table:
+            # all columns muse be from same table
+            raise ArgumentError("All index columns must be from same table. "
+                                "%s is from %s not %s" % (column,
+                                                          column.table,
+                                                          self.table))
+        elif column.name in [ c.name for c in self.columns ]:
+            raise ArgumentError("A column may not appear twice in the "
+                                "same index (%s already has column %s)"
+                                % (self.name, column))
+        self.columns.append(column)
+        
     def create(self):
        self.engine.create(self)
        return self
         """visit a ForeignKey."""
         pass
     def visit_index(self, index):
-        """visit an Index (not implemented yet)."""
+        """visit an Index."""
         pass
     def visit_passive_default(self, default):
         """visit a passive default"""

lib/sqlalchemy/sql.py

         super(TableClause, self).__init__(name)
         self.name = self.id = self.fullname = name
         self._columns = util.OrderedProperties()
+        self._indexes = util.OrderedProperties()
         self._foreign_keys = []
         self._primary_key = []
         for c in columns:
             self.append_column(c)
 
+    indexes = property(lambda s:s._indexes)
+    
     def append_column(self, c):
         self._columns[c.text] = c
         c.table = self
     
     def setUp(self):
         self.created = []
-
+        self.echo = testbase.db.echo
+        self.logger = testbase.db.logger
+        
     def tearDown(self):
+        testbase.db.echo = self.echo
+        testbase.db.logger = testbase.db.engine.logger = self.logger
         if self.created:
             self.created.reverse()
             for entity in self.created:
                   employees.c.last_name, employees.c.first_name)
         i.create()
         self.created.append(i)
+        assert employees.indexes['employee_name_index'] is i
         
-        i = Index('employee_email_index',
-                  employees.c.email_address, unique=True)        
+        i2 = Index('employee_email_index',
+                   employees.c.email_address, unique=True)        
+        i2.create()
+        self.created.append(i2)
+        assert employees.indexes['employee_email_index'] is i2
+
+    def test_index_create_camelcase(self):
+        """test that mixed-case index identifiers are legal"""
+        employees = Table('companyEmployees', testbase.db,
+                          Column('id', Integer, primary_key=True),
+                          Column('firstName', String),
+                          Column('lastName', String),
+                          Column('emailAddress', String))        
+        employees.create()
+        self.created.append(employees)
+        
+        i = Index('employeeNameIndex',
+                  employees.c.lastName, employees.c.firstName)
         i.create()
         self.created.append(i)
         
+        i = Index('employeeEmailIndex',
+                  employees.c.emailAddress, unique=True)        
+        i.create()
+        self.created.append(i)
+
+        # Check that the table is useable. This is mostly for pg,
+        # which can be somewhat sticky with mixed-case identifiers
+        employees.insert().execute(firstName='Joe', lastName='Smith')
+        ss = employees.select().execute().fetchall()
+        assert ss[0].firstName == 'Joe'
+        assert ss[0].lastName == 'Smith'
+
+    def test_index_create_inline(self):
+        """Test indexes defined with tables"""
+
+        testbase.db.echo = True
+        capt = []
+        class dummy:
+            pass
+        stream = dummy()
+        stream.write = capt.append
+        testbase.db.logger = testbase.db.engine.logger = stream
+        
+        events = Table('events', testbase.db,
+                       Column('id', Integer, primary_key=True),
+                       Column('name', String(30), unique=True),
+                       Column('location', String(30), index=True),
+                       Column('sport', String(30),
+                              unique='sport_announcer'),
+                       Column('announcer', String(30),
+                              unique='sport_announcer'),
+                       Column('winner', String(30), index='idx_winners'))
+        
+        index_names = [ ix.name for ix in events.indexes ]
+        assert 'ux_name' in index_names
+        assert 'ix_location' in index_names
+        assert 'sport_announcer' in index_names
+        assert 'idx_winners' in index_names
+        assert len(index_names) == 4
+
+        events.create()
+        self.created.append(events)
+
+        # verify that the table is functional
+        events.insert().execute(id=1, name='hockey finals', location='rink',
+                                sport='hockey', announcer='some canadian',
+                                winner='sweden')
+        ss = events.select().execute().fetchall()
+        
+        assert capt[0].strip().startswith('CREATE TABLE events')
+        assert capt[2].strip() == \
+            'CREATE UNIQUE INDEX ux_name ON events (name)'
+        assert capt[4].strip() == \
+            'CREATE INDEX ix_location ON events (location)'
+        assert capt[6].strip() == \
+            'CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)'
+        assert capt[8].strip() == \
+            'CREATE INDEX idx_winners ON events (winner)'
+            
 if __name__ == "__main__":    
     testbase.main()
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.