Commits

Mike Bayer committed 417836b

- [feature] Enhanced GenericFunction and func.*
to allow for user-defined GenericFunction
subclasses to be available via the func.*
namespace automatically by classname,
optionally using a package name as well.

Comments (0)

Files changed (5)

     used by combining operators.custom_op() with
     UnaryExpression().
 
+  - [feature] Enhanced GenericFunction and func.*
+    to allow for user-defined GenericFunction
+    subclasses to be available via the func.*
+    namespace automatically by classname,
+    optionally using a package name as well.
+
   - [changed] Most classes in expression.sql
     are no longer preceded with an underscore,
     i.e. Label, SelectBase, Generative, CompareMixin.

doc/build/core/expression_api.rst

    :members:
    :show-inheritance:
 
+.. autoclass:: sqlalchemy.sql.functions.GenericFunction
+   :members:
+   :show-inheritance:
+
 .. autoclass:: Insert
    :members:
    :show-inheritance:

lib/sqlalchemy/sql/expression.py

     def __call__(self, *c, **kwargs):
         o = self.opts.copy()
         o.update(kwargs)
-        if len(self.__names) == 1:
-            func = getattr(functions, self.__names[-1].lower(), None)
-            if func is not None and \
-                    isinstance(func, type) and \
-                    issubclass(func, Function):
-                return func(*c, **o)
+
+        tokens = len(self.__names)
+
+        if tokens == 2:
+            package, fname = self.__names
+        elif tokens == 1:
+            package, fname = "_default", self.__names[0]
+        else:
+            package = None
+
+        if package is not None and \
+            package in functions._registry and \
+            fname in functions._registry[package]:
+            func = functions._registry[package][fname]
+            return func(*c, **o)
 
         return Function(self.__names[-1],
                         packagenames=self.__names[0:-1], *c, **o)
                     self.get_children()]))
 
 class FunctionElement(Executable, ColumnElement, FromClause):
-    """Base for SQL function-oriented constructs."""
+    """Base for SQL function-oriented constructs.
+
+    See also:
+
+    :class:`.Function` - named SQL function.
+
+    :data:`.func` - namespace which produces registered or ad-hoc
+    :class:`.Function` instances.
+
+    :class:`.GenericFunction` - allows creation of registered function
+    types.
+
+    """
 
     packagenames = ()
 
     See the superclass :class:`.FunctionElement` for a description
     of public methods.
 
+    See also:
+
+    See also:
+
+    :data:`.func` - namespace which produces registered or ad-hoc
+    :class:`.Function` instances.
+
+    :class:`.GenericFunction` - allows creation of registered function
+    types.
+
     """
 
     __visit_name__ = 'function'

lib/sqlalchemy/sql/functions.py

     )
 from . import operators
 from .visitors import VisitableType
+from .. import util
+
+_registry = util.defaultdict(dict)
 
 class _GenericMeta(VisitableType):
-    def __call__(self, *args, **kwargs):
-        args = [_literal_as_binds(c) for c in args]
-        return type.__call__(self, *args, **kwargs)
+    def __init__(cls, clsname, bases, clsdict):
+        cls.name = name = clsdict.get('name', clsname)
+        package = clsdict.pop('package', '_default')
+        # legacy
+        if '__return_type__' in clsdict:
+            cls.type = clsdict['__return_type__']
+        reg = _registry[package]
+        reg[name] = cls
+        super(_GenericMeta, cls).__init__(clsname, bases, clsdict)
+
+    def __call__(cls, *args, **kwargs):
+        if cls.coerce_arguments:
+            args = [_literal_as_binds(c) for c in args]
+        return type.__call__(cls, *args, **kwargs)
 
 class GenericFunction(Function):
+    """Define a 'generic' function.
+
+    A generic function is a pre-established :class:`.Function`
+    class that is instantiated automatically when called
+    by name from the :data:`.func` attribute.    Note that
+    calling any name from :data:`.func` has the effect that
+    a new :class:`.Function` instance is created automatically,
+    given that name.  The primary use case for defining
+    a :class:`.GenericFunction` class is so that a function
+    of a particular name may be given a fixed return type.
+    It can also include custom argument parsing schemes as well
+    as additional methods.
+
+    Subclasses of :class:`.GenericFunction` are automatically
+    registered under the name of the class.  For
+    example, a user-defined function ``as_utc()`` would
+    be available immediately::
+
+        from sqlalchemy.sql.functions import GenericFunction
+        from sqlalchemy.types import DateTime
+
+        class as_utc(GenericFunction):
+            type = DateTime
+
+        print select([func.as_utc()])
+
+    User-defined generic functions can be organized into
+    packages by specifying the "package" attribute when defining
+    :class:`.GenericFunction`.   Third party libraries
+    containing many functions may want to use this in order
+    to avoid name conflicts with other systems.   For example,
+    if our ``as_utc()`` function were part of a package
+    "time"::
+
+        class as_utc(GenericFunction):
+            type = DateTime
+            package = "time"
+
+    The above function would be available from :data:`.func`
+    using the package name ``time``::
+
+        print select([func.time.as_utc()])
+
+    .. versionadded:: 0.8 :class:`.GenericFunction` now supports
+       automatic registration of new functions as well as package
+       support.
+
+    .. versionchanged:: 0.8 The attribute name ``type`` is used
+       to specify the function's return type at the class level.
+       Previously, the name ``__return_type__`` was used.  This
+       name is still recognized for backwards-compatibility.
+
+    """
     __metaclass__ = _GenericMeta
 
+    coerce_arguments = True
     def __init__(self, type_=None, args=(), **kwargs):
+        args = [_literal_as_binds(c) for c in args]
         self.packagenames = []
-        self.name = self.__class__.__name__
         self._bind = kwargs.get('bind', None)
         self.clause_expr = ClauseList(
                 operator=operators.comma_op,
                 group_contents=True, *args).self_group()
         self.type = sqltypes.to_instance(
-            type_ or getattr(self, '__return_type__', None))
+            type_ or getattr(self, 'type', None))
 
 
-class next_value(Function):
+class next_value(GenericFunction):
     """Represent the 'next value', given a :class:`.Sequence`
     as it's single argument.
 
     """
     type = sqltypes.Integer()
     name = "next_value"
+    coerce_arguments = False
 
     def __init__(self, seq, **kw):
         assert isinstance(seq, schema.Sequence), \
 
 
 class now(GenericFunction):
-    __return_type__ = sqltypes.DateTime
+    type = sqltypes.DateTime
 
 class concat(GenericFunction):
-    __return_type__ = sqltypes.String
+    type = sqltypes.String
     def __init__(self, *args, **kwargs):
         GenericFunction.__init__(self, args=args, **kwargs)
 
 class char_length(GenericFunction):
-    __return_type__ = sqltypes.Integer
+    type = sqltypes.Integer
 
     def __init__(self, arg, **kwargs):
         GenericFunction.__init__(self, args=[arg], **kwargs)
 class count(GenericFunction):
     """The ANSI COUNT aggregate function.  With no arguments, emits COUNT \*."""
 
-    __return_type__ = sqltypes.Integer
+    type = sqltypes.Integer
 
     def __init__(self, expression=None, **kwargs):
         if expression is None:
         GenericFunction.__init__(self, args=(expression,), **kwargs)
 
 class current_date(AnsiFunction):
-    __return_type__ = sqltypes.Date
+    type = sqltypes.Date
 
 class current_time(AnsiFunction):
-    __return_type__ = sqltypes.Time
+    type = sqltypes.Time
 
 class current_timestamp(AnsiFunction):
-    __return_type__ = sqltypes.DateTime
+    type = sqltypes.DateTime
 
 class current_user(AnsiFunction):
-    __return_type__ = sqltypes.String
+    type = sqltypes.String
 
 class localtime(AnsiFunction):
-    __return_type__ = sqltypes.DateTime
+    type = sqltypes.DateTime
 
 class localtimestamp(AnsiFunction):
-    __return_type__ = sqltypes.DateTime
+    type = sqltypes.DateTime
 
 class session_user(AnsiFunction):
-    __return_type__ = sqltypes.String
+    type = sqltypes.String
 
 class sysdate(AnsiFunction):
-    __return_type__ = sqltypes.DateTime
+    type = sqltypes.DateTime
 
 class user(AnsiFunction):
-    __return_type__ = sqltypes.String
+    type = sqltypes.String
 

test/sql/test_functions.py

 import datetime
 from sqlalchemy import *
 from sqlalchemy.sql import table, column
-from sqlalchemy import databases, sql, util
+from sqlalchemy import sql, util
 from sqlalchemy.sql.compiler import BIND_TEMPLATES
-from sqlalchemy.engine import default
 from test.lib.engines import all_dialects
 from sqlalchemy import types as sqltypes
-from test.lib import *
+from sqlalchemy.sql import functions
 from sqlalchemy.sql.functions import GenericFunction
-from test.lib.testing import eq_
 from sqlalchemy.util.compat import decimal
-from test.lib import testing
-from sqlalchemy.databases import *
+from test.lib import testing, fixtures, AssertsCompiledSQL, engines
+from sqlalchemy.dialects import sqlite, postgresql, mysql, oracle
 
 
 class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = 'default'
 
+    def tear_down(self):
+        functions._registry.clear()
+
     def test_compile(self):
-        for dialect in all_dialects(exclude=('sybase', 'access', 'informix', 'maxdb')):
+        for dialect in all_dialects(exclude=('sybase', 'access',
+                                                'informix', 'maxdb')):
             bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
-            self.assert_compile(func.current_timestamp(), "CURRENT_TIMESTAMP", dialect=dialect)
+            self.assert_compile(func.current_timestamp(),
+                                        "CURRENT_TIMESTAMP", dialect=dialect)
             self.assert_compile(func.localtime(), "LOCALTIME", dialect=dialect)
-            if isinstance(dialect, (firebird.dialect, maxdb.dialect)):
-                self.assert_compile(func.nosuchfunction(), "nosuchfunction", dialect=dialect)
+            if dialect.name in ('firebird', 'maxdb'):
+                self.assert_compile(func.nosuchfunction(),
+                                            "nosuchfunction", dialect=dialect)
             else:
-                self.assert_compile(func.nosuchfunction(), "nosuchfunction()", dialect=dialect)
+                self.assert_compile(func.nosuchfunction(),
+                                        "nosuchfunction()", dialect=dialect)
 
             # test generic function compile
             class fake_func(GenericFunction):
             self.assert_compile(
                             fake_func('foo'),
                             "fake_func(%s)" %
-                            bindtemplate % {'name':'param_1', 'position':1},
+                            bindtemplate % {'name': 'param_1', 'position': 1},
                             dialect=dialect)
 
     def test_use_labels(self):
         ]:
             self.assert_compile(func.random(), ret, dialect=dialect)
 
+    def test_custom_default_namespace(self):
+        class myfunc(GenericFunction):
+            pass
+
+        assert isinstance(func.myfunc(), myfunc)
+
+    def test_custom_type(self):
+        class myfunc(GenericFunction):
+            type = DateTime
+
+        assert isinstance(func.myfunc().type, DateTime)
+
+    def test_custom_legacy_type(self):
+        # in case someone was using this system
+        class myfunc(GenericFunction):
+            __return_type__ = DateTime
+
+        assert isinstance(func.myfunc().type, DateTime)
+
+    def test_custom_w_custom_name(self):
+        class myfunc(GenericFunction):
+            name = "notmyfunc"
+
+        assert isinstance(func.notmyfunc(), myfunc)
+        assert not isinstance(func.myfunc(), myfunc)
+
+    def test_custom_package_namespace(self):
+        def cls1(pk_name):
+            class myfunc(GenericFunction):
+                package = pk_name
+            return myfunc
+
+        f1 = cls1("mypackage")
+        f2 = cls1("myotherpackage")
+
+        assert isinstance(func.mypackage.myfunc(), f1)
+        assert isinstance(func.myotherpackage.myfunc(), f2)
+
     def test_namespacing_conflicts(self):
         self.assert_compile(func.text('foo'), 'text(:text_1)')
 
                             ((datetime.date(2007, 10, 5),
                                 datetime.date(2005, 10, 15)), sqltypes.Date),
                             ((3, 5), sqltypes.Integer),
-                            ((decimal.Decimal(3), decimal.Decimal(5)), sqltypes.Numeric),
+                            ((decimal.Decimal(3), decimal.Decimal(5)),
+                                                        sqltypes.Numeric),
                             (("foo", "bar"), sqltypes.String),
                             ((datetime.datetime(2007, 10, 5, 8, 3, 34),
-                                datetime.datetime(2005, 10, 15, 14, 45, 33)), sqltypes.DateTime)
+                                datetime.datetime(2005, 10, 15, 14, 45, 33)),
+                                                        sqltypes.DateTime)
                         ]:
-                assert isinstance(fn(*args).type, type_), "%s / %s" % (fn(), type_)
+                assert isinstance(fn(*args).type, type_), \
+                            "%s / %s" % (fn(), type_)
 
         assert isinstance(func.concat("foo", "bar").type, sqltypes.String)
 
         )
 
         # test an expression with a function
-        self.assert_compile(func.lala(3, 4, literal("five"), table1.c.myid) * table2.c.otherid,
-            "lala(:lala_1, :lala_2, :param_1, mytable.myid) * myothertable.otherid")
+        self.assert_compile(func.lala(3, 4, literal("five"),
+                                        table1.c.myid) * table2.c.otherid,
+            "lala(:lala_1, :lala_2, :param_1, mytable.myid) * "
+            "myothertable.otherid")
 
         # test it in a SELECT
         self.assert_compile(select([func.count(table1.c.myid)]),
         self.assert_compile(select([func.foo.bar.lala(table1.c.myid)]),
             "SELECT foo.bar.lala(mytable.myid) AS lala_1 FROM mytable")
 
-        # test the bind parameter name with a "dotted" function name is only the name
-        # (limits the length of the bind param name)
+        # test the bind parameter name with a "dotted" function name is
+        # only the name (limits the length of the bind param name)
         self.assert_compile(select([func.foo.bar.lala(12)]),
             "SELECT foo.bar.lala(:lala_2) AS lala_1")
 
         self.assert_compile(func.lala.hoho(7), "lala.hoho(:hoho_1)")
 
         # test None becomes NULL
-        self.assert_compile(func.my_func(1,2,None,3),
+        self.assert_compile(func.my_func(1, 2, None, 3),
                         "my_func(:my_func_1, :my_func_2, NULL, :my_func_3)")
 
         # test pickling
         self.assert_compile(
-                util.pickle.loads(util.pickle.dumps(func.my_func(1, 2, None, 3))),
+                util.pickle.loads(util.pickle.dumps(
+                                        func.my_func(1, 2, None, 3))),
                 "my_func(:my_func_1, :my_func_2, NULL, :my_func_3)")
 
-        # assert func raises AttributeError for __bases__ attribute, since its not a class
-        # fixes pydoc
+        # assert func raises AttributeError for __bases__ attribute, since
+        # its not a class fixes pydoc
         try:
             func.__bases__
             assert False
             "FROM users, (SELECT q, z, r "
             "FROM calculate(:x_1, :y_1)) AS c1, (SELECT q, z, r "
             "FROM calculate(:x_2, :y_2)) AS c2 "
-            "WHERE users.id BETWEEN c1.z AND c2.z"
-            , checkparams={'y_1': 45, 'x_1': 17, 'y_2': 12, 'x_2': 5})
+            "WHERE users.id BETWEEN c1.z AND c2.z",
+            checkparams={'y_1': 45, 'x_1': 17, 'y_2': 12, 'x_2': 5})
 
 
 class ExecuteTest(fixtures.TestBase):
         eq_(f._execution_options, {})
 
         f = f.execution_options(foo='bar')
-        eq_(f._execution_options, {'foo':'bar'})
+        eq_(f._execution_options, {'foo': 'bar'})
         s = f.select()
-        eq_(s._execution_options, {'foo':'bar'})
+        eq_(s._execution_options, {'foo': 'bar'})
 
         ret = testing.db.execute(func.now().execution_options(foo='bar'))
-        eq_(ret.context.execution_options, {'foo':'bar'})
+        eq_(ret.context.execution_options, {'foo': 'bar'})
         ret.close()
 
 
 
         meta = MetaData(testing.db)
         t = Table('t1', meta,
-            Column('id', Integer, Sequence('t1idseq', optional=True), primary_key=True),
+            Column('id', Integer, Sequence('t1idseq', optional=True),
+                                                            primary_key=True),
             Column('value', Integer)
         )
         t2 = Table('t2', meta,
-            Column('id', Integer, Sequence('t2idseq', optional=True), primary_key=True),
+            Column('id', Integer, Sequence('t2idseq', optional=True),
+                                                            primary_key=True),
             Column('value', Integer, default=7),
             Column('stuff', String(20), onupdate="thisisstuff")
         )
 
             r = t.insert(values=dict(value=func.length("sfsaafsda"))).execute()
             id = r.inserted_primary_key[0]
-            assert t.select(t.c.id==id).execute().first()['value'] == 9
-            t.update(values={t.c.value:func.length("asdf")}).execute()
+            assert t.select(t.c.id == id).execute().first()['value'] == 9
+            t.update(values={t.c.value: func.length("asdf")}).execute()
             assert t.select().execute().first()['value'] == 4
             print "--------------------------"
             t2.insert().execute()
             t2.insert(values=dict(value=func.length("one"))).execute()
-            t2.insert(values=dict(value=func.length("asfda") + -19)).execute(stuff="hi")
+            t2.insert(values=dict(value=func.length("asfda") + -19)).\
+                            execute(stuff="hi")
 
             res = exec_sorted(select([t2.c.value, t2.c.stuff]))
             eq_(res, [(-14, 'hi'), (3, None), (7, None)])
 
-            t2.update(values=dict(value=func.length("asdsafasd"))).execute(stuff="some stuff")
+            t2.update(values=dict(value=func.length("asdsafasd"))).\
+                        execute(stuff="some stuff")
             assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == \
-                        [(9,"some stuff"), (9,"some stuff"), (9,"some stuff")]
+                        [(9, "some stuff"), (9, "some stuff"),
+                            (9, "some stuff")]
 
             t2.delete().execute()
 
             assert t2.select().execute().first()['value'] == 11
 
             t2.update(values=dict(value=func.length("asfda"))).execute()
-            assert select([t2.c.value, t2.c.stuff]).execute().first() == (5, "thisisstuff")
+            eq_(
+                select([t2.c.value, t2.c.stuff]).execute().first(),
+                (5, "thisisstuff")
+            )
 
-            t2.update(values={t2.c.value:func.length("asfdaasdf"), t2.c.stuff:"foo"}).execute()
+            t2.update(values={t2.c.value: func.length("asfdaasdf"),
+                                        t2.c.stuff: "foo"}).execute()
             print "HI", select([t2.c.value, t2.c.stuff]).execute().first()
-            assert select([t2.c.value, t2.c.stuff]).execute().first() == (9, "foo")
+            eq_(select([t2.c.value, t2.c.stuff]).execute().first(),
+                    (9, "foo")
+                )
         finally:
             meta.drop_all()
 
         x = func.current_date(bind=testing.db).execute().scalar()
         y = func.current_date(bind=testing.db).select().execute().scalar()
         z = func.current_date(bind=testing.db).scalar()
-        w = select(['*'], from_obj=[func.current_date(bind=testing.db)]).scalar()
+        w = select(['*'], from_obj=[func.current_date(bind=testing.db)]).\
+                    scalar()
 
-        # construct a column-based FROM object out of a function, like in [ticket:172]
-        s = select([sql.column('date', type_=DateTime)], from_obj=[func.current_date(bind=testing.db)])
+        # construct a column-based FROM object out of a function,
+        # like in [ticket:172]
+        s = select([sql.column('date', type_=DateTime)],
+                            from_obj=[func.current_date(bind=testing.db)])
         q = s.execute().first()[s.c.date]
         r = s.alias('datequery').select().scalar()
 
         try:
             table.insert().execute(
                 {'dt': datetime.datetime(2010, 5, 1, 12, 11, 10),
-                 'd': datetime.date(2010, 5, 1) })
+                 'd': datetime.date(2010, 5, 1)})
             rs = select([extract('year', table.c.dt),
                          extract('month', table.c.d)]).execute()
             row = rs.first()