Mike Bayer avatar Mike Bayer committed 267b178

- column.in_(someselect) can now be used as
a columns-clause expression without the subquery
bleeding into the FROM clause [ticket:1074]

Comments (0)

Files changed (3)

       SessionExtension.before_flush() will take
       effect for that flush.
 
+- sql
+    - column.in_(someselect) can now be used as 
+      a columns-clause expression without the subquery
+      bleeding into the FROM clause [ticket:1074]
+      
 - mysql
     - Added MSMediumInteger type [ticket:1146].
 

lib/sqlalchemy/sql/expression.py

 
     def _in_impl(self, op, negate_op, *other):
         # Handle old style *args argument passing
-        if len(other) != 1 or not isinstance(other[0], Selectable) and (not hasattr(other[0], '__iter__') or isinstance(other[0], basestring)):
+        if len(other) != 1 or not isinstance(other[0], (_ScalarSelect, Selectable)) and (not hasattr(other[0], '__iter__') or isinstance(other[0], basestring)):
             util.warn_deprecated('passing in_ arguments as varargs is deprecated, in_ takes a single argument that is a sequence or a selectable')
             seq_or_selectable = other
         else:
             seq_or_selectable = other[0]
 
-        if isinstance(seq_or_selectable, Selectable):
-            return self.__compare( op, seq_or_selectable, negate=negate_op)
+        if isinstance(seq_or_selectable, _ScalarSelect):
+             return self.__compare( op, seq_or_selectable, negate=negate_op)
+        elif isinstance(seq_or_selectable, _SelectBaseMixin):
+             return self.__compare( op, seq_or_selectable.as_scalar(), negate=negate_op)
+        elif isinstance(seq_or_selectable, Selectable):
+             return self.__compare( op, seq_or_selectable, negate=negate_op)
 
         # Handle non selectable arguments as sequences
         args = []
 
     def __init__(self, elem):
         self.elem = elem
-        cols = list(elem.inner_columns)
+        cols = list(elem.c)
         if len(cols) != 1:
             raise exceptions.InvalidRequestError("Scalar select can only be created from a Select object that has exactly one column expression.")
         self.type = cols[0].type
                 self.selects.append(s)
 
         _SelectBaseMixin.__init__(self, **kwargs)
-        
+
     def self_group(self, against=None):
         return _FromGrouping(self)
 

test/sql/select.py

 
         self.assert_compile(select([table1], table1.c.myid.in_(
             union(
-                  select([table1], table1.c.myid == 5),
-                  select([table1], table1.c.myid == 12),
+                  select([table1.c.myid], table1.c.myid == 5),
+                  select([table1.c.myid], table1.c.myid == 12),
             )
         )), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable \
 WHERE mytable.myid IN (\
-SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :myid_1 \
-UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :myid_2)")
+SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_1 \
+UNION SELECT mytable.myid FROM mytable WHERE mytable.myid = :myid_2)")
 
         # test that putting a select in an IN clause does not blow away its ORDER BY clause
         self.assert_compile(
         self.assert_compile(select([table1], table1.c.myid.in_([])),
         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE (CASE WHEN (mytable.myid IS NULL) THEN NULL ELSE 0 END = 1)")
 
+        self.assert_compile(
+            select([table1.c.myid.in_(select([table2.c.otherid]))]),
+            "SELECT mytable.myid IN (SELECT myothertable.otherid FROM myothertable) AS anon_1 FROM mytable"
+        )
+        self.assert_compile(
+            select([table1.c.myid.in_(select([table2.c.otherid]).as_scalar())]),
+            "SELECT mytable.myid IN (SELECT myothertable.otherid FROM myothertable) AS anon_1 FROM mytable"
+        )
+
     def test_in_deprecated_api(self):
         self.assert_compile(select([table1], table1.c.myid.in_('abc')),
         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:myid_1)")
 
         self.assert_compile(select([table1], table1.c.myid.in_()),
         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE (CASE WHEN (mytable.myid IS NULL) THEN NULL ELSE 0 END = 1)")
+
     test_in_deprecated_api = testing.uses_deprecated('passing in_')(test_in_deprecated_api)
 
+
     def test_cast(self):
         tbl = table('casttest',
                     column('id', Integer),
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.