adamv avatar adamv committed d908f55

* Took sql_flush from django-pyodbc (for unit tests)
* Using Django's unit test base class instead of the unittest one
* Added fixture for paging tests
* Column aliasing code is now more robust
* Separated DatabaseOperations from base.py

Comments (0)

Files changed (6)

source/sqlserver_ado/base.py

 """Microsoft SQL Server database backend for Django."""
-from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, BaseDatabaseOperations, BaseDatabaseValidation, BaseDatabaseClient
+from django.db.backends import BaseDatabaseWrapper, BaseDatabaseFeatures, BaseDatabaseValidation, BaseDatabaseClient
 
 from django.core.exceptions import ImproperlyConfigured
 
 import dbapi as Database
-import query
+
 from introspection import DatabaseIntrospection
 from creation import DatabaseCreation
+from operations import DatabaseOperations
 
 DatabaseError = Database.DatabaseError
 IntegrityError = Database.IntegrityError
 class DatabaseFeatures(BaseDatabaseFeatures):
     uses_custom_query_class = True
 
-class DatabaseOperations(BaseDatabaseOperations):
-    def date_extract_sql(self, lookup_type, field_name):
-        return "DATEPART(%s, %s)" % (lookup_type, self.quote_name(field_name))
-
-    def date_trunc_sql(self, lookup_type, field_name):
-    	quoted_field_name = self.quote_name(field_name)
-
-        if lookup_type == 'year':
-            return "Convert(datetime, Convert(varchar, DATEPART(year, %s)) + '/01/01')" % quoted_field_name
-        if lookup_type == 'month':
-            return "Convert(datetime, Convert(varchar, DATEPART(year, %s)) + '/' + Convert(varchar, DATEPART(month, %s)) + '/01')" %\
-                (quoted_field_name, quoted_field_name)
-        if lookup_type == 'day':
-            return "Convert(datetime, Convert(varchar(12), %s))" % quoted_field_name
-
-    def last_insert_id(self, cursor, table_name, pk_name):
-        cursor.execute("SELECT CAST(IDENT_CURRENT(%s) as bigint)", [self.quote_name(table_name)])
-        return cursor.fetchone()[0]
-
-    def query_class(self, DefaultQueryClass):
-        return query.query_class(DefaultQueryClass, Database)
-
-    def quote_name(self, name):
-        if name.startswith('[') and name.endswith(']'):
-            return name # already quoted
-        return '[%s]' % name
-
-    def random_function_sql(self):
-        return 'RAND()'
-
-    def regex_lookup(self, lookup_type):
-		# Case sensitivity
-		match_option = {'iregex':0, 'regex':1}[lookup_type]
-		return "dbo.REGEXP_LIKE(%%s, %%s, %s)=1" % (match_option,)
-
-    def tablespace_sql(self, tablespace, inline=False):
-        return "ON %s" % self.quote_name(tablespace)
-        
-    def no_limit_value(self):
-        return None
-
-    def value_to_db_datetime(self, value):
-        # MS SQL 2005 doesn't support microseconds
-        if value is None:
-            return None
-        return value.replace(microsecond=0)
-    
-    def value_to_db_time(self, value):
-        # MS SQL 2005 doesn't support microseconds
-        #...but it also doesn't really suport bare times
-        if value is None:
-            return None
-        return value.replace(microsecond=0)
-	        
-    def value_to_db_decimal(self, value, max_digits, decimal_places):
-        if value is None or value == '':
-            return None
-        return value # Should be a decimal type (or string)
-
-    def prep_for_like_query(self, x):
-        """Prepares a value for use in a LIKE query."""
-        from django.utils.encoding import smart_unicode
-        return (
-            smart_unicode(x).\
-                replace("\\", "\\\\").\
-                replace("%", "\%").\
-                replace("_", "\_").\
-                replace("[", "\[").\
-                replace("]", "\]")
-            )
-
 # IP Address recognizer taken from:
 # http://mail.python.org/pipermail/python-list/2006-March/375505.html
 def _looks_like_ipaddress(address):

source/sqlserver_ado/operations.py

+from django.db.backends import BaseDatabaseOperations
+import datetime
+import time
+
+import query
+
+class DatabaseOperations(BaseDatabaseOperations):
+    def date_extract_sql(self, lookup_type, field_name):
+        return "DATEPART(%s, %s)" % (lookup_type, self.quote_name(field_name))
+
+    def date_trunc_sql(self, lookup_type, field_name):
+    	quoted_field_name = self.quote_name(field_name)
+
+        if lookup_type == 'year':
+            return "Convert(datetime, Convert(varchar, DATEPART(year, %s)) + '/01/01')" % quoted_field_name
+        if lookup_type == 'month':
+            return "Convert(datetime, Convert(varchar, DATEPART(year, %s)) + '/' + Convert(varchar, DATEPART(month, %s)) + '/01')" %\
+                (quoted_field_name, quoted_field_name)
+        if lookup_type == 'day':
+            return "Convert(datetime, Convert(varchar(12), %s))" % quoted_field_name
+
+    def last_insert_id(self, cursor, table_name, pk_name):
+        cursor.execute("SELECT CAST(IDENT_CURRENT(%s) as bigint)", [self.quote_name(table_name)])
+        return cursor.fetchone()[0]
+
+    def no_limit_value(self):
+        return None
+
+    def prep_for_like_query(self, x):
+        """Prepares a value for use in a LIKE query."""
+        from django.utils.encoding import smart_unicode
+        return (
+            smart_unicode(x).\
+                replace("\\", "\\\\").\
+                replace("%", "\%").\
+                replace("_", "\_").\
+                replace("[", "\[").\
+                replace("]", "\]")
+            )
+
+    def query_class(self, DefaultQueryClass):
+        return query.query_class(DefaultQueryClass)
+
+    def quote_name(self, name):
+        if name.startswith('[') and name.endswith(']'):
+            return name # already quoted
+        return '[%s]' % name
+
+    def random_function_sql(self):
+        return 'RAND()'
+
+    def regex_lookup(self, lookup_type):
+        # Case sensitivity
+        match_option = {'iregex':0, 'regex':1}[lookup_type]
+        return "dbo.REGEXP_LIKE(%%s, %%s, %s)=1" % (match_option,)
+
+    def sql_flush(self, style, tables, sequences):
+        """
+        Returns a list of SQL statements required to remove all data from
+        the given database tables (without actually removing the tables
+        themselves).
+
+        The `style` argument is a Style object as returned by either
+        color_style() or no_style() in django.core.management.color.
+        
+        Nicked from django-pyodbc
+        """
+        if tables:
+            # Cannot use TRUNCATE on tables that are referenced by a FOREIGN KEY
+            # So must use the much slower DELETE
+            from django.db import connection
+            cursor = connection.cursor()
+            # Try to minimize the risks of the braindeaded inconsistency in
+            # DBCC CHEKIDENT(table, RESEED, n) behavior.
+            seqs = []
+            for seq in sequences:
+                cursor.execute("SELECT COUNT(*) FROM %s" % self.quote_name(seq["table"]))
+                rowcnt = cursor.fetchone()[0]
+                elem = {}
+
+                if rowcnt:
+                    elem['start_id'] = 0
+                else:
+                    elem['start_id'] = 1
+
+                elem.update(seq)
+                seqs.append(elem)
+            cursor.execute("SELECT TABLE_NAME, CONSTRAINT_NAME FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS")
+            fks = cursor.fetchall()
+            sql_list = ['ALTER TABLE %s NOCHECK CONSTRAINT %s;' % \
+                    (self.quote_name(fk[0]), self.quote_name(fk[1])) for fk in fks]
+            sql_list.extend(['%s %s %s;' % (style.SQL_KEYWORD('DELETE'), style.SQL_KEYWORD('FROM'),
+                             style.SQL_FIELD(self.quote_name(table)) ) for table in tables])
+            # Then reset the counters on each table.
+            sql_list.extend(['%s %s (%s, %s, %s) %s %s;' % (
+                style.SQL_KEYWORD('DBCC'),
+                style.SQL_KEYWORD('CHECKIDENT'),
+                style.SQL_FIELD(self.quote_name(seq["table"])),
+                style.SQL_KEYWORD('RESEED'),
+                style.SQL_FIELD('%d' % seq['start_id']),
+                style.SQL_KEYWORD('WITH'),
+                style.SQL_KEYWORD('NO_INFOMSGS'),
+                ) for seq in seqs])
+            sql_list.extend(['ALTER TABLE %s CHECK CONSTRAINT %s;' % \
+                    (self.quote_name(fk[0]), self.quote_name(fk[1])) for fk in fks])
+            return sql_list
+        else:
+            return []
+
+    def tablespace_sql(self, tablespace, inline=False):
+        return "ON %s" % self.quote_name(tablespace)
+        
+    def value_to_db_datetime(self, value):
+        # MS SQL 2005 doesn't support microseconds
+        if value is None:
+            return None
+        return value.replace(microsecond=0)
+    
+    def value_to_db_time(self, value):
+        # MS SQL 2005 doesn't support microseconds
+        #...but it also doesn't really suport bare times
+        if value is None:
+            return None
+        return value.replace(microsecond=0)
+	        
+    def value_to_db_decimal(self, value, max_digits, decimal_places):
+        if value is None or value == '':
+            return None
+        return value # Should be a decimal type (or string)

source/sqlserver_ado/query.py

 def _remove_order_limit_offset(sql):
     return _re_order_limit_offset.sub('',sql)
 
-def query_class(QueryClass, Database):
+def query_class(QueryClass):
     """Return a custom Query subclass for SQL Server."""
     class SqlServerQuery(QueryClass):
         def __init__(self, *args, **kwargs):
         def _alias_columns(self, sql):
             """Return tuple of SELECT and FROM clauses, aliasing duplicate column names."""
             qn = self.connection.ops.quote_name
-
+            
+            # Pattern to find the quoted column name at the end of a field specification
+            _pat_col = r"\[([^[]+)\]$"  
+            #]) Funky comment to get e's syntax highlighting back on track. 
+        
             outer = list()
             inner = list()
-            
+
             names_seen = list()
-            original_names = sql[0:sql.find(' FROM [')].split(',')
+            original_names = [x.strip() for x in sql[:sql.find(' FROM [')].split(',')]
             for col in original_names:
-                # Col looks like: "[app_table].[column]"; strip out just "column"
-                col_name = col.split('].[')[1][:-1]
+                col_name = re.search(_pat_col, col).group(1)
                 
                 # If column name was already seen, alias it.
                 if col_name in names_seen:

tests/test_main/make_data.py

+from paging.models import *
+
+a1 = FirstTable(b='A1')
+a1.save()
+
+a2 = FirstTable(b='A2')
+a2.save()
+
+b1 = SecondTable(a=a1, b='B1')
+b1.save()
+
+b2 = SecondTable(a=a1, b='B2')
+b2.save()
+
+b3 = SecondTable(a=a1, b='B3')
+b3.save()

tests/test_main/paging/fixtures/paging.json

+[{"pk": 1, "model": "paging.firsttable", "fields": {"c": "test", "b": "A1"}}, {"pk": 2, "model": "paging.firsttable", "fields": {"c": "test", "b": "A2"}}, {"pk": 1, "model": "paging.secondtable", "fields": {"a": 1, "b": "B1"}}, {"pk": 2, "model": "paging.secondtable", "fields": {"a": 1, "b": "B2"}}, {"pk": 3, "model": "paging.secondtable", "fields": {"a": 1, "b": "B3"}}]

tests/test_main/paging/models.py

 from django.db import models
 from django.core.paginator import Paginator
 
-import unittest
+from django.test import TestCase
+#import unittest
 
 class FirstTable(models.Model):
     b = models.CharField(max_length=100)
     
     def __repr__(self):
         return '<FirstTable %s: %s, %s>' % (self.pk, self.b, self.c)
-    
+
 class SecondTable(models.Model):
     a = models.ForeignKey(FirstTable)
     b = models.CharField(max_length=100)
         return '<FirstTable %s: %s, %s>' % (self.pk, self.a_id, self.b)
 
 
-class PagingTestCase(unittest.TestCase):
-    def setupPagingData(self):
-        a1 = FirstTable(b='A1')
-        a1.save()
-        
-        a2 = FirstTable(b='A2')
-        a2.save()
-        
-        b1 = SecondTable(a=a1, b='B1')
-        b1.save()
+class PagingTestCase(TestCase):
+    fixtures = ['paging.json']
+    
+    def get_q(self, a1_pk):
+        return SecondTable.objects.filter(a=a1_pk).order_by('b').select_related(depth=1)
 
-        b2 = SecondTable(a=a1, b='B2')
-        b2.save()
-
-        b3 = SecondTable(a=a1, b='B3')
-        b3.save()
-        
-        return a1.pk
-        
-    def try_page(self, page_number, a1_pk):
-        # Select related data so we get two 'b' columns and two 'id' columns per row.
-        data = SecondTable.objects.filter(a=a1_pk).order_by('b').select_related(depth=1)
-        
+    def try_page(self, page_number, q):
         # Use a single item per page, to get multiple pages.
-        pager = Paginator(data, 1)
+        pager = Paginator(q, 1)
         self.assertEquals(pager.count, 3)
 
         on_this_page = list(pager.page(page_number).object_list)
-        self.assertEquals(len(on_this_page), 1, 'Too many results on this page.')
+        self.assertEquals(len(on_this_page), 1)
         self.assertEquals(on_this_page[0].b, 'B'+str(page_number))
     
-    def testPagingWithDuplicateColumnNames(self):
-        a1_pk = self.setupPagingData()
+    def testWithDuplicateColumnNames(self):
+        a1_pk = FirstTable.objects.get(b='A1').pk
+        q = self.get_q(a1_pk)
         
         for i in (1,2,3):
-            self.try_page(i, a1_pk)
+            self.try_page(i, q)
+            
+    def testPerRowSelect(self):
+        a1_pk = FirstTable.objects.get(b='A1').pk
+
+        q = SecondTable.objects.filter(a=a1_pk).order_by('b').select_related(depth=1).extra(select=
+        {
+            'extra_column': "select paging_FirstTable.id from paging_FirstTable where paging_FirstTable.id=%s" % (a1_pk,)
+        })
+        
+        for i in (1,2,3):
+            self.try_page(i, q)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.