Commits

Mike Bayer committed 8968822

- correct the argument signature for GenericFunction to be more predictable

  • Participants
  • Parent commits f152102

Comments (0)

Files changed (2)

File lib/sqlalchemy/sql/functions.py

 
 from .. import types as sqltypes, schema
 from .expression import (
-    ClauseList, Function, _literal_as_binds, text, _type_from_args
+    ClauseList, Function, _literal_as_binds, literal_column, _type_from_args
     )
 from . import operators
 from .visitors import VisitableType
     __metaclass__ = _GenericMeta
 
     coerce_arguments = True
-    def __init__(self, type_=None, args=(), **kwargs):
-        args = [_literal_as_binds(c) for c in args]
+    def __init__(self, *args, **kwargs):
         self.packagenames = []
         self._bind = kwargs.get('bind', None)
         self.clause_expr = ClauseList(
                 operator=operators.comma_op,
                 group_contents=True, *args).self_group()
         self.type = sqltypes.to_instance(
-            type_ or getattr(self, 'type', None))
+            kwargs.pop("type_", None) or getattr(self, 'type', None))
 
 
 class next_value(GenericFunction):
 
     def __init__(self, *args, **kwargs):
         kwargs.setdefault('type_', _type_from_args(args))
-        GenericFunction.__init__(self, args=args, **kwargs)
+        GenericFunction.__init__(self, *args, **kwargs)
 
 class coalesce(ReturnTypeFromArgs):
     pass
 
 class concat(GenericFunction):
     type = sqltypes.String
-    def __init__(self, *args, **kwargs):
-        GenericFunction.__init__(self, args=args, **kwargs)
 
 class char_length(GenericFunction):
     type = sqltypes.Integer
 
     def __init__(self, arg, **kwargs):
-        GenericFunction.__init__(self, args=[arg], **kwargs)
+        GenericFunction.__init__(self, arg, **kwargs)
 
 class random(GenericFunction):
-    def __init__(self, *args, **kwargs):
-        kwargs.setdefault('type_', None)
-        GenericFunction.__init__(self, args=args, **kwargs)
+    pass
 
 class count(GenericFunction):
     """The ANSI COUNT aggregate function.  With no arguments, emits COUNT \*."""
 
     def __init__(self, expression=None, **kwargs):
         if expression is None:
-            expression = text('*')
-        GenericFunction.__init__(self, args=(expression,), **kwargs)
+            expression = literal_column('*')
+        GenericFunction.__init__(self, expression, **kwargs)
 
 class current_date(AnsiFunction):
     type = sqltypes.Date

File test/sql/test_functions.py

                 __return_type__ = sqltypes.Integer
 
                 def __init__(self, arg, **kwargs):
-                    GenericFunction.__init__(self, args=[arg], **kwargs)
+                    GenericFunction.__init__(self, arg, **kwargs)
 
             self.assert_compile(
                             fake_func('foo'),
         assert isinstance(func.mypackage.myfunc(), f1)
         assert isinstance(func.myotherpackage.myfunc(), f2)
 
+    def test_custom_args(self):
+        class myfunc(GenericFunction):
+            pass
+
+        self.assert_compile(
+            myfunc(1, 2, 3),
+            "myfunc(:param_1, :param_2, :param_3)"
+        )
+
     def test_namespacing_conflicts(self):
         self.assert_compile(func.text('foo'), 'text(:text_1)')