Commits

Ants Aasma  committed 1e970c5

Preliminary implementation for the evaluation framework

  • Participants
  • Parent commits 09e5ee1

Comments (0)

Files changed (2)

File lib/sqlalchemy/orm/evaluator.py

+from sqlalchemy.sql import operators, functions
+from sqlalchemy.sql import expression as sql
+from sqlalchemy.util import Set
+import operator
+
+class UnevaluatableError(Exception):
+    pass
+
+_straight_ops = Set([getattr(operators, op) for op in [
+    'add', 'mul', 'sub', 'div', 'mod', 'truediv', 'lt', 'le', 'ne', 'gt', 'ge', 'eq'
+]])
+
+
+_notimplemented_ops = Set([getattr(operators, op) for op in [
+    'like_op', 'notlike_op', 'ilike_op', 'notilike_op', 'between_op', 'in_op', 'notin_op',
+    'endswith_op', 'concat_op',
+]])
+
+class EvaluatorCompiler(object):
+    def process(self, clause):
+        meth = getattr(self, "visit_%s" % clause.__visit_name__, None)
+        if not meth:
+            raise UnevaluatableError("Cannot evaluate %s" % type(clause).__name__)
+        return meth(clause)
+    
+    def visit_grouping(self, clause):
+        return self.process(clause.element)
+    
+    def visit_null(self, clause):
+        return lambda obj: None
+    
+    def visit_column(self, clause):
+        if 'parententity' in clause._annotations:
+            key = clause._annotations['parententity']._get_col_to_prop(clause).key
+        else:
+            key = clause.key
+        get_corresponding_attr = operator.attrgetter(key)
+        return lambda obj: get_corresponding_attr(obj)
+    
+    def visit_clauselist(self, clause):
+        evaluators = map(self.process, clause.clauses)
+        if clause.operator is operators.or_:
+            def evaluate(obj):
+                has_null = False
+                for sub_evaluate in evaluators:
+                    value = sub_evaluate(obj)
+                    if value:
+                        return True
+                    has_null = has_null or value is None
+                if has_null:
+                    return None
+                return False
+        if clause.operator is operators.and_:
+            def evaluate(obj):
+                for sub_evaluate in evaluators:
+                    value = sub_evaluate(obj)
+                    if not value:
+                        if value is None:
+                            return None
+                        return False
+                return True
+        
+        return evaluate
+
+    def visit_binary(self, clause):
+        eval_left,eval_right = map(self.process, [clause.left, clause.right])
+        operator = clause.operator
+        if operator is operators.is_:
+            def evaluate(obj):
+                return eval_left(obj) == eval_right(obj)
+        if operator is operators.isnot:
+            def evaluate(obj):
+                return eval_left(obj) != eval_right(obj)
+        elif operator in _straight_ops:
+            def evaluate(obj):
+                left_val = eval_left(obj)
+                right_val = eval_right(obj)
+                if left_val is None or right_val is None:
+                    return None
+                return operator(eval_left(obj), eval_right(obj))
+        return evaluate
+
+    def visit_unary(self, clause):
+        eval_inner = self.process(clause.element)
+        if clause.operator is operators.inv:
+            def evaluate(obj):
+                value = eval_inner(obj)
+                if value is None:
+                    return None
+                return not value
+            return evaluate
+        raise UnevaluatableError("Cannot evaluate %s with operator %s" % (type(clause).__name__, clause.operator))
+    
+    def visit_bindparam(self, clause):
+        val = clause.value
+        return lambda obj: val

File test/orm/evaluator.py

+"""Evluating SQL expressions on ORM objects"""
+import testenv; testenv.configure_for_tests()
+from testlib import sa, testing
+from testlib.sa import Table, Column, String, Integer, select
+from testlib.sa.orm import mapper, create_session
+from testlib.testing import eq_
+from orm import _base
+
+from sqlalchemy import and_, or_, not_
+from sqlalchemy.orm import evaluator
+
+compiler = evaluator.EvaluatorCompiler()
+def eval_eq(clause, testcases=None):
+    evaluator = compiler.process(clause)
+    def testeval(obj=None, expected_result=None):
+        assert evaluator(obj) == expected_result, "%s != %r for %s with %r" % (evaluator(obj), expected_result, clause, obj)
+    if testcases:
+        for an_obj,result in testcases:
+            testeval(an_obj, result)
+    return testeval
+
+class EvaluateTest(_base.MappedTest):
+    def define_tables(self, metadata):
+        Table('users', metadata,
+              Column('id', Integer, primary_key=True),
+              Column('name', String))
+    
+    def setup_classes(self):
+        class User(_base.ComparableEntity):
+            pass
+    
+    @testing.resolve_artifact_names
+    def setup_mappers(self):
+        mapper(User, users)
+    
+    @testing.resolve_artifact_names
+    def test_compare_to_value(self):
+        eval_eq(User.name == 'foo', testcases=[
+            (User(name='foo'), True),
+            (User(name='bar'), False),
+            (User(name=None), None),
+        ])
+        
+        eval_eq(User.id < 5, testcases=[
+            (User(id=3), True),
+            (User(id=5), False),
+            (User(id=None), None),
+        ])
+    
+    @testing.resolve_artifact_names
+    def test_compare_to_none(self):
+        eval_eq(User.name == None, testcases=[
+            (User(name='foo'), False),
+            (User(name=None), True),
+        ])
+   
+    @testing.resolve_artifact_names
+    def test_boolean_ops(self):
+        eval_eq(and_(User.name == 'foo', User.id == 1), testcases=[
+            (User(id=1, name='foo'), True),
+            (User(id=2, name='foo'), False),
+            (User(id=1, name='bar'), False),
+            (User(id=2, name='bar'), False),
+            (User(id=1, name=None), None),
+        ])
+        
+        eval_eq(or_(User.name == 'foo', User.id == 1), testcases=[
+            (User(id=1, name='foo'), True),
+            (User(id=2, name='foo'), True),
+            (User(id=1, name='bar'), True),
+            (User(id=2, name='bar'), False),
+            (User(id=1, name=None), True),
+            (User(id=2, name=None), None),
+        ])
+        
+        eval_eq(not_(User.id == 1), testcases=[
+            (User(id=1), False),
+            (User(id=2), True),
+            (User(id=None), None),
+        ])
+
+    @testing.resolve_artifact_names
+    def test_null_propagation(self):
+        eval_eq((User.name == 'foo') == (User.id == 1), testcases=[
+            (User(id=1, name='foo'), True),
+            (User(id=2, name='foo'), False),
+            (User(id=1, name='bar'), False),
+            (User(id=2, name='bar'), True),
+            (User(id=None, name='foo'), None),
+            (User(id=None, name=None), None),
+        ])
+
+if __name__ == '__main__':
+    testenv.main()