Source

sqlalchemy / lib / sqlalchemy / testing / assertions.py

from __future__ import absolute_import

from . import util as testutil
from sqlalchemy import pool, orm, util
from sqlalchemy.engine import default, create_engine
from sqlalchemy import exc as sa_exc
from sqlalchemy.util import decorator
from sqlalchemy import types as sqltypes, schema
import warnings
import re
from .warnings import resetwarnings
from .exclusions import db_spec, _is_excluded
from . import assertsql
from . import config
import itertools
from .util import fail
import contextlib


def emits_warning(*messages):
    """Mark a test as emitting a warning.

    With no arguments, squelches all SAWarning failures.  Or pass one or more
    strings; these will be matched to the root of the warning description by
    warnings.filterwarnings().
    """
    # TODO: it would be nice to assert that a named warning was
    # emitted. should work with some monkeypatching of warnings,
    # and may work on non-CPython if they keep to the spirit of
    # warnings.showwarning's docstring.
    # - update: jython looks ok, it uses cpython's module

    @decorator
    def decorate(fn, *args, **kw):
        # todo: should probably be strict about this, too
        filters = [dict(action='ignore',
                        category=sa_exc.SAPendingDeprecationWarning)]
        if not messages:
            filters.append(dict(action='ignore',
                                 category=sa_exc.SAWarning))
        else:
            filters.extend(dict(action='ignore',
                                 message=message,
                                 category=sa_exc.SAWarning)
                            for message in messages)
        for f in filters:
            warnings.filterwarnings(**f)
        try:
            return fn(*args, **kw)
        finally:
            resetwarnings()
    return decorate


def emits_warning_on(db, *warnings):
    """Mark a test as emitting a warning on a specific dialect.

    With no arguments, squelches all SAWarning failures.  Or pass one or more
    strings; these will be matched to the root of the warning description by
    warnings.filterwarnings().
    """
    spec = db_spec(db)

    @decorator
    def decorate(fn, *args, **kw):
        if isinstance(db, basestring):
            if not spec(config.db):
                return fn(*args, **kw)
            else:
                wrapped = emits_warning(*warnings)(fn)
                return wrapped(*args, **kw)
        else:
            if not _is_excluded(*db):
                return fn(*args, **kw)
            else:
                wrapped = emits_warning(*warnings)(fn)
                return wrapped(*args, **kw)
    return decorate


def uses_deprecated(*messages):
    """Mark a test as immune from fatal deprecation warnings.

    With no arguments, squelches all SADeprecationWarning failures.
    Or pass one or more strings; these will be matched to the root
    of the warning description by warnings.filterwarnings().

    As a special case, you may pass a function name prefixed with //
    and it will be re-written as needed to match the standard warning
    verbiage emitted by the sqlalchemy.util.deprecated decorator.
    """

    @decorator
    def decorate(fn, *args, **kw):
        # todo: should probably be strict about this, too
        filters = [dict(action='ignore',
                        category=sa_exc.SAPendingDeprecationWarning)]
        if not messages:
            filters.append(dict(action='ignore',
                                category=sa_exc.SADeprecationWarning))
        else:
            filters.extend(
                [dict(action='ignore',
                      message=message,
                      category=sa_exc.SADeprecationWarning)
                 for message in
                 [(m.startswith('//') and
                    ('Call to deprecated function ' + m[2:]) or m)
                   for m in messages]])

        for f in filters:
            warnings.filterwarnings(**f)
        try:
            return fn(*args, **kw)
        finally:
            resetwarnings()
    return decorate


def global_cleanup_assertions():
    """Check things that have to be finalized at the end of a test suite.

    Hardcoded at the moment, a modular system can be built here
    to support things like PG prepared transactions, tables all
    dropped, etc.

    """

    testutil.lazy_gc()
    assert not pool._refs, str(pool._refs)


def eq_(a, b, msg=None):
    """Assert a == b, with repr messaging on failure."""
    assert a == b, msg or "%r != %r" % (a, b)


def ne_(a, b, msg=None):
    """Assert a != b, with repr messaging on failure."""
    assert a != b, msg or "%r == %r" % (a, b)


def is_(a, b, msg=None):
    """Assert a is b, with repr messaging on failure."""
    assert a is b, msg or "%r is not %r" % (a, b)


def is_not_(a, b, msg=None):
    """Assert a is not b, with repr messaging on failure."""
    assert a is not b, msg or "%r is %r" % (a, b)


def startswith_(a, fragment, msg=None):
    """Assert a.startswith(fragment), with repr messaging on failure."""
    assert a.startswith(fragment), msg or "%r does not start with %r" % (
        a, fragment)


def assert_raises(except_cls, callable_, *args, **kw):
    try:
        callable_(*args, **kw)
        success = False
    except except_cls:
        success = True

    # assert outside the block so it works for AssertionError too !
    assert success, "Callable did not raise an exception"


def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
    try:
        callable_(*args, **kwargs)
        assert False, "Callable did not raise an exception"
    except except_cls, e:
        assert re.search(msg, unicode(e), re.UNICODE), u"%r !~ %s" % (msg, e)
        print unicode(e).encode('utf-8')


class AssertsCompiledSQL(object):
    def assert_compile(self, clause, result, params=None,
                        checkparams=None, dialect=None,
                        checkpositional=None,
                        use_default_dialect=False,
                        allow_dialect_select=False):
        if use_default_dialect:
            dialect = default.DefaultDialect()
        elif dialect == None and not allow_dialect_select:
            dialect = getattr(self, '__dialect__', None)
            if dialect == 'default':
                dialect = default.DefaultDialect()
            elif dialect is None:
                dialect = config.db.dialect
            elif isinstance(dialect, basestring):
                dialect = create_engine("%s://" % dialect).dialect

        kw = {}
        if params is not None:
            kw['column_keys'] = params.keys()

        if isinstance(clause, orm.Query):
            context = clause._compile_context()
            context.statement.use_labels = True
            clause = context.statement

        c = clause.compile(dialect=dialect, **kw)

        param_str = repr(getattr(c, 'params', {}))
        # Py3K
        #param_str = param_str.encode('utf-8').decode('ascii', 'ignore')

        print "\nSQL String:\n" + str(c) + param_str

        cc = re.sub(r'[\n\t]', '', str(c))

        eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))

        if checkparams is not None:
            eq_(c.construct_params(params), checkparams)
        if checkpositional is not None:
            p = c.construct_params(params)
            eq_(tuple([p[x] for x in c.positiontup]), checkpositional)


class ComparesTables(object):

    def assert_tables_equal(self, table, reflected_table, strict_types=False):
        assert len(table.c) == len(reflected_table.c)
        for c, reflected_c in zip(table.c, reflected_table.c):
            eq_(c.name, reflected_c.name)
            assert reflected_c is reflected_table.c[c.name]
            eq_(c.primary_key, reflected_c.primary_key)
            eq_(c.nullable, reflected_c.nullable)

            if strict_types:
                msg = "Type '%s' doesn't correspond to type '%s'"
                assert type(reflected_c.type) is type(c.type), \
                    msg % (reflected_c.type, c.type)
            else:
                self.assert_types_base(reflected_c, c)

            if isinstance(c.type, sqltypes.String):
                eq_(c.type.length, reflected_c.type.length)

            eq_(
                set([f.column.name for f in c.foreign_keys]),
                set([f.column.name for f in reflected_c.foreign_keys])
            )
            if c.server_default:
                assert isinstance(reflected_c.server_default,
                                  schema.FetchedValue)

        assert len(table.primary_key) == len(reflected_table.primary_key)
        for c in table.primary_key:
            assert reflected_table.primary_key.columns[c.name] is not None

    def assert_types_base(self, c1, c2):
        assert c1.type._compare_type_affinity(c2.type),\
                "On column %r, type '%s' doesn't correspond to type '%s'" % \
                (c1.name, c1.type, c2.type)


class AssertsExecutionResults(object):
    def assert_result(self, result, class_, *objects):
        result = list(result)
        print repr(result)
        self.assert_list(result, class_, objects)

    def assert_list(self, result, class_, list):
        self.assert_(len(result) == len(list),
                     "result list is not the same size as test list, " +
                     "for class " + class_.__name__)
        for i in range(0, len(list)):
            self.assert_row(class_, result[i], list[i])

    def assert_row(self, class_, rowobj, desc):
        self.assert_(rowobj.__class__ is class_,
                     "item class is not " + repr(class_))
        for key, value in desc.iteritems():
            if isinstance(value, tuple):
                if isinstance(value[1], list):
                    self.assert_list(getattr(rowobj, key), value[0], value[1])
                else:
                    self.assert_row(value[0], getattr(rowobj, key), value[1])
            else:
                self.assert_(getattr(rowobj, key) == value,
                             "attribute %s value %s does not match %s" % (
                             key, getattr(rowobj, key), value))

    def assert_unordered_result(self, result, cls, *expected):
        """As assert_result, but the order of objects is not considered.

        The algorithm is very expensive but not a big deal for the small
        numbers of rows that the test suite manipulates.
        """

        class immutabledict(dict):
            def __hash__(self):
                return id(self)

        found = util.IdentitySet(result)
        expected = set([immutabledict(e) for e in expected])

        for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found):
            fail('Unexpected type "%s", expected "%s"' % (
                type(wrong).__name__, cls.__name__))

        if len(found) != len(expected):
            fail('Unexpected object count "%s", expected "%s"' % (
                len(found), len(expected)))

        NOVALUE = object()

        def _compare_item(obj, spec):
            for key, value in spec.iteritems():
                if isinstance(value, tuple):
                    try:
                        self.assert_unordered_result(
                            getattr(obj, key), value[0], *value[1])
                    except AssertionError:
                        return False
                else:
                    if getattr(obj, key, NOVALUE) != value:
                        return False
            return True

        for expected_item in expected:
            for found_item in found:
                if _compare_item(found_item, expected_item):
                    found.remove(found_item)
                    break
            else:
                fail(
                    "Expected %s instance with attributes %s not found." % (
                    cls.__name__, repr(expected_item)))
        return True

    def assert_sql_execution(self, db, callable_, *rules):
        assertsql.asserter.add_rules(rules)
        try:
            callable_()
            assertsql.asserter.statement_complete()
        finally:
            assertsql.asserter.clear_rules()

    def assert_sql(self, db, callable_, list_, with_sequences=None):
        if with_sequences is not None and config.db.dialect.supports_sequences:
            rules = with_sequences
        else:
            rules = list_

        newrules = []
        for rule in rules:
            if isinstance(rule, dict):
                newrule = assertsql.AllOf(*[
                    assertsql.ExactSQL(k, v) for k, v in rule.iteritems()
                ])
            else:
                newrule = assertsql.ExactSQL(*rule)
            newrules.append(newrule)

        self.assert_sql_execution(db, callable_, *newrules)

    def assert_sql_count(self, db, callable_, count):
        self.assert_sql_execution(
            db, callable_, assertsql.CountStatements(count))

    @contextlib.contextmanager
    def assert_execution(self, *rules):
        assertsql.asserter.add_rules(rules)
        try:
            yield
            assertsql.asserter.statement_complete()
        finally:
            assertsql.asserter.clear_rules()

    def assert_statement_count(self, count):
        return self.assert_execution(assertsql.CountStatements(count))