Michael Manfre  committed 12fed69

SQL containing AVG(...) will be modified to cast the contained values to float to allow the function to behave the same as for other databases. E.g. AVG([1,2]) will return 1.5, instead of 1. Casting can be disabled by 'disable_avg_cast' database option. Fixes #7

  • Participants
  • Parent commits f6b5859
  • Branches default

Comments (0)

Files changed (4)

File docs/changelog.txt

 - Implemented ``DatabaseOperations.date_interval_sql`` to allow using expressions like ``end__lte=F('start')+delta``.
 - Fixed date part extraction for ``week_day``.
 - DatabaseWrapper reports vendor as 'microsoft'.
+- AVG function now matches core backend behaviors and will auto-cast to ``float``, instead of maintaining datatype. 
+  Set database ``OPTIONS`` setting ``disable_avg_cast`` to turn off the auto-cast behavior.

File docs/settings.txt

 ``Django-mssql`` provides a few extra ``OPTIONS`` that are specific to this
+backend. Please note that while the main database settings are UPPERCASE
+keys, the ``OPTIONS`` dictionary keys are expected to be lowercase (due to
+legacy reasons).
 to enable MARS to avoid seeing the "Cannot create new connection because 
 in manual or distributed transaction mode" error.
-.. note:
+.. note::
     This will only set the appropriate connection string value for 
     the "SQLOLEDB" provider. If you are using a different provider, you 
     will need to add the appropriate connection string value to 
 Default: ``'SQLOLEDB'``
-The SQL provider to use when connecting to the database.
+The SQL provider to use when connecting to the database.
+Default: ``False``
+This backend will automatically ``CAST`` fields used by the `AVG function` 
+as ``FLOAT`` to match the behavior of the core database backends. Set this
+to ``True`` if you need SQL server to retain the datatype of fields used
+with ``AVG``.
+.. versionadded:: 1.1
+.. _`AVG function`:
+.. note::
+    SQL server maintains the datatype of the values used in ``AVG``. The
+    average of an ``int`` column will be an ``int``. With this option set
+    to ``True``, ``AVG([1,2])`` == 1, not 1.5.

File sqlserver_ado/

         except ValueError:   
             self.command_timeout = 30
+        try:
+            options = self.settings_dict.get('OPTIONS', {})
+            self.cast_avg_to_float = not bool(options.get('disable_avg_cast', False))
+        except ValueError:
+            self.cast_avg_to_float = False
         self.ops.is_sql2005 = self.is_sql2005
         self.ops.is_sql2008 = self.is_sql2008

File sqlserver_ado/

 def _remove_order_limit_offset(sql):
     return _re_order_limit_offset.sub('',sql).split(None, 1)[1]
 class SQLCompiler(compiler.SQLCompiler):
     def resolve_columns(self, row, fields=()):
         # If the results are sliced, the resultset will have an initial 
         return row[:index_extra_select] + tuple(values)
+    def _fix_aggregates(self):
+        """
+        MSSQL doesn't match the behavior of the other backends on a few of
+        the aggregate functions; different return type behavior, different
+        function names, etc.
+        MSSQL's implementation of AVG maintains datatype without proding. To
+        match behavior of other django backends, it needs to not drop remainders.
+        E.g. AVG([1, 2]) needs to yield 1.5, not 1
+        """
+        if self.connection.cast_avg_to_float:
+            for alias, aggregate in self.query.aggregate_select.items():
+                if aggregate.sql_function == 'AVG':
+                    # Embed the CAST in the template on this query to
+                    # maintain multi-db support.
+                    self.query.aggregate_select[alias].sql_template = \
+                        '%(function)s(CAST(%(field)s AS FLOAT))'
     def as_sql(self, with_limits=True, with_col_aliases=False):
+        self._fix_aggregates()
         self._using_row_number = False
         # Get out of the way if we're not a select query or there's no limiting involved.
 class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
-    pass
+    def as_sql(self, qn=None):
+        self._fix_aggregates()
+        return super(SQLAggregateCompiler, self).as_sql(qn=qn)
 class SQLDateCompiler(compiler.SQLDateCompiler, SQLCompiler):