Commits

Mike Bayer committed 8d34951

- deprecated scalar=True argument on select(). its replaced
by select().scalar() which returns a _ScalarSelect object, that obeys
the ColumnElement interface fully
- removed _selectable() method. replaced with __selectable__() as an optional
duck-typer; subclassing Selectable (without any __selectable__()) is equivalent
- query._col_aggregate() was assuming bound metadata. ick !
- probably should deprecate ClauseElement.scalar(), in favor of ClauseElement.execute().scalar()...
otherwise might need to rename select().scalar()

Comments (0)

Files changed (7)

lib/sqlalchemy/ext/sqlsoup.py

     def update(cls, whereclause=None, values=None, **kwargs):
         _ddl_error(cls)
 
-    def _selectable(cls):
+    def __selectable__(cls):
         return cls._table
 
     def __getattr__(cls, attr):
         return x
 
 def class_for_table(selectable, **mapper_kwargs):
-    if not hasattr(selectable, '_selectable') \
-    or selectable._selectable() != selectable:
-        raise ArgumentError('class_for_table requires a selectable as its argument')
+    selectable = sql._selectable(selectable)
     mapname = 'Mapped' + _selectable_name(selectable)
     if isinstance(selectable, Table):
         klass = TableClassType(mapname, (object,), {})
 
     def with_labels(self, item):
         # TODO give meaningful aliases
-        return self.map(item._selectable().select(use_labels=True).alias('foo'))
+        return self.map(sql._selectable(item).select(use_labels=True).alias('foo'))
 
     def join(self, *args, **kwargs):
         j = join(*args, **kwargs)

lib/sqlalchemy/orm/query.py

 
         if self._order_by is not False:
             s1 = sql.select([col], self._criterion, **ops).alias('u')
-            return sql.select([func(s1.corresponding_column(col))]).scalar()
+            return self.session.execute(sql.select([func(s1.corresponding_column(col))]), mapper=self.mapper).scalar()
         else:
-            return sql.select([func(col)], self._criterion, **ops).scalar()
+            return self.session.execute(sql.select([func(col)], self._criterion, **ops), mapper=self.mapper).scalar()
 
     def min(self, col):
         """Execute the SQL ``min()`` function against the given column."""

lib/sqlalchemy/sql.py

           will attempt to provide similar functionality.
         
         scalar=False
-          when ``True``, indicates that the resulting ``Select`` object
-          is to be used in the "columns" clause of another select statement,
-          where the evaluated value of the column is the scalar result of 
-          this statement.  Normally, placing any ``Selectable`` within the 
-          columns clause of a ``select()`` call will expand the member 
-          columns of the ``Selectable`` individually.
+          deprecated.  use select(...).scalar() to create a "scalar column"
+          proxy for an existing Select object.
 
         correlate=True
           indicates that this ``Select`` object should have its contained
           rendered in the ``FROM`` clause of this select statement.
       
     """
-
-    return Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs)
+    scalar = kwargs.pop('scalar', False)
+    s = Select(columns, whereclause=whereclause, from_obj=from_obj, **kwargs)
+    if scalar:
+        return s.scalar()
+    else:
+        return s
 
 def subquery(alias, *args, **kwargs):
     """Return an [sqlalchemy.sql#Alias] object derived from a [sqlalchemy.sql#Select].
             return _BindParamClause(name, element, shortname=name, type_=type_, unique=True)
     else:
         return element
+
+def _selectable(element):
+    if hasattr(element, '__selectable__'):
+        return element.__selectable__()
+    elif isinstance(element, Selectable):
+        return element
+    else:
+        raise exceptions.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element))
         
 def is_column(col):
     return isinstance(col, ColumnElement)
             if _is_literal(o) or isinstance( o, _CompareMixin):
                 return self.__eq__( o)    #single item -> ==
             else:
-                assert hasattr( o, '_selectable')   #better check?
+                assert isinstance(o, Selectable)
                 return self.__compare( op, o, negate=negate_op)   #single selectable
 
         args = []
 
     columns = util.NotImplProperty("""a [sqlalchemy.sql#ColumnCollection] containing ``ColumnElement`` instances.""")
 
-    def _selectable(self):
-        return self
-
     def select(self, whereclauses = None, **params):
         return select([self], whereclauses, **params)
 
 
-        
 class ColumnElement(Selectable, _CompareMixin):
     """Represent an element that is useable within the 
     "column clause" portion of a ``SELECT`` statement. 
         """return the list of ColumnElements represented within this FromClause's _exportable_columns"""
         export = self._exportable_columns()
         for column in export:
-            if hasattr(column, '_selectable'):
-                s = column._selectable()
+            # TODO: is this conditional needed ?
+            if isinstance(column, Selectable):
+                s = column
             else:
                 continue
             for co in s.columns:
         return select([self])
 
     def scalar(self):
-        return select([self]).scalar()
+        return select([self]).execute().scalar()
 
     def execute(self):
         return select([self]).execute()
     
     """
     def __init__(self, left, right, onclause=None, isouter = False):
-        self.left = left._selectable()
-        self.right = right._selectable().self_group()
+        self.left = _selectable(left)
+        self.right = _selectable(right).self_group()
         if onclause is None:
             self.onclause = self._match_primaries(self.left, self.right)
         else:
 
     bind = property(lambda s: s.selectable.bind)
 
-class _Grouping(ColumnElement):
+class _ColumnElementAdapter(ColumnElement):
+    """adapts a ClauseElement which may or may not be a
+    ColumnElement subclass itself into an object which
+    acts like a ColumnElement.
+    """
+    
     def __init__(self, elem):
         self.elem = elem
         self.type = getattr(elem, 'type', None)
+        self.orig_set = getattr(elem, 'orig_set', util.Set())
         
-            
     key = property(lambda s: s.elem.key)
     _label = property(lambda s: s.elem._label)
-    orig_set = property(lambda s:s.elem.orig_set)
     columns = c = property(lambda s:s.elem.columns)
-    
+
     def _copy_internals(self):
-        print "GROPING COPY INTERNALS"
         self.elem = self.elem._clone()
-        print "NEW ID", id(self.elem)
-        
+
     def get_children(self, **kwargs):
         return self.elem,
-        
+
     def _hide_froms(self, **modifiers):
         return self.elem._hide_froms(**modifiers)
-        
+
     def _get_from_objects(self, **modifiers):
         return self.elem._get_from_objects(**modifiers)
 
     def __getattr__(self, attr):
         return getattr(self.elem, attr)
 
+class _Grouping(_ColumnElementAdapter):
+    pass
+
 class _Label(ColumnElement):
     """represent a label, as typically applied to any column-level element
     using the ``AS`` sql keyword.
     def _get_from_objects(self, **modifiers):
         return [self]
 
+    
 class _SelectBaseMixin(object):
     """Base class for ``Select`` and ``CompoundSelects``."""
 
-    def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None, scalar=False):
+    def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None):
         self.use_labels = use_labels
         self.for_update = for_update
         self._limit = limit
         self._offset = offset
         self._bind = bind
-        self.is_scalar = scalar
-        if self.is_scalar:
-            # allow corresponding_column to return None
-            self.orig_set = util.Set()
         
         self.append_order_by(*util.to_list(order_by, []))
         self.append_group_by(*util.to_list(group_by, []))
+    
+    def scalar(self):
+        return _ScalarSelect(self)
+    
+    def label(self, name):
+        return self.scalar().label(name)
         
     def supports_execution(self):
         return True
         return select([self], whereclauses, **params)
 
     def _get_from_objects(self, is_where=False, **modifiers):
-        if is_where or self.is_scalar:
+        if is_where:
             return []
         else:
             return [self]
 
+class _ScalarSelect(_Grouping):
+    __visit_name__ = 'grouping'
+
+    def __init__(self, elem):
+        super(_ScalarSelect, self).__init__(elem)
+        self.type = list(elem.inner_columns)[0].type
+
+    columns = property(lambda self:[self])
+    
+    def self_group(self, **kwargs):
+        return self
+
+    def _make_proxy(self, selectable, name):
+        return list(self.inner_columns)[0]._make_proxy(selectable, name)
+
+    def _get_from_objects(self, **modifiers):
+        return []
+
 class CompoundSelect(_SelectBaseMixin, FromClause):
     def __init__(self, keyword, *selects, **kwargs):
         self._should_correlate = kwargs.pop('correlate', False)
 
     def _get_inner_columns(self):
         for c in self._raw_columns:
-            # TODO: need to have Select, as well as a Select inside a _Grouping,
-            # give us a clearer idea of if we want its column list or not
-            if hasattr(c, '_selectable') and not getattr(c, 'is_scalar', False):
-                for co in c._selectable().columns:
+            if isinstance(c, Selectable):
+                for co in c.columns:
                     yield co
             else:
                 yield c
         if _is_literal(column):
             column = literal_column(str(column))
 
-        if isinstance(column, Select) and column.is_scalar:
+        if isinstance(column, _ScalarSelect):
             column = column.self_group(against=ColumnOperators.comma_op)
 
         self._raw_columns.append(column)
             fromclause = FromClause(fromclause)
         self._froms.add(fromclause)
 
-    def _make_proxy(self, selectable, name):
-        if self.is_scalar:
-            return list(self.inner_columns)[0]._make_proxy(selectable, name)
-        else:
-            raise exceptions.InvalidRequestError("Not a scalar select statement")
-
-    def label(self, name):
-        if not self.is_scalar:
-            raise exceptions.InvalidRequestError("Not a scalar select statement")
-        else:
-            return label(name, self)
-
-    def _get_type(self):
-        if self.is_scalar:
-            return list(self.inner_columns)[0].type
-        else:
-            return None
-    type = property(_get_type)
-
     def _exportable_columns(self):
         return [c for c in self._raw_columns if isinstance(c, Selectable)]
         

test/orm/generative.py

         mapper(Foo, foo)
         metadata.create_all()
         
-        sess = create_session()
+        sess = create_session(bind=testbase.db)
         for i in range(100):
             sess.save(Foo(bar=i, range=i%10))
         sess.flush()
         clear_mappers()
     
     def test_selectby(self):
-        res = create_session().query(Foo).filter_by(range=5)
+        res = create_session(bind=testbase.db).query(Foo).filter_by(range=5)
         assert res.order_by([Foo.c.bar])[0].bar == 5
         assert res.order_by([desc(Foo.c.bar)])[0].bar == 95
         
     @testing.unsupported('mssql')
     def test_slice(self):
-        sess = create_session()
+        sess = create_session(bind=testbase.db)
         query = sess.query(Foo)
         orig = query.all()
         assert query[1] == orig[1]
 
     @testing.supported('mssql')
     def test_slice_mssql(self):
-        sess = create_session()
+        sess = create_session(bind=testbase.db)
         query = sess.query(Foo)
         orig = query.all()
         assert list(query[:10]) == orig[:10]
         assert list(query[:10]) == orig[:10]
 
     def test_aggregate(self):
-        sess = create_session()
+        sess = create_session(bind=testbase.db)
         query = sess.query(Foo)
         assert query.count() == 100
         assert query.filter(foo.c.bar<30).min(foo.c.bar) == 0
     @testing.unsupported('mysql')
     def test_aggregate_1(self):
         # this one fails in mysql as the result comes back as a string
-        query = create_session().query(Foo)
+        query = create_session(bind=testbase.db).query(Foo)
         assert query.filter(foo.c.bar<30).sum(foo.c.bar) == 435
 
     @testing.unsupported('postgres', 'mysql', 'firebird', 'mssql')
     def test_aggregate_2(self):
-        query = create_session().query(Foo)
+        query = create_session(bind=testbase.db).query(Foo)
         assert query.filter(foo.c.bar<30).avg(foo.c.bar) == 14.5
 
     @testing.supported('postgres', 'mysql', 'firebird', 'mssql')
     def test_aggregate_2_int(self):
-        query = create_session().query(Foo)
+        query = create_session(bind=testbase.db).query(Foo)
         assert int(query.filter(foo.c.bar<30).avg(foo.c.bar)) == 14
 
     @testing.unsupported('postgres', 'mysql', 'firebird', 'mssql')
     def test_aggregate_3(self):
-        query = create_session().query(Foo)
+        query = create_session(bind=testbase.db).query(Foo)
         assert query.filter(foo.c.bar<30).apply_avg(foo.c.bar).first() == 14.5
         assert query.filter(foo.c.bar<30).apply_avg(foo.c.bar).one() == 14.5
         
     def test_filter(self):
-        query = create_session().query(Foo)
+        query = create_session(bind=testbase.db).query(Foo)
         assert query.count() == 100
         assert query.filter(Foo.c.bar < 30).count() == 30
         res2 = query.filter(Foo.c.bar < 30).filter(Foo.c.bar > 10)
         assert res2.count() == 19
     
     def test_options(self):
-        query = create_session().query(Foo)
+        query = create_session(bind=testbase.db).query(Foo)
         class ext1(MapperExtension):
             def populate_instance(self, mapper, selectcontext, row, instance, **flags):
                 instance.TEST = "hello world"
         assert query.options(extension(ext1()))[0].TEST == "hello world"
         
     def test_order_by(self):
-        query = create_session().query(Foo)
+        query = create_session(bind=testbase.db).query(Foo)
         assert query.order_by([Foo.c.bar])[0].bar == 0
         assert query.order_by([desc(Foo.c.bar)])[0].bar == 99
 
     def test_offset(self):
-        query = create_session().query(Foo)
+        query = create_session(bind=testbase.db).query(Foo)
         assert list(query.order_by([Foo.c.bar]).offset(10))[0].bar == 10
         
     def test_offset(self):
-        query = create_session().query(Foo)
+        query = create_session(bind=testbase.db).query(Foo)
         assert len(list(query.limit(10))) == 10
 
 class Obj1(object):
 class GenerativeTest2(PersistTest):
     def setUpAll(self):
         global metadata, table1, table2
-        metadata = MetaData(testbase.db)
+        metadata = MetaData()
         table1 = Table('Table1', metadata,
             Column('id', Integer, primary_key=True),
             )
             )
         mapper(Obj1, table1)
         mapper(Obj2, table2)
-        metadata.create_all()
-        table1.insert().execute({'id':1},{'id':2},{'id':3},{'id':4})
-        table2.insert().execute({'num':1,'t1id':1},{'num':2,'t1id':1},{'num':3,'t1id':1},\
+        metadata.create_all(bind=testbase.db)
+        testbase.db.execute(table1.insert(), {'id':1},{'id':2},{'id':3},{'id':4})
+        testbase.db.execute(table2.insert(), {'num':1,'t1id':1},{'num':2,'t1id':1},{'num':3,'t1id':1},\
 {'num':4,'t1id':2},{'num':5,'t1id':2},{'num':6,'t1id':3})
 
     def tearDownAll(self):
-        metadata.drop_all()
+        metadata.drop_all(bind=testbase.db)
         clear_mappers()
 
     def test_distinctcount(self):
-        query = create_session().query(Obj1)
+        query = create_session(bind=testbase.db).query(Obj1)
         assert query.count() == 4
         res = query.filter(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1))
         assert res.count() == 3
                 'items':relation(mapper(tables.Item, tables.orderitems))
             }))
         })
-        session = create_session()
+        session = create_session(bind=testbase.db)
         query = session.query(tables.User)
         x = query.join(['orders', 'items']).filter(tables.Item.c.item_id==2)
         print x.compile()
                 'items':relation(mapper(tables.Item, tables.orderitems))
             }))
         })
-        session = create_session()
+        session = create_session(bind=testbase.db)
         query = session.query(tables.User)
         x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
         print x.compile()
                 'items':relation(mapper(tables.Item, tables.orderitems))
             }))
         })
-        session = create_session()
+        session = create_session(bind=testbase.db)
         query = session.query(tables.User)
         x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)).count()
         assert x==2
                 'items':relation(mapper(tables.Item, tables.orderitems))
             }))
         })
-        session = create_session()
+        session = create_session(bind=testbase.db)
         query = session.query(tables.User)
         x = query.select_from([tables.users.outerjoin(tables.orders).outerjoin(tables.orderitems)]).\
             filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
         clear_mappers()
         
     def test_distinctcount(self):
-        q = create_session().query(Obj1)
+        q = create_session(bind=testbase.db).query(Obj1)
         assert q.count() == 4
         res = q.filter(and_(table1.c.ID==table2.c.T1ID,table2.c.T1ID==1))
         assert res.count() == 3
     def test_noautojoin(self):
         class T(object):pass
         mapper(T, t1, properties={'children':relation(T)})
-        sess = create_session()
+        sess = create_session(bind=testbase.db)
         try:
             sess.query(T).join('children').select_by(id=7)
             assert False

test/orm/query.py

 
         mapper(User, users, properties={
             'concat': column_property(f),
-            'count': column_property(select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id, scalar=True).correlate(users).label('count'))
+            'count': column_property(select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).correlate(users).label('count'))
         })
 
         mapper(Address, addresses, properties={

test/sql/query.py

         x = testbase.db.func.current_date().execute().scalar()
         y = testbase.db.func.current_date().select().execute().scalar()
         z = testbase.db.func.current_date().scalar()
-        assert x == y == z
+        assert (x == y == z) is True
         
         x = testbase.db.func.current_date(type_=Date)
         assert isinstance(x.type, Date)
             z = conn.scalar(func.current_date())
         finally:
             conn.close()
-        assert x == y == z
-        
+        assert (x == y == z) is True
+         
     def test_update_functions(self):
         """test sending functions and SQL expressions to the VALUES and SET clauses of INSERT/UPDATE instances,
         and that column-level defaults get overridden"""

test/sql/select.py

 
         s = select([table1.c.myid], scalar=True)
         self.runtest(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) FROM myothertable")
-        
+
+        s = select([table1.c.myid]).correlate(None).scalar()
+        self.runtest(select([table1, s]), "SELECT mytable.myid, mytable.name, mytable.description, (SELECT mytable.myid FROM mytable) FROM mytable")
+
+        s = select([table1.c.myid]).scalar()
+        self.runtest(select([table2, s]), "SELECT myothertable.otherid, myothertable.othername, (SELECT mytable.myid FROM mytable) FROM myothertable")
+
+        # test expressions against scalar selects
+        self.runtest(select([s - literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) - :literal")
+        self.runtest(select([select([table1.c.name]).scalar() + literal('x')]), "SELECT (SELECT mytable.name FROM mytable) || :literal")
+        self.runtest(select([s > literal(8)]), "SELECT (SELECT mytable.myid FROM mytable) > :literal")
+
+        self.runtest(select([select([table1.c.name]).label('foo')]), "SELECT (SELECT mytable.name FROM mytable) AS foo")
+
 
         zips = table('zips',
             column('zipcode'),
             column('nm')
         )
         zip = '12345'
-        qlat = select([zips.c.latitude], zips.c.zipcode == zip, scalar=True, correlate=False)
-        qlng = select([zips.c.longitude], zips.c.zipcode == zip, scalar=True, correlate=False)
+        qlat = select([zips.c.latitude], zips.c.zipcode == zip).correlate(None).scalar()
+        qlng = select([zips.c.longitude], zips.c.zipcode == zip).correlate(None).scalar()
  
         q = select([places.c.id, places.c.nm, zips.c.zipcode, func.latlondist(qlat, qlng).label('dist')],
                          zips.c.zipcode==zip,