Michael Manfre avatar Michael Manfre committed fab0dcd

Fixed #20 - the insert ID will be returned, even for tables with triggers.

Comments (0)

Files changed (3)

docs/changelog.txt

 
 - Backend now supports returning the ID from an insert without needing an additional query. This is disabled
   for SQL Server 2000 (assuming that version still works with this backend). :issue:`17`
+
+  - This will work even if the table has a trigger. :issue:`20`
+
 - Subqueries will have their ordering removed because SQL Server only supports it when using TOP or FOR XML. 
   This relies upon the ``with_col_aliases`` argument to ``SQLCompiler.as_sql`` only being ``True`` when the query 
   is a subquery, which is currently the case for all usages in Django 1.5 master. :issue:`18`

sqlserver_ado/compiler.py

     # search for after table/column list
     _re_values_sub = re.compile(r'(?P<prefix>\)|\])(?P<default>\s*|\s*default\s*)values(?P<suffix>\s*|\s+\()?', re.IGNORECASE)
     # ... and insert the OUTPUT clause between it and the values list (or DEFAULT VALUES).
-    _values_repl = r'\g<prefix> OUTPUT INSERTED.{col}\g<default>VALUES\g<suffix>'
+    _values_repl = r'\g<prefix> OUTPUT INSERTED.{col} INTO @sqlserver_ado_return_id\g<default>VALUES\g<suffix>'
 
     def as_sql(self, *args, **kwargs):
         # Fix for Django ticket #14019
         # mangle SQL to return ID from insert
         # http://msdn.microsoft.com/en-us/library/ms177564.aspx
         if self.return_id and self.connection.features.can_return_id_from_insert:
-            sql = 'SET NOCOUNT ON; {sql}'.format(sql=sql)
+            col = self.connection.ops.quote_name(meta.pk.db_column or meta.pk.get_attname())
+
+            # Determine datatype for use with the table variable that will return the inserted ID            
+            pk_db_type = meta.pk.db_type(self.connection)
+            if ' IDENTITY ' in pk_db_type:
+                # separate off IDENTITY clause
+                pk_db_type, _ = pk_db_type.split(' IDENTITY ', 2)
+            if ' CHECK ' in pk_db_type:
+                # separate off CHECK clause
+                pk_db_type, _ = pk_db_type.split(' CHECK ', 2)
             
-            col = self.connection.ops.quote_name(meta.pk.db_column or meta.pk.get_attname())
+            # NOCOUNT ON to prevent additional trigger/stored proc related resultsets
+            sql = 'SET NOCOUNT ON;{declare_table_var};{sql};{select_return_id}'.format(
+                sql=sql,
+                declare_table_var="DECLARE @sqlserver_ado_return_id table ({col_name} {pk_type})".format(
+                    col_name=col,
+                    pk_type=pk_db_type,
+                ),
+                select_return_id="SELECT * FROM @sqlserver_ado_return_id",
+            )
+            
             output = self._values_repl.format(col=col)
             sql = self._re_values_sub.sub(output, sql)
 

tests/test_main/regressiontests/tests.py

 import datetime
 import decimal
 from django.core.exceptions import ImproperlyConfigured
-from django.db import models
+from django.db import models, connection
 from django.test import TestCase
 
 from regressiontests.models import Bug69Table1, Bug69Table2, Bug70Table, Bug93Table, IntegerIdTable
                 'HOST': 'my.fqdn.com',
                 'PORT': '1433',
             })
+
+
+class PkPlusOne(models.Model):
+    id = models.IntegerField(primary_key=True)
+    a = models.IntegerField(null=True)
+
+class AutoPkPlusOne(models.Model):
+    id = models.AutoField(primary_key=True)
+    a = models.IntegerField(null=True)
+
+class TextPkPlusOne(models.Model):
+    id = models.CharField(primary_key=True, max_length=10)
+    a = models.IntegerField(null=True)
+
+class ReturnIdOnInsertWithTriggersTestCase(TestCase):
+    def create_trigger(self, model):
+        """Create a trigger for the provided model"""
+        qn = connection.ops.quote_name
+        table_name = qn(model._meta.db_table)
+        trigger_name = qn('test_trigger_%s' % model._meta.db_table)
+        
+        with connection.cursor() as cur:
+            # drop trigger if it exists
+            drop_sql = """
+IF OBJECT_ID(N'[dbo].{trigger}') IS NOT NULL
+    DROP TRIGGER [dbo].{trigger}
+""".format(trigger=trigger_name)
+            
+            create_sql = """
+CREATE TRIGGER [dbo].{trigger} ON {tbl} FOR INSERT
+AS UPDATE {tbl} set [a] = 100""".format(
+                trigger=trigger_name,
+                tbl=table_name,
+            )
+            
+            cur.execute(drop_sql)
+            cur.execute(create_sql)
+
+    def test_pk(self):
+        self.create_trigger(PkPlusOne)
+        id = 1
+        obj = PkPlusOne.objects.create(id=id)
+        self.assertEqual(obj.pk, id)
+        self.assertEqual(PkPlusOne.objects.get(pk=id).a, 100)
+
+    def test_auto_pk(self):
+        self.create_trigger(AutoPkPlusOne)
+        id = 1
+        obj = AutoPkPlusOne.objects.create()
+        self.assertEqual(obj.pk, id)
+        self.assertEqual(AutoPkPlusOne.objects.get(pk=id).a, 100)
+
+    def test_text_pk(self):
+        self.create_trigger(TextPkPlusOne)
+        id = 'asdf'
+        obj = TextPkPlusOne.objects.create(id=id)
+        self.assertEqual(obj.pk, id)
+        self.assertEqual(TextPkPlusOne.objects.get(pk=id).a, 100)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.