Mike Bayer avatar Mike Bayer committed 18905ba

- [bug] UPDATE..FROM syntax with SQL Server
requires that the updated table be present
in the FROM clause when an alias of that
table is also present in the FROM clause.
The updated table is now always present
in the FROM, when FROM is present
in the first place. Courtesy sayap.
[ticket:2468]

Comments (0)

Files changed (6)

     INSERT to get at the last inserted ID,
     for those tables which have "implicit_returning"
     set to False.
+ 
+  - [bug] UPDATE..FROM syntax with SQL Server
+    requires that the updated table be present
+    in the FROM clause when an alias of that
+    table is also present in the FROM clause.
+    The updated table is now always present
+    in the FROM, when FROM is present 
+    in the first place.  Courtesy sayap.
+    [ticket:2468]
 
 - postgresql
   - [feature] Added new for_update/with_lockmode()

lib/sqlalchemy/dialects/mssql/base.py

         else:
             return ""
 
+    def update_from_clause(self, update_stmt,
+                                from_table, extra_froms,
+                                from_hints,
+                                **kw):
+        """Render the UPDATE..FROM clause specific to MSSQL.
+        
+        In MSSQL, if the UPDATE statement involves an alias of the table to
+        be updated, then the table itself must be added to the FROM list as
+        well. Otherwise, it is optional. Here, we add it regardless.
+        
+        """
+        return "FROM " + ', '.join(
+                    t._compiler_dispatch(self, asfrom=True,
+                                    fromhints=from_hints, **kw)
+                    for t in [from_table] + extra_froms)
+
 class MSSQLStrictCompiler(MSSQLCompiler):
     """A subclass of MSSQLCompiler which disables the usage of bind
     parameters where not allowed natively by MS-SQL.

lib/sqlalchemy/sql/compiler.py

         """Provide a hook to override the generation of an 
         UPDATE..FROM clause.
 
-        MySQL overrides this.
+        MySQL and MSSQL override this.
 
         """
         return "FROM " + ', '.join(

test/dialect/test_mssql.py

                                 selectable=t2, 
                                 dialect_name=darg),
                 "UPDATE sometable SET somecolumn=:somecolumn "
-                "FROM othertable WITH (PAGLOCK) "
+                "FROM sometable, othertable WITH (PAGLOCK) "
                 "WHERE sometable.somecolumn = othertable.somecolumn"
             )
 

test/sql/test_compiler.py

                 "UPDATE mytable SET name=:name "
                 "FROM myothertable WHERE myothertable.otherid = mytable.myid")
 
+        self.assert_compile(u,
+                "UPDATE mytable SET name=:name "
+                "FROM mytable, myothertable WHERE "
+                "myothertable.otherid = mytable.myid",
+                dialect=mssql.dialect())
+
+        self.assert_compile(u.where(table2.c.othername == mt.c.name),
+                "UPDATE mytable SET name=:name "
+                "FROM mytable, myothertable, mytable AS mytable_1 "
+                "WHERE myothertable.otherid = mytable.myid "
+                "AND myothertable.othername = mytable_1.name",
+                dialect=mssql.dialect())
+
     def test_delete(self):
         self.assert_compile(
                         delete(table1, table1.c.myid == 7), 

test/sql/test_update.py

         )
 
     @testing.requires.update_from
+    def test_exec_two_table_plus_alias(self):
+        users, addresses = self.tables.users, self.tables.addresses
+        a1 = addresses.alias()
+
+        testing.db.execute(
+            addresses.update().\
+                values(email_address=users.c.name).\
+                where(users.c.id==a1.c.user_id).\
+                where(users.c.name=='ed').\
+                where(a1.c.id==addresses.c.id)
+        )
+        eq_(
+            testing.db.execute(
+                addresses.select().\
+                    order_by(addresses.c.id)).fetchall(),
+            [
+                (1, 7, 'x', "jack@bean.com"),
+                (2, 8, 'x', "ed"),
+                (3, 8, 'x', "ed"),
+                (4, 8, 'x', "ed"),
+                (5, 9, 'x', "fred@fred.com")
+            ]
+        )
+
+    @testing.requires.update_from
     def test_exec_three_table(self):
         users, addresses, dingalings = \
                 self.tables.users, \
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.