Commits

Cesar Canassa committed 3469d36

Merged with svn version 190

Comments (0)

Files changed (4)

sql_server/pyodbc/base.py

 class DatabaseFeatures(BaseDatabaseFeatures):
     uses_custom_query_class = True
     can_use_chunked_reads = False
+    can_return_id_from_insert = True
     #uses_savepoints = True
 
 
             self.datefirst = self.settings_dict['OPTIONS'].get('datefirst', 7)
             self.unicode_results = self.settings_dict['OPTIONS'].get('unicode_results', False)
 
-        self.features = DatabaseFeatures()
-        self.ops = DatabaseOperations()
+        if _DJANGO_VERSION >= 13:
+            self.features = DatabaseFeatures(self)
+        else:
+            self.features = DatabaseFeatures()
+        self.ops = DatabaseOperations(self)
         self.client = DatabaseClient(self)
         self.creation = DatabaseCreation(self)
         self.introspection = DatabaseIntrospection(self)
     def _cursor(self):
         new_conn = False
         settings_dict = self.settings_dict
-        db_str, user_str, passwd_str, port_str = None, None, None, None
+        db_str, user_str, passwd_str, port_str = None, None, "", None
         if _DJANGO_VERSION >= 12:
             options = settings_dict['OPTIONS']
             if settings_dict['NAME']:
                 if os.name == 'nt' or driver == 'FreeTDS' and \
                         options.get('host_is_server', False):
                     if port_str:
-                        host_str += ',%s' % port_str
+                        host_str += ';PORT=%s' % port_str
                     cstr_parts.append('SERVER=%s' % host_str)
                 else:
                     cstr_parts.append('SERVERNAME=%s' % host_str)
             # Django convention for the 'week_day' Django lookup) if the user
             # hasn't told us otherwise
             cursor.execute("SET DATEFORMAT ymd; SET DATEFIRST %s" % self.datefirst)
-            if self.ops._get_sql_server_ver(self.connection) < 2005:
+            if self.ops.sql_server_ver < 2005:
                 self.creation.data_types['TextField'] = 'ntext'
+                self.features.can_return_id_from_insert = False
 
             if self.driver_needs_utf8 is None:
                 self.driver_needs_utf8 = True
                     self.driver_needs_utf8 = False
 
                 # http://msdn.microsoft.com/en-us/library/ms131686.aspx
-                if self.ops._get_sql_server_ver(self.connection) >= 2005 and self.drv_name in ('SQLNCLI.DLL', 'SQLNCLI10.DLL') and self.MARS_Connection:
+                if self.ops.sql_server_ver >= 2005 and self.drv_name in ('SQLNCLI.DLL', 'SQLNCLI10.DLL') and self.MARS_Connection:
                     # How to to activate it: Add 'MARS_Connection': True
                     # to the DATABASE_OPTIONS dictionary setting
                     self.features.can_use_chunked_reads = True

sql_server/pyodbc/compiler.py

         fallback_ordering = '%s.%s' % (qn(meta.db_table), qn(meta.pk.db_column or meta.pk.column))
 
         # SQL Server 2000, offset+limit case
-        if self.connection.ops._get_sql_server_ver(self.connection) < 2005 and self.query.high_mark is not None:
+        if self.connection.ops.sql_server_ver < 2005 and self.query.high_mark is not None:
             orig_sql, params = self._as_sql(USE_TOP_HMARK)
             if self._ord:
                 ord = ', '.join(['%s %s' % pair for pair in self._ord])
             return sql, params
 
         # SQL Server 2005
-        if self.connection.ops._get_sql_server_ver(self.connection) >= 2005:
+        if self.connection.ops.sql_server_ver >= 2005:
             sql, params = self._as_sql(USE_ROW_NUMBER)
 
             # Construct the final SQL clause, using the initial select SQL
 
 
 class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
-    def as_sql(self, *args, **kwargs):
-        # Fix for Django ticket #14019
-        if not hasattr(self, 'return_id'):
-            self.return_id = False
-
-        sql, params = super(SQLInsertCompiler, self).as_sql(*args, **kwargs)
-
+    def as_sql(self):
+        # We don't need quote_name_unless_alias() here, since these are all
+        # going to be column names (so we can avoid the extra overhead).
+        qn = self.connection.ops.quote_name
+        opts = self.query.model._meta
+        result = ['INSERT INTO %s' % qn(opts.db_table)]
+        result.append('(%s)' % ', '.join([qn(c) for c in self.query.columns]))
+        if self.return_id and self.connection.features.can_return_id_from_insert:
+            output = 'OUTPUT inserted.%s' % qn(opts.pk.column)
+            result.append(output)
+        values = [self.placeholder(*v) for v in self.query.values]
+        result.append('VALUES (%s)' % ', '.join(values))
+        params = self.query.params
+        sql = ' '.join(result)
+        
         meta = self.query.get_meta()
 
         if meta.has_auto_field:
 
             if auto_field_column in self.query.columns:
                 quoted_table = self.connection.ops.quote_name(meta.db_table)
-                sql = "SET IDENTITY_INSERT %s ON;%s;SET IDENTITY_INSERT %s OFF" %\
-                    (quoted_table, sql, quoted_table)
-
+                if len(self.query.columns) == 1 and not params:
+                    sql = "INSERT INTO %s DEFAULT VALUES" % quoted_table
+                else:
+                    sql = "SET IDENTITY_INSERT %s ON;\n%s;\nSET IDENTITY_INSERT %s OFF" % \
+                        (quoted_table, sql, quoted_table)
         return sql, params
 
 

sql_server/pyodbc/introspection.py

 AND ix.is_unique_constraint = 0
 AND t.name = %s"""
 
-        if self.connection.ops._get_sql_server_ver(self.connection) >= 2005:
+        if self.connection.ops.sql_server_ver >= 2005:
             cursor.execute(ix_sql, (table_name,))
             for column in [r[0] for r in cursor.fetchall()]:
                 if column not in results:

sql_server/pyodbc/operations.py

 
 class DatabaseOperations(BaseDatabaseOperations):
     compiler_module = "sql_server.pyodbc.compiler"
-    def __init__(self):
+    def __init__(self, connection):
         super(DatabaseOperations, self).__init__()
+        self.connection = connection
         self._ss_ver = None
 
-    def _get_sql_server_ver(self, connection=None):
+    def _get_sql_server_ver(self):
         """
         Returns the version of the SQL Server in use:
         """
         if self._ss_ver is not None:
             return self._ss_ver
+        cur = self.connection.cursor()
+        cur.execute("SELECT CAST(SERVERPROPERTY('ProductVersion') as varchar)")
+        ver_code = int(cur.fetchone()[0].split('.')[0])
+        if ver_code >= 10:
+            self._ss_ver = 2008
+        elif ver_code == 9:
+            self._ss_ver = 2005
         else:
-            if connection:
-                cur = connection.cursor()
-            else:
-                from django.db import connection
-                cur = connection.cursor()
-            cur.execute("SELECT CAST(SERVERPROPERTY('ProductVersion') as varchar)")
-            ver_code = int(cur.fetchone()[0].split('.')[0])
-            if ver_code >= 10:
-                self._ss_ver = 2008
-            elif ver_code == 9:
-                self._ss_ver = 2005
-            else:
-                self._ss_ver = 2000
-            return self._ss_ver
+            self._ss_ver = 2000
+        return self._ss_ver
     sql_server_ver = property(_get_sql_server_ver)
 
     def date_extract_sql(self, lookup_type, field_name):
         cursor.execute("SELECT CAST(IDENT_CURRENT(%s) as bigint)", [table_name])
         return cursor.fetchone()[0]
 
+    def fetch_returned_insert_id(self, cursor):
+        """
+        Given a cursor object that has just performed an INSERT/OUTPUT statement
+        into a table that has an auto-incrementing ID, returns the newly created
+        ID.
+        """
+        return cursor.fetchone()[0]
+
     def lookup_cast(self, lookup_type):
         if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
             return "UPPER(%s)"
         elif value is not None and field and field.get_internal_type() == 'FloatField':
             value = float(value)
         return value
+