Commits

Mike Bayer committed 2cbc04b

_adapt_expression() moves fully to _DefaultColumnComparator which resumes
its original role as stateful, forms the basis of TypeEngine.Comparator. lots
of code goes back mostly as it was just with cleaner typing behavior, such
as simple flow in _binary_operate now.

Comments (0)

Files changed (5)

lib/sqlalchemy/dialects/postgresql/base.py

             affinity = None
 
         casts = {
-                    sqltypes.Date:'date',
-                    sqltypes.DateTime:'timestamp',
-                    sqltypes.Interval:'interval', sqltypes.Time:'time'
+                    sqltypes.Date: 'date',
+                    sqltypes.DateTime: 'timestamp',
+                    sqltypes.Interval: 'interval',
+                    sqltypes.Time: 'time'
                 }
         cast = casts.get(affinity, None)
         if isinstance(extract.expr, sql.ColumnElement) and cast is not None:

lib/sqlalchemy/sql/expression.py

         return self
 
 
-class _DefaultColumnComparator(object):
+class _DefaultColumnComparator(operators.ColumnOperators):
     """Defines comparison and math operations.
 
     See :class:`.ColumnOperators` and :class:`.Operators` for descriptions
 
     """
 
+    @util.memoized_property
+    def type(self):
+        return self.expr.type
+
+    def operate(self, op, *other, **kwargs):
+        o = self.operators[op.__name__]
+        return o[0](self, self.expr, op, *(other + o[1:]), **kwargs)
+
+    def reverse_operate(self, op, other, **kwargs):
+        o = self.operators[op.__name__]
+        return o[0](self, self.expr, op, other, reverse=True, *o[1:], **kwargs)
+
+    def _adapt_expression(self, op, other_comparator):
+        """evaluate the return type of <self> <op> <othertype>,
+        and apply any adaptations to the given operator.
+
+        This method determines the type of a resulting binary expression
+        given two source types and an operator.   For example, two
+        :class:`.Column` objects, both of the type :class:`.Integer`, will
+        produce a :class:`.BinaryExpression` that also has the type
+        :class:`.Integer` when compared via the addition (``+``) operator.
+        However, using the addition operator with an :class:`.Integer`
+        and a :class:`.Date` object will produce a :class:`.Date`, assuming
+        "days delta" behavior by the database (in reality, most databases
+        other than Postgresql don't accept this particular operation).
+
+        The method returns a tuple of the form <operator>, <type>.
+        The resulting operator and type will be those applied to the
+        resulting :class:`.BinaryExpression` as the final operator and the
+        right-hand side of the expression.
+
+        Note that only a subset of operators make usage of
+        :meth:`._adapt_expression`,
+        including math operators and user-defined operators, but not
+        boolean comparison or special SQL keywords like MATCH or BETWEEN.
+
+        """
+        return op, other_comparator.type
+
     def _boolean_compare(self, expr, op, obj, negate=None, reverse=False,
                         **kwargs
         ):
                             type_=sqltypes.BOOLEANTYPE,
                             negate=negate, modifiers=kwargs)
 
-    def _binary_operate(self, expr, op, obj, result_type, reverse=False):
+    def _binary_operate(self, expr, op, obj, reverse=False):
         obj = self._check_literal(expr, op, obj)
 
         if reverse:
         else:
             left, right = expr, obj
 
+        op, result_type = left.comparator._adapt_expression(op, right.comparator)
+
         return BinaryExpression(left, right, op, type_=result_type)
 
     def _scalar(self, expr, op, fn, **kw):
             expr,
             operators.like_op,
             literal_column("'%'", type_=sqltypes.String).__radd__(
-                                self._check_literal(expr, operators.like_op, other)
+                                self._check_literal(expr,
+                                        operators.like_op, other)
                             ),
             escape=escape)
 
         "neg": (_neg_impl,),
     }
 
-    def operate(self, expr, op, *other, **kwargs):
-        o = self.operators[op.__name__]
-        return o[0](self, expr, op, *(other + o[1:]), **kwargs)
-
-    def reverse_operate(self, expr, op, other, **kwargs):
-        o = self.operators[op.__name__]
-        return o[0](self, expr, op, other, reverse=True, *o[1:], **kwargs)
 
     def _check_literal(self, expr, operator, other):
-        if isinstance(other, BindParameter) and \
-            isinstance(other.type, sqltypes.NullType):
-            # TODO: perhaps we should not mutate the incoming bindparam()
-            # here and instead make a copy of it.  this might
-            # be the only place that we're mutating an incoming construct.
-            other.type = expr.type
+        if isinstance(other, (ColumnElement, TextClause)):
+            if isinstance(other, BindParameter) and \
+                isinstance(other.type, sqltypes.NullType):
+                # TODO: perhaps we should not mutate the incoming
+                # bindparam() here and instead make a copy of it.
+                # this might be the only place that we're mutating
+                # an incoming construct.
+                other.type = expr.type
             return other
         elif hasattr(other, '__clause_element__'):
             other = other.__clause_element__()
         else:
             return other
 
-_DEFAULT_COMPARATOR = _DefaultColumnComparator()
-
 
 class ColumnElement(ClauseElement, ColumnOperators):
     """Represent an element that is usable within the "column clause" portion
     def comparator(self):
         return self.type.comparator_factory(self)
 
-    #def _assert_comparator(self):
-    #    assert self.comparator.expr is self
-
     def __getattr__(self, key):
-        #self._assert_comparator()
         try:
             return getattr(self.comparator, key)
         except AttributeError:
             )
 
     def operate(self, op, *other, **kwargs):
-        #self._assert_comparator()
         return op(self.comparator, *other, **kwargs)
 
     def reverse_operate(self, op, other, **kwargs):
-        #self._assert_comparator()
         return op(other, self.comparator, **kwargs)
 
     def _bind_param(self, operator, obj):
         else:
             return sqltypes.NULLTYPE
 
+    @property
+    def comparator(self):
+        return self.type.comparator_factory(self)
+
     def self_group(self, against=None):
         if against is operators.in_op:
             return Grouping(self)

lib/sqlalchemy/types.py

 For more information see the SQLAlchemy documentation on types.
 
 """
-__all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType', 'UserDefinedType',
-            'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'NVARCHAR','TEXT', 'Text',
+__all__ = ['TypeEngine', 'TypeDecorator', 'AbstractType', 'UserDefinedType',
+            'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'NVARCHAR', 'TEXT', 'Text',
             'FLOAT', 'NUMERIC', 'REAL', 'DECIMAL', 'TIMESTAMP', 'DATETIME',
             'CLOB', 'BLOB', 'BINARY', 'VARBINARY', 'BOOLEAN', 'BIGINT', 'SMALLINT',
             'INTEGER', 'DATE', 'TIME', 'String', 'Integer', 'SmallInteger',
             'BigInteger', 'Numeric', 'Float', 'DateTime', 'Date', 'Time',
             'LargeBinary', 'Binary', 'Boolean', 'Unicode', 'Concatenable',
-            'UnicodeText','PickleType', 'Interval', 'Enum' ]
+            'UnicodeText', 'PickleType', 'Interval', 'Enum']
 
 import datetime as dt
 import codecs
 
 from . import exc, schema, util, processors, events, event
 from .sql import operators
-from .sql.expression import _DEFAULT_COMPARATOR
+from .sql.expression import _DefaultColumnComparator
 from .util import pickle
 from .util.compat import decimal
 from .sql.visitors import Visitable
 class TypeEngine(AbstractType):
     """Base for built-in types."""
 
-    class Comparator(operators.ColumnOperators):
+    class Comparator(_DefaultColumnComparator):
         """Base class for custom comparison operations defined at the
         type level.  See :attr:`.TypeEngine.comparator_factory`.
 
         def __reduce__(self):
             return _reconstitute_comparator, (self.expr, )
 
-        def operate(self, op, *other, **kwargs):
-            if len(other) == 1:
-                obj = other[0]
-                obj = _DEFAULT_COMPARATOR._check_literal(self.expr, op, obj)
-                op, adapt_type = self.expr.type._adapt_expression(op,
-                    obj.type)
-                kwargs['result_type'] = adapt_type
-
-            return _DEFAULT_COMPARATOR.operate(self.expr, op, *other, **kwargs)
-
-        def reverse_operate(self, op, other, **kwargs):
-
-            obj = _DEFAULT_COMPARATOR._check_literal(self.expr, op, other)
-            op, adapt_type = obj.type._adapt_expression(op, self.expr.type)
-            kwargs['result_type'] = adapt_type
-
-            return _DEFAULT_COMPARATOR.reverse_operate(self.expr, op, obj,
-                                                **kwargs)
 
     comparator_factory = Comparator
     """A :class:`.TypeEngine.Comparator` class which will apply
         >>> (c1 == c2).type
         Boolean()
 
-    The propagation of :class:`.TypeEngine.Comparator` throughout an expression
-    will follow with how the :class:`.TypeEngine` itself is propagated.  To
-    customize the behavior of most operators in this regard, see the
-    :meth:`._adapt_expression` method.
-
     .. versionadded:: 0.8  The expression system was reworked to support
       user-defined comparator objects specified at the type level.
 
         .. versionadded:: 0.7.2
 
         """
-        return Variant(self, {dialect_name:type_})
-
-    def _adapt_expression(self, op, othertype):
-        """evaluate the return type of <self> <op> <othertype>,
-        and apply any adaptations to the given operator.
-
-        This method determines the type of a resulting binary expression
-        given two source types and an operator.   For example, two
-        :class:`.Column` objects, both of the type :class:`.Integer`, will
-        produce a :class:`.BinaryExpression` that also has the type
-        :class:`.Integer` when compared via the addition (``+``) operator.
-        However, using the addition operator with an :class:`.Integer`
-        and a :class:`.Date` object will produce a :class:`.Date`, assuming
-        "days delta" behavior by the database (in reality, most databases
-        other than Postgresql don't accept this particular operation).
-
-        The method returns a tuple of the form <operator>, <type>.
-        The resulting operator and type will be those applied to the
-        resulting :class:`.BinaryExpression` as the final operator and the
-        right-hand side of the expression.
-
-        Note that only a subset of operators make usage of
-        :meth:`._adapt_expression`,
-        including math operators and user-defined operators, but not
-        boolean comparison or special SQL keywords like MATCH or BETWEEN.
-
-        """
-        return op, self
+        return Variant(self, {dialect_name: type_})
 
     @util.memoized_property
     def _type_affinity(self):
                 impl = self.adapt(type(self))
             # this can't be self, else we create a cycle
             assert impl is not self
-            dialect._type_memos[self] = d = {'impl':impl}
+            dialect._type_memos[self] = d = {'impl': impl}
             return d
 
     def _gen_dialect_impl(self, dialect):
     """
     __visit_name__ = "user_defined"
 
-    def _adapt_expression(self, op, othertype):
-        """evaluate the return type of <self> <op> <othertype>,
-        and apply any adaptations to the given operator.
-
-        """
-        return self.adapt_operator(op), self
-
-    def adapt_operator(self, op):
-        """A hook which allows the given operator to be adapted
-        to something new.
-
-        See also UserDefinedType._adapt_expression(), an as-yet-
-        semi-public method with greater capability in this regard.
-
-        """
-        return op
+    class Comparator(TypeEngine.Comparator):
+        def _adapt_expression(self, op, other_comparator):
+            if hasattr(self.type, 'adapt_operator'):
+                util.warn_deprecated(
+                    "UserDefinedType.adapt_operator is deprecated.  Create "
+                     "a UserDefinedType.Comparator subclass instead which "
+                     "generates the desired expression constructs, given a "
+                     "particular operator."
+                    )
+                return self.type.adapt_operator(op), self.type
+            else:
+                return op, self.type
+
+    comparator_factory = Comparator
+
 
 class TypeDecorator(TypeEngine):
     """Allows the creation of types which add additional functionality
         """
         return self.impl.compare_values(x, y)
 
-    def _adapt_expression(self, op, othertype):
-        op, typ = self.impl._adapt_expression(op, othertype)
-        typ = to_instance(typ)
-        if typ._compare_type_affinity(self.impl):
-            return op, self
-        else:
-            return op, typ
 
 class Variant(TypeDecorator):
     """A wrapping type that selects among a variety of
     return typeobj.adapt(impltype)
 
 
-
-
 class NullType(TypeEngine):
     """An unknown type.
 
     """
     __visit_name__ = 'null'
 
-    def _adapt_expression(self, op, othertype):
-        if isinstance(othertype, NullType) or not operators.is_commutative(op):
-            return op, self
-        else:
-            return othertype._adapt_expression(op, self)
+    class Comparator(TypeEngine.Comparator):
+        def _adapt_expression(self, op, other_comparator):
+            if isinstance(other_comparator, NullType.Comparator) or \
+                not operators.is_commutative(op):
+                return op, self.expr.type
+            else:
+                return other_comparator._adapt_expression(op, self)
+    comparator_factory = Comparator
 
 NullTypeEngine = NullType
 
     """A mixin that marks a type as supporting 'concatenation',
     typically strings."""
 
-    def _adapt_expression(self, op, othertype):
-        if op is operators.add and issubclass(othertype._type_affinity,
-                (Concatenable, NullType)):
-            return operators.concat_op, self
-        else:
-            return op, self
+    class Comparator(TypeEngine.Comparator):
+        def _adapt_expression(self, op, other_comparator):
+            if op is operators.add and isinstance(other_comparator,
+                    (Concatenable.Comparator, NullType.Comparator)):
+                return operators.concat_op, self.expr.type
+            else:
+                return op, self.expr.type
+
+    comparator_factory = Comparator
+
 
 class _DateAffinity(object):
     """Mixin date/time specific expression adaptations.
     def _expression_adaptations(self):
         raise NotImplementedError()
 
-    _blank_dict = util.immutabledict()
-    def _adapt_expression(self, op, othertype):
-        othertype = othertype._type_affinity
-        return op, \
-                self._expression_adaptations.get(op, self._blank_dict).\
-                get(othertype, NULLTYPE)
+    class Comparator(TypeEngine.Comparator):
+        _blank_dict = util.immutabledict()
+        def _adapt_expression(self, op, other_comparator):
+            othertype = other_comparator.type._type_affinity
+            return op, \
+                    self.type._expression_adaptations.get(op, self._blank_dict).\
+                    get(othertype, NULLTYPE)
+    comparator_factory = Comparator
 
 class String(Concatenable, TypeEngine):
     """The base for all string and character types.

test/sql/test_operators.py

 class DefaultColumnComparatorTest(fixtures.TestBase):
 
     def _do_scalar_test(self, operator, compare_to):
-        cc = _DefaultColumnComparator()
         left = column('left')
-        assert cc.operate(left, operator).compare(
+        assert left.comparator.operate(operator).compare(
             compare_to(left)
         )
 
     def _do_operate_test(self, operator):
-        cc = _DefaultColumnComparator()
         left = column('left')
         right = column('right')
 
-        assert cc.operate(left, operator, right, result_type=Integer).compare(
+        assert left.comparator.operate(operator, right).compare(
             BinaryExpression(left, right, operator)
         )
 
         self._do_operate_test(operators.add)
 
     def test_in(self):
-        cc = _DefaultColumnComparator()
         left = column('left')
-        assert cc.operate(left, operators.in_op, [1, 2, 3]).compare(
+        assert left.comparator.operate(operators.in_op, [1, 2, 3]).compare(
                 BinaryExpression(
                     left,
                     Grouping(ClauseList(
             )
 
     def test_collate(self):
-        cc = _DefaultColumnComparator()
         left = column('left')
         right = "some collation"
-        cc.operate(left, operators.collate, right).compare(
+        left.comparator.operate(operators.collate, right).compare(
             collate(left, right)
         )
 
         self._assert_add_override(6 - c1)
 
     def test_binary_multi_propagate(self):
-        c1 = Column('foo', self._add_override_factory(True))
+        c1 = Column('foo', self._add_override_factory())
         self._assert_add_override((c1 - 6) + 5)
 
-    def test_no_binary_multi_propagate_wo_adapt(self):
-        c1 = Column('foo', self._add_override_factory())
-        self._assert_not_add_override((c1 - 6) + 5)
-
     def test_no_boolean_propagate(self):
         c1 = Column('foo', self._add_override_factory())
         self._assert_not_add_override(c1 == 56)
         )
 
 class CustomComparatorTest(_CustomComparatorTests, fixtures.TestBase):
-    def _add_override_factory(self, include_adapt=False):
+    def _add_override_factory(self):
 
         class MyInteger(Integer):
             class comparator_factory(TypeEngine.Comparator):
                 def __add__(self, other):
                     return self.expr.op("goofy")(other)
 
-            if include_adapt:
-                def _adapt_expression(self, op, othertype):
-                    if op.__name__ == 'custom_op':
-                        return op, self
-                    else:
-                        return super(MyInteger, self)._adapt_expression(
-                                                            op, othertype)
 
         return MyInteger
 
 
 class TypeDecoratorComparatorTest(_CustomComparatorTests, fixtures.TestBase):
-    def _add_override_factory(self, include_adapt=False):
+    def _add_override_factory(self):
 
         class MyInteger(TypeDecorator):
             impl = Integer
                 def __add__(self, other):
                     return self.expr.op("goofy")(other)
 
-            if include_adapt:
-                def _adapt_expression(self, op, othertype):
-                    if op.__name__ == 'custom_op':
-                        return op, self
-                    else:
-                        return super(MyInteger, self)._adapt_expression(
-                                                            op, othertype)
 
         return MyInteger
 
 
 class CustomEmbeddedinTypeDecoratorTest(_CustomComparatorTests, fixtures.TestBase):
-    def _add_override_factory(self, include_adapt=False):
+    def _add_override_factory(self):
         class MyInteger(Integer):
             class comparator_factory(TypeEngine.Comparator):
                 def __init__(self, expr):
                 def __add__(self, other):
                     return self.expr.op("goofy")(other)
 
-            if include_adapt:
-                def _adapt_expression(self, op, othertype):
-                    if op.__name__ == 'custom_op':
-                        return op, self
-                    else:
-                        return super(MyInteger, self)._adapt_expression(
-                                                            op, othertype)
 
         class MyDecInteger(TypeDecorator):
             impl = MyInteger
         return MyDecInteger
 
 class NewOperatorTest(_CustomComparatorTests, fixtures.TestBase):
-    def _add_override_factory(self, include_adapt=False):
+    def _add_override_factory(self):
         class MyInteger(Integer):
             class comparator_factory(TypeEngine.Comparator):
                 def __init__(self, expr):
 
                 def foob(self, other):
                     return self.expr.op("foob")(other)
-
-            if include_adapt:
-                def _adapt_expression(self, op, othertype):
-                    if op.__name__ == 'custom_op':
-                        return op, self
-                    else:
-                        return super(MyInteger, self)._adapt_expression(
-                                                            op, othertype)
-
         return MyInteger
 
     def _assert_add_override(self, expr):
     def _assert_not_add_override(self, expr):
         assert not hasattr(expr, "foob")
 
-    def test_no_binary_multi_propagate_wo_adapt(self):
-        pass

test/sql/test_types.py

         eq_(expr.right.type.__class__, CHAR)
 
 
+    @testing.uses_deprecated
     @testing.fails_on('firebird', 'Data type unknown on the parameter')
     @testing.fails_on('mssql', 'int is unsigned ?  not clear')
     def test_operator_adapt(self):
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.