Mike Bayer avatar Mike Bayer committed ca305f7

- decoupled all ColumnElements from also being Selectables. this means
that anything which is a column expression does not have a "c" or a
"columns" attribute. Also works for select().as_scalar(); _ScalarSelect
is a columnelement, so you can't say select().as_scalar().c.foo, which is
a pretty confusing mistake to make. in the case of _ScalarSelect made
an explicit raise if you try to access 'c'.

Comments (0)

Files changed (3)

doc/build/testdocs.py

     return s
 
 for filename in ('ormtutorial', 'sqlexpression'):
-#for filename in ('sqlexpression',):
 	filename = 'content/%s.txt' % filename
 	s = open(filename).read()
 	#s = replace_file(s, ':memory:')

lib/sqlalchemy/sql.py

     def select(self, whereclauses = None, **params):
         return select([self], whereclauses, **params)
 
-
-class ColumnElement(Selectable, _CompareMixin):
+class ColumnElement(ClauseElement, _CompareMixin):
     """Represent an element that is useable within the 
     "column clause" portion of a ``SELECT`` statement. 
     
         which each represent a foreign key placed on this column's ultimate
         ancestor.
         """)
-    columns = property(lambda self:[self],
-                       doc=\
-        """Columns accessor which returns ``self``, to provide compatibility 
-        with ``Selectable`` objects.
-        """)
 
     def _one_fkey(self):
         if self._foreign_keys:
         """return the list of ColumnElements represented within this FromClause's _exportable_columns"""
         export = self._exportable_columns()
         for column in export:
-            # TODO: is this conditional needed ?
             if isinstance(column, Selectable):
-                s = column
+                for co in column.columns:
+                    yield co
+            elif isinstance(column, ColumnElement):
+                yield column
             else:
                 continue
-            for co in s.columns:
-                yield co
         
     def _exportable_columns(self):
         return []
             self.append(c)
 
     key = property(lambda self:self.name)
-
+    columns = property(lambda self:[self])
+    
     def _copy_internals(self):
         _CalculatedClause._copy_internals(self)
         self._clone_from_clause()
     
     def __init__(self, *args, **kwargs):
         kwargs['correlate'] = True
-        s = select(*args, **kwargs).self_group()
+        s = select(*args, **kwargs).as_scalar().self_group()
         _UnaryExpression.__init__(self, s, operator=Operators.exists)
 
+    def select(self, whereclauses = None, **params):
+        return select([self], whereclauses, **params)
+
     def correlate(self, fromclause):
-      e = self._clone()
-      e.element = self.element.correlate(fromclause).self_group()
-      return e
+        e = self._clone()
+        e.element = self.element.correlate(fromclause).self_group()
+        return e
     
     def where(self, clause):
-      e = self._clone()
-      e.element = self.element.where(clause).self_group()
-      return e
+        e = self._clone()
+        e.element = self.element.where(clause).self_group()
+        return e
       
     def _hide_froms(self, **modifiers):
         return self._get_from_objects(**modifiers)
     primary_key = property(lambda s:s.__primary_key)
 
     def self_group(self, against=None):
-        return _Grouping(self)
+        return _FromGrouping(self)
         
     def _locate_oid_column(self):
         return self.left.oid_column
         
     key = property(lambda s: s.elem.key)
     _label = property(lambda s: s.elem._label)
-    columns = c = property(lambda s:s.elem.columns)
 
     def _copy_internals(self):
         self.elem = self.elem._clone()
         return getattr(self.elem, attr)
 
 class _Grouping(_ColumnElementAdapter):
+    """represent a grouping within a column expression"""
     pass
 
+class _FromGrouping(FromClause):
+    """represent a grouping of a FROM clause"""
+    __visit_name__ = 'grouping'
+
+    def __init__(self, elem):
+        self.elem = elem
+
+    columns = c = property(lambda s:s.elem.columns)
+
+    def get_children(self, **kwargs):
+        return self.elem,
+
+    def _hide_froms(self, **modifiers):
+        return self.elem._hide_froms(**modifiers)
+
+    def _copy_internals(self):
+        self.elem = self.elem._clone()
+
+    def _get_from_objects(self, **modifiers):
+        return self.elem._get_from_objects(**modifiers)
+
+    def __getattr__(self, attr):
+        return getattr(self.elem, attr)
+    
 class _Label(ColumnElement):
     """represent a label, as typically applied to any column-level element
     using the ``AS`` sql keyword.
         return self.obj._hide_froms(**modifiers)
         
     def _make_proxy(self, selectable, name = None):
-        if isinstance(self.obj, Selectable):
+        if isinstance(self.obj, (Selectable, ColumnElement)):
             return self.obj._make_proxy(selectable, name=self.name)
         else:
             return column(self.name)._make_proxy(selectable=selectable)
         super(_ScalarSelect, self).__init__(elem)
         self.type = list(elem.inner_columns)[0].type
 
-    columns = property(lambda self:[self])
+    def _no_cols(self):
+        raise exceptions.InvalidRequestError("Scalar Select expression has no columns; use this object directly within a column-level expression.")
+    c = property(_no_cols)
+    columns = c
     
     def self_group(self, **kwargs):
         return self
     name = property(lambda s:s.keyword + " statement")
 
     def self_group(self, against=None):
-        return _Grouping(self)
+        return _FromGrouping(self)
 
     def _locate_oid_column(self):
         return self.selects[0].oid_column
             return froms
     
     froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""")
+
+    name = property(lambda self:"Select statement")
+
+    def expression_element(self):
+        return self.as_scalar()
     
     def locate_all_froms(self):
         froms = util.Set()
         self._froms.add(fromclause)
 
     def _exportable_columns(self):
-        return [c for c in self._raw_columns if isinstance(c, Selectable)]
+        return [c for c in self._raw_columns if isinstance(c, (Selectable, ColumnElement))]
         
     def _proxy_column(self, column):
         if self.use_labels:
     def self_group(self, against=None):
         if isinstance(against, CompoundSelect):
             return self
-        return _Grouping(self)
+        return _FromGrouping(self)
 
     def _locate_oid_column(self):
         for f in self.locate_all_froms():
                 return e
         # look through the columns (largely synomous with looking
         # through the FROMs except in the case of _CalculatedClause/_Function)
-        for cc in self._exportable_columns():
-            for c in cc.columns:
-                if getattr(c, 'table', None) is self:
-                    continue
-                e = c.bind
-                if e is not None:
-                    self._bind = e
-                    return e
+        for c in self._exportable_columns():
+            if getattr(c, 'table', None) is self:
+                continue
+            e = c.bind
+            if e is not None:
+                self._bind = e
+                return e
         return None
 
 class _UpdateBase(ClauseElement):

test/sql/select.py

                 self.assert_(c.get_params().get_original_dict() == checkparams, "params dont match" + repr(c.get_params()))
             
 class SelectTest(SQLTest):
+    
+    def test_attribute_sanity(self):
+        assert hasattr(table1, 'c')
+        assert hasattr(table1.select(), 'c')
+        assert not hasattr(table1.c.myid.self_group(), 'columns')
+        assert hasattr(table1.select().self_group(), 'columns')
+        assert not hasattr(table1.select().as_scalar().self_group(), 'columns')
+        assert not hasattr(table1.c.myid, 'columns')
+        assert not hasattr(table1.c.myid, 'c')
+        assert not hasattr(table1.select().c.myid, 'c')
+        assert not hasattr(table1.select().c.myid, 'columns')
+        assert not hasattr(table1.alias().c.myid, 'columns')
+        assert not hasattr(table1.alias().c.myid, 'c')
+        
     def testtableselect(self):
         self.runtest(table1.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable")
 
         self.runtest(select([table1], from_obj=[table1, table1.select()]), """SELECT mytable.myid, mytable.name, mytable.description FROM mytable, (SELECT mytable.myid AS myid, mytable.name AS name, mytable.description AS description FROM mytable)""")
     
     def testexistsascolumnclause(self):
-        self.runtest(exists([table1.c.myid], table1.c.myid==5).select(), "SELECT EXISTS (SELECT mytable.myid AS myid FROM mytable WHERE mytable.myid = :mytable_myid)", params={'mytable_myid':5})
+        self.runtest(exists([table1.c.myid], table1.c.myid==5).select(), "SELECT EXISTS (SELECT mytable.myid FROM mytable WHERE mytable.myid = :mytable_myid)", params={'mytable_myid':5})
 
         self.runtest(select([table1, exists([1], from_obj=[table2])]), "SELECT mytable.myid, mytable.name, mytable.description, EXISTS (SELECT 1 FROM myothertable) FROM mytable", params={})
 
             select([users, s.c.street], from_obj=[s]),
             """SELECT users.user_id, users.user_name, users.password, s.street FROM users, (SELECT addresses.street AS street FROM addresses WHERE addresses.user_id = users.user_id) AS s""")
 
-        # TODO: this tests that you dont get a "SELECT column" without a FROM but its not working yet.
-        #self.runtest(
-        #    table1.select(table1.c.myid == select([table1.c.myid], table1.c.name=='jack')), ""
-        #)
+        self.runtest(
+            table1.select(table1.c.myid == select([table1.c.myid], table1.c.name=='jack')), 
+            "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = (SELECT mytable.myid FROM mytable WHERE mytable.name = :mytable_name)"
+        )
         
         self.runtest(
             table1.select(table1.c.myid == select([table2.c.otherid], table1.c.name == table2.c.othername)),
         )
         
         
-    def testcolumnsubquery(self):
+    def test_scalar_select(self):
         s = select([table1.c.myid], scalar=True, correlate=False)
         self.runtest(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) FROM mytable")
 
 
         self.runtest(select([select([table1.c.name]).label('foo')]), "SELECT (SELECT mytable.name FROM mytable) AS foo")
 
+        # scalar selects should not have any attributes on their 'c' or 'columns' attribute
+        s = select([table1.c.myid]).as_scalar()
+        try:
+            s.c.foo
+        except exceptions.InvalidRequestError, err:
+            assert str(err) == 'Scalar Select expression has no columns; use this object directly within a column-level expression.'
 
+        try:
+            s.columns.foo
+        except exceptions.InvalidRequestError, err:
+            assert str(err) == 'Scalar Select expression has no columns; use this object directly within a column-level expression.'
+        
         zips = table('zips',
             column('zipcode'),
             column('latitude'),
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.