Michael Manfre avatar Michael Manfre committed 12fed69

SQL containing AVG(...) will be modified to cast the contained values to float to allow the function to behave the same as for other databases. E.g. AVG([1,2]) will return 1.5, instead of 1. Casting can be disabled by 'disable_avg_cast' database option. Fixes #7

Comments (0)

Files changed (4)

docs/changelog.txt

 - Implemented ``DatabaseOperations.date_interval_sql`` to allow using expressions like ``end__lte=F('start')+delta``.
 - Fixed date part extraction for ``week_day``.
 - DatabaseWrapper reports vendor as 'microsoft'.
+- AVG function now matches core backend behaviors and will auto-cast to ``float``, instead of maintaining datatype. 
+  Set database ``OPTIONS`` setting ``disable_avg_cast`` to turn off the auto-cast behavior.

docs/settings.txt

 -------
 
 ``Django-mssql`` provides a few extra ``OPTIONS`` that are specific to this
-backend.
+backend. Please note that while the main database settings are UPPERCASE
+keys, the ``OPTIONS`` dictionary keys are expected to be lowercase (due to
+legacy reasons).
 
 use_mars
 ~~~~~~~~
 to enable MARS to avoid seeing the "Cannot create new connection because 
 in manual or distributed transaction mode" error.
 
-.. note:
+.. note::
     This will only set the appropriate connection string value for 
     the "SQLOLEDB" provider. If you are using a different provider, you 
     will need to add the appropriate connection string value to 
 
 Default: ``'SQLOLEDB'``
 
-The SQL provider to use when connecting to the database.
+The SQL provider to use when connecting to the database.
+
+
+disable_avg_cast
+~~~~~~~~~~~~~~~~
+
+Default: ``False``
+
+This backend will automatically ``CAST`` fields used by the `AVG function` 
+as ``FLOAT`` to match the behavior of the core database backends. Set this
+to ``True`` if you need SQL server to retain the datatype of fields used
+with ``AVG``.
+
+.. versionadded:: 1.1
+
+.. _`AVG function`: http://msdn.microsoft.com/en-us/library/ms177677.aspx
+
+.. note::
+    SQL server maintains the datatype of the values used in ``AVG``. The
+    average of an ``int`` column will be an ``int``. With this option set
+    to ``True``, ``AVG([1,2])`` == 1, not 1.5.

sqlserver_ado/base.py

         except ValueError:   
             self.command_timeout = 30
         
+        try:
+            options = self.settings_dict.get('OPTIONS', {})
+            self.cast_avg_to_float = not bool(options.get('disable_avg_cast', False))
+        except ValueError:
+            self.cast_avg_to_float = False
+        
         self.ops.is_sql2005 = self.is_sql2005
         self.ops.is_sql2008 = self.is_sql2008
 

sqlserver_ado/compiler.py

 def _remove_order_limit_offset(sql):
     return _re_order_limit_offset.sub('',sql).split(None, 1)[1]
 
-
 class SQLCompiler(compiler.SQLCompiler):
     def resolve_columns(self, row, fields=()):
         # If the results are sliced, the resultset will have an initial 
 
         return row[:index_extra_select] + tuple(values)
 
+    def _fix_aggregates(self):
+        """
+        MSSQL doesn't match the behavior of the other backends on a few of
+        the aggregate functions; different return type behavior, different
+        function names, etc.
+        
+        MSSQL's implementation of AVG maintains datatype without proding. To
+        match behavior of other django backends, it needs to not drop remainders.
+        E.g. AVG([1, 2]) needs to yield 1.5, not 1
+        """
+        if self.connection.cast_avg_to_float:
+            for alias, aggregate in self.query.aggregate_select.items():
+                if aggregate.sql_function == 'AVG':
+                    # Embed the CAST in the template on this query to
+                    # maintain multi-db support.
+                    self.query.aggregate_select[alias].sql_template = \
+                        '%(function)s(CAST(%(field)s AS FLOAT))'
+
     def as_sql(self, with_limits=True, with_col_aliases=False):
+        self._fix_aggregates()
+        
         self._using_row_number = False
         
         # Get out of the way if we're not a select query or there's no limiting involved.
     pass
 
 class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
-    pass
+    def as_sql(self, qn=None):
+        self._fix_aggregates()
+        return super(SQLAggregateCompiler, self).as_sql(qn=qn)
 
 class SQLDateCompiler(compiler.SQLDateCompiler, SQLCompiler):
     pass
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.