Commits

Mike Bayer committed 45cec09

- unit tests have been migrated from unittest to nose.
See README.unittests for information on how to run
the tests. [ticket:970]

Comments (0)

Files changed (237)

 
 0.5.5
 =======
+- general
+    - unit tests have been migrated from unittest to nose.
+      See README.unittests for information on how to run
+      the tests.  [ticket:970]
 - orm
     - Fixed bug introduced in 0.5.4 whereby Composite types
       fail when default-holding columns are flushed.
 SQLALCHEMY UNIT TESTS
 =====================
 
-SETUP
------
-SQLite support is required.  These instructions assume standard Python 2.4 or
-higher. 
-
-The 'test' directory must be on the PYTHONPATH.
+SQLAlchemy unit tests by default run using Python's built-in sqlite3 
+module.  If running on Python 2.4, pysqlite must be installed.
 
-cd into the SQLAlchemy distribution directory
+As of 0.5.5, unit tests are run using nose.  Documentation and
+downloads for nose are available at:
 
-In bash:
+http://somethingaboutorange.com/mrl/projects/nose/0.11.1/index.html
 
-    $ export PYTHONPATH=./test/
 
-On windows:
+SQLAlchemy implements a nose plugin that must be present when tests are run.
+This plugin is available when SQLAlchemy is installed via setuptools.
 
-    C:\sa\> set PYTHONPATH=test\
+SETUP
+-----
 
-    Adjust any other use Unix-style paths in this README as needed.
+All that's required is for SQLAlchemy to be installed via setuptools.
+For example, to create a local install in a source distribution directory:
 
-The unittest framework will automatically prepend the lib/ directory to
-sys.path.  This forces the local version of SQLAlchemy to be used, bypassing
-any setuptools-installed installations (setuptools places .egg files ahead of
-plain directories, even if on PYTHONPATH, unfortunately).
+    $ export PYTHONPATH=.
+    $ python setup.py develop -d .
 
+The above will create a setuptools "development" distribution in the local
+path, which allows the Nose plugin to be available when nosetests is run.
+The plugin is enabled using the "with-sqlalchemy=True" configuration
+in setup.cfg.
 
 RUNNING ALL TESTS
 -----------------
 To run all tests:
 
-    $ python test/alltests.py
+    $ nosetests
+
+Assuming all tests pass, this is a very unexciting output.  To make it more 
+intersesting:
 
+    $ nosetests -v
 
 RUNNING INDIVIDUAL TESTS
 -------------------------
-Any unittest module can be run directly from the module file:
+Any test module can be run directly by specifying its module name:
 
-    python test/orm/mapper.py
+    $ nosetests test.orm.test_mapper
 
-To run a specific test within the module, specify it as ClassName.methodname:
+To run a specific test within the module, specify it as module:ClassName.methodname:
 
-    python test/orm/mapper.py MapperTest.testget
+    $ nosetests test.orm.test_mapper:MapperTest.test_utils
 
 
 COMMAND LINE OPTIONS
 --------------------
-Help is available via --help
+Help is available via --help:
 
-    $ python test/alltests.py --help
+    $ nosetests --help
 
-    usage: alltests.py [options] [tests...]
-
-    Options:
-      -h, --help            show this help message and exit
-      --verbose             enable stdout echoing/printing
-      --quiet               suppress output
-    [...]
-
-Command line options can applied to alltests.py or any individual test module.
-Many are available.  The most commonly used are '--db' and '--dburi'.
+The --help screen is a combination of common nose options and options which 
+the SQLAlchemy nose plugin adds.  The most commonly SQLAlchemy-specific 
+options used are '--db' and '--dburi'.
 
 
 DATABASE TARGETS
 If you'll be running the tests frequently, database aliases can save a lot of
 typing.  The --dbs option lists the built-in aliases and their matching URLs:
 
-    $ python test/alltests.py --dbs
+    $ nosetests --dbs
     Available --db options (use --dburi to override)
                mysql    mysql://scott:tiger@127.0.0.1:3306/test
               oracle    oracle://scott:tiger@127.0.0.1:1521
 
 To run tests against an aliased database:
 
-    $ python test/alltests.py --db=postgres
+    $ nosetests --db=postgres
 
 To customize the URLs with your own users or hostnames, make a simple .ini
 file called `test.cfg` at the top level of the SQLAlchemy source distribution
 Any log target can be directed to the console with command line options, such
 as:
 
-    $ python test/orm/unitofwork.py --log-info=sqlalchemy.orm.mapper \
+    $ nosetests test.orm.unitofwork --log-info=sqlalchemy.orm.mapper \
       --log-debug=sqlalchemy.pool --log-info=sqlalchemy.engine
 
 This would log mapper configuration, connection pool checkouts, and SQL
 
 BUILT-IN COVERAGE REPORTING
 ------------------------------
-Coverage is tracked with coverage.py module, included in the './test/'
-directory.  Running the test suite with the --coverage switch will generate a
-local file ".coverage" containing coverage details, and a report will be
-printed to standard output with an overview of the coverage gathered from the
-last unittest run (the file is deleted between runs).
-
-After the suite has been run with --coverage, an annotated version of any
-source file can be generated, marking statements that are executed with > and
-statements that are missed with !, by running the coverage.py utility with the
-"-a" (annotate) option, such as:
-
-    $ python ./test/testlib/coverage.py -a ./lib/sqlalchemy/sql.py
+Coverage is tracked using Nose's coverage plugin.   See the nose 
+documentation for details.  Basic usage is:
 
-This will create a new annotated file ./lib/sqlalchemy/sql.py,cover. Pretty
-cool!
+    $ nosetests test.sql.test_query --with-coverage
 
 BIG COVERAGE TIP !!!  There is an issue where existing .pyc files may
 store the incorrect filepaths, which will break the coverage system.  If
+import os
+import subprocess
+import re
+
+def walk():
+    for root, dirs, files in os.walk("./test/"):
+        if root.endswith("/perf"):
+            continue
+        
+        for fname in files:
+            if not fname.endswith(".py"):
+                continue
+            if fname == "alltests.py":
+                subprocess.call(["svn", "remove", os.path.join(root, fname)])
+            elif fname.startswith("_") or fname == "__init__.py" or fname == "pickleable.py":
+                convert(os.path.join(root, fname))
+            elif not fname.startswith("test_"):
+                if os.path.exists(os.path.join(root, "test_" + fname)):
+                    os.unlink(os.path.join(root, "test_" + fname))
+                subprocess.call(["svn", "rename", os.path.join(root, fname), os.path.join(root, "test_" + fname)])
+                convert(os.path.join(root, "test_" + fname))
+
+
+def convert(fname):
+    lines = list(file(fname))
+    replaced = []
+    flags = {}
+    
+    while lines:
+        for reg, handler in handlers:
+            m = reg.match(lines[0])
+            if m:
+                handler(lines, replaced, flags)
+                break
+    
+    post_handler(lines, replaced, flags)
+    f = file(fname, 'w')
+    f.write("".join(replaced))
+    f.close()
+
+handlers = []
+
+
+def post_handler(lines, replaced, flags):
+    imports = []
+    if "needs_eq" in flags:
+        imports.append("eq_")
+    if "needs_assert_raises" in flags:
+        imports += ["assert_raises", "assert_raises_message"]
+    if imports:
+        for i, line in enumerate(replaced):
+            if "import" in line:
+                replaced.insert(i, "from sqlalchemy.test.testing import %s\n" % ", ".join(imports))
+                break
+    
+def remove_line(lines, replaced, flags):
+    lines.pop(0)
+    
+handlers.append((re.compile(r"import testenv; testenv\.configure_for_tests"), remove_line))
+handlers.append((re.compile(r"(.*\s)?import sa_unittest"), remove_line))
+
+
+def import_testlib_sa(lines, replaced, flags):
+    line = lines.pop(0)
+    line = line.replace("import testlib.sa", "import sqlalchemy")
+    replaced.append(line)
+handlers.append((re.compile("import testlib\.sa"), import_testlib_sa))
+
+def from_testlib_sa(lines, replaced, flags):
+    line = lines.pop(0)
+    while True:
+        if line.endswith("\\\n"):
+            line = line[0:-2] + lines.pop(0)
+        else:
+            break
+    
+    components = re.compile(r'from testlib\.sa import (.*)').match(line)
+    if components:
+        components = re.split(r"\s*,\s*", components.group(1))
+        line = "from sqlalchemy import %s\n" % (", ".join(c for c in components if c not in ("Table", "Column")))
+        replaced.append(line)
+        if "Table" in components:
+            replaced.append("from sqlalchemy.test.schema import Table\n")
+        if "Column" in components:
+            replaced.append("from sqlalchemy.test.schema import Column\n")
+        return
+        
+    line = line.replace("testlib.sa", "sqlalchemy")
+    replaced.append(line)
+handlers.append((re.compile("from testlib\.sa.*import"), from_testlib_sa))
+
+def from_testlib(lines, replaced, flags):
+    line = lines.pop(0)
+    
+    components = re.compile(r'from testlib import (.*)').match(line)
+    if components:
+        components = re.split(r"\s*,\s*", components.group(1))
+        if "sa" in components:
+            replaced.append("import sqlalchemy as sa\n")
+            replaced.append("from sqlalchemy.test import %s\n" % (", ".join(c for c in components if c != "sa" and c != "sa as tsa")))
+            return
+        elif "sa as tsa" in components:
+            replaced.append("import sqlalchemy as tsa\n")
+            replaced.append("from sqlalchemy.test import %s\n" % (", ".join(c for c in components if c != "sa" and c != "sa as tsa")))
+            return
+    
+    line = line.replace("testlib", "sqlalchemy.test")
+    replaced.append(line)
+handlers.append((re.compile(r"from testlib"), from_testlib))
+
+def from_orm(lines, replaced, flags):
+    line = lines.pop(0)
+    line = line.replace("from orm import", "from test.orm import")
+    line = line.replace("from orm.", "from test.orm.")
+    replaced.append(line)
+handlers.append((re.compile(r'from orm( import|\.)'), from_orm))
+    
+def assert_equals(lines, replaced, flags):
+    line = lines.pop(0)
+    line = line.replace("self.assertEquals", "eq_")
+    line = line.replace("self.assertEqual", "eq_")
+    replaced.append(line)
+    flags["needs_eq"] = True
+handlers.append((re.compile(r"\s*self\.assertEqual(s)?"), assert_equals))
+
+def assert_raises(lines, replaced, flags):
+    line = lines.pop(0)
+    line = line.replace("self.assertRaisesMessage", "assert_raises_message")
+    line = line.replace("self.assertRaises", "assert_raises")
+    replaced.append(line)
+    flags["needs_assert_raises"] = True
+handlers.append((re.compile(r"\s*self\.assertRaises(Message)?"), assert_raises))
+
+def setup_all(lines, replaced, flags):
+    line = lines.pop(0)
+    whitespace = re.compile(r"(\s*)def setUpAll\(self\)\:").match(line).group(1)
+    replaced.append("%s@classmethod\n" % whitespace)
+    replaced.append("%sdef setup_class(cls):\n" % whitespace)
+handlers.append((re.compile(r"\s*def setUpAll\(self\)"), setup_all))
+
+def teardown_all(lines, replaced, flags):
+    line = lines.pop(0)
+    whitespace = re.compile(r"(\s*)def tearDownAll\(self\)\:").match(line).group(1)
+    replaced.append("%s@classmethod\n" % whitespace)
+    replaced.append("%sdef teardown_class(cls):\n" % whitespace)
+handlers.append((re.compile(r"\s*def tearDownAll\(self\)"), teardown_all))
+
+def setup(lines, replaced, flags):
+    line = lines.pop(0)
+    whitespace = re.compile(r"(\s*)def setUp\(self\)\:").match(line).group(1)
+    replaced.append("%sdef setup(self):\n" % whitespace)
+handlers.append((re.compile(r"\s*def setUp\(self\)"), setup))
+
+def teardown(lines, replaced, flags):
+    line = lines.pop(0)
+    whitespace = re.compile(r"(\s*)def tearDown\(self\)\:").match(line).group(1)
+    replaced.append("%sdef teardown(self):\n" % whitespace)
+handlers.append((re.compile(r"\s*def tearDown\(self\)"), teardown))
+    
+def define_tables(lines, replaced, flags):
+    line = lines.pop(0)
+    whitespace = re.compile(r"(\s*)def define_tables").match(line).group(1)
+    replaced.append("%s@classmethod\n" % whitespace)
+    replaced.append("%sdef define_tables(cls, metadata):\n" % whitespace)
+handlers.append((re.compile(r"\s*def define_tables\(self, metadata\)"), define_tables))
+
+def setup_mappers(lines, replaced, flags):
+    line = lines.pop(0)
+    whitespace = re.compile(r"(\s*)def setup_mappers").match(line).group(1)
+    
+    i = -1
+    while re.match("\s*@testing", replaced[i]):
+        i -= 1
+        
+    replaced.insert(len(replaced) + i + 1, "%s@classmethod\n" % whitespace)
+    replaced.append("%sdef setup_mappers(cls):\n" % whitespace)
+handlers.append((re.compile(r"\s*def setup_mappers\(self\)"), setup_mappers))
+
+def setup_classes(lines, replaced, flags):
+    line = lines.pop(0)
+    whitespace = re.compile(r"(\s*)def setup_classes").match(line).group(1)
+    
+    i = -1
+    while re.match("\s*@testing", replaced[i]):
+        i -= 1
+        
+    replaced.insert(len(replaced) + i + 1, "%s@classmethod\n" % whitespace)
+    replaced.append("%sdef setup_classes(cls):\n" % whitespace)
+handlers.append((re.compile(r"\s*def setup_classes\(self\)"), setup_classes))
+
+def insert_data(lines, replaced, flags):
+    line = lines.pop(0)
+    whitespace = re.compile(r"(\s*)def insert_data").match(line).group(1)
+    
+    i = -1
+    while re.match("\s*@testing", replaced[i]):
+        i -= 1
+        
+    replaced.insert(len(replaced) + i + 1, "%s@classmethod\n" % whitespace)
+    replaced.append("%sdef insert_data(cls):\n" % whitespace)
+handlers.append((re.compile(r"\s*def insert_data\(self\)"), insert_data))
+
+def fixtures(lines, replaced, flags):
+    line = lines.pop(0)
+    whitespace = re.compile(r"(\s*)def fixtures").match(line).group(1)
+    
+    i = -1
+    while re.match("\s*@testing", replaced[i]):
+        i -= 1
+        
+    replaced.insert(len(replaced) + i + 1, "%s@classmethod\n" % whitespace)
+    replaced.append("%sdef fixtures(cls):\n" % whitespace)
+handlers.append((re.compile(r"\s*def fixtures\(self\)"), fixtures))
+    
+    
+def call_main(lines, replaced, flags):
+    replaced.pop(-1)
+    lines.pop(0)
+handlers.append((re.compile(r"\s+testenv\.main\(\)"), call_main))
+
+def default(lines, replaced, flags):
+    replaced.append(lines.pop(0))
+handlers.append((re.compile(r".*"), default))
+
+
+if __name__ == '__main__':
+    convert("test/orm/inheritance/abc_inheritance.py")
+#    walk()

lib/sqlalchemy/test/__init__.py

+"""Testing environment and utilities.
+
+This package contains base classes and routines used by 
+the unit tests.   Tests are based on Nose and bootstrapped
+by noseplugin.NoseSQLAlchemy.
+
+"""
+
+from sqlalchemy.test import testing, engines, requires, profiling, pickleable, config
+from sqlalchemy.test.schema import Column, Table
+from sqlalchemy.test.testing import \
+     AssertsCompiledSQL, \
+     AssertsExecutionResults, \
+     ComparesTables, \
+     TestBase, \
+     rowset
+
+
+__all__ = ('testing',
+            'Column', 'Table',
+           'rowset',
+           'TestBase', 'AssertsExecutionResults',
+           'AssertsCompiledSQL', 'ComparesTables',
+           'engines', 'profiling', 'pickleable')
+
+

lib/sqlalchemy/test/assertsql.py

+
+from sqlalchemy.interfaces import ConnectionProxy
+from sqlalchemy.engine.default import DefaultDialect
+from sqlalchemy.engine.base import Connection
+from sqlalchemy import util
+import testing
+import re
+
+class AssertRule(object):
+    def process_execute(self, clauseelement, *multiparams, **params):
+        pass
+
+    def process_cursor_execute(self, statement, parameters, context, executemany):
+        pass
+        
+    def is_consumed(self):
+        """Return True if this rule has been consumed, False if not.
+        
+        Should raise an AssertionError if this rule's condition has definitely failed.
+        
+        """
+        raise NotImplementedError()
+    
+    def rule_passed(self):
+        """Return True if the last test of this rule passed, False if failed, None if no test was applied."""
+        
+        raise NotImplementedError()
+        
+    def consume_final(self):
+        """Return True if this rule has been consumed.
+        
+        Should raise an AssertionError if this rule's condition has not been consumed or has failed.
+        
+        """
+        
+        if self._result is None:
+            assert False, "Rule has not been consumed"
+            
+        return self.is_consumed()
+
+class SQLMatchRule(AssertRule):
+    def __init__(self):
+        self._result = None
+        self._errmsg = ""
+    
+    def rule_passed(self):
+        return self._result
+        
+    def is_consumed(self):
+        if self._result is None:
+            return False
+            
+        assert self._result, self._errmsg
+        
+        return True
+    
+class ExactSQL(SQLMatchRule):
+    def __init__(self, sql, params=None):
+        SQLMatchRule.__init__(self)
+        self.sql = sql
+        self.params = params
+    
+    def process_cursor_execute(self, statement, parameters, context, executemany):
+        if not context:
+            return
+            
+        _received_statement = _process_engine_statement(statement, context)
+        _received_parameters = context.compiled_parameters
+        
+        # TODO: remove this step once all unit tests
+        # are migrated, as ExactSQL should really be *exact* SQL 
+        sql = _process_assertion_statement(self.sql, context)
+        
+        equivalent = _received_statement == sql
+        if self.params:
+            if util.callable(self.params):
+                params = self.params(context)
+            else:
+                params = self.params
+
+            if not isinstance(params, list):
+                params = [params]
+            equivalent = equivalent and params == context.compiled_parameters
+        else:
+            params = {}
+        
+        
+        self._result = equivalent
+        if not self._result:
+            self._errmsg = "Testing for exact statement %r exact params %r, " \
+                "received %r with params %r" % (sql, params, _received_statement, _received_parameters)
+    
+
+class RegexSQL(SQLMatchRule):
+    def __init__(self, regex, params=None):
+        SQLMatchRule.__init__(self)
+        self.regex = re.compile(regex)
+        self.orig_regex = regex
+        self.params = params
+
+    def process_cursor_execute(self, statement, parameters, context, executemany):
+        if not context:
+            return
+
+        _received_statement = _process_engine_statement(statement, context)
+        _received_parameters = context.compiled_parameters
+
+        equivalent = bool(self.regex.match(_received_statement))
+        if self.params:
+            if util.callable(self.params):
+                params = self.params(context)
+            else:
+                params = self.params
+
+            if not isinstance(params, list):
+                params = [params]
+            
+            # do a positive compare only
+            for param, received in zip(params, _received_parameters):
+                for k, v in param.iteritems():
+                    if k not in received or received[k] != v:
+                        equivalent = False
+                        break
+        else:
+            params = {}
+
+        self._result = equivalent
+        if not self._result:
+            self._errmsg = "Testing for regex %r partial params %r, "\
+                "received %r with params %r" % (self.orig_regex, params, _received_statement, _received_parameters)
+
+class CompiledSQL(SQLMatchRule):
+    def __init__(self, statement, params):
+        SQLMatchRule.__init__(self)
+        self.statement = statement
+        self.params = params
+
+    def process_cursor_execute(self, statement, parameters, context, executemany):
+        if not context:
+            return
+
+        _received_parameters = context.compiled_parameters
+        
+        # recompile from the context, using the default dialect
+        compiled = context.compiled.statement.\
+                compile(dialect=DefaultDialect(), column_keys=context.compiled.column_keys)
+                
+        _received_statement = re.sub(r'\n', '', str(compiled))
+        
+        equivalent = self.statement == _received_statement
+        if self.params:
+            if util.callable(self.params):
+                params = self.params(context)
+            else:
+                params = self.params
+
+            if not isinstance(params, list):
+                params = [params]
+            
+            # do a positive compare only
+            for param, received in zip(params, _received_parameters):
+                for k, v in param.iteritems():
+                    if k not in received or received[k] != v:
+                        equivalent = False
+                        break
+        else:
+            params = {}
+
+        self._result = equivalent
+        if not self._result:
+            self._errmsg = "Testing for compiled statement %r partial params %r, " \
+                    "received %r with params %r" % (self.statement, params, _received_statement, _received_parameters)
+    
+        
+class CountStatements(AssertRule):
+    def __init__(self, count):
+        self.count = count
+        self._statement_count = 0
+        
+    def process_execute(self, clauseelement, *multiparams, **params):
+        self._statement_count += 1
+
+    def process_cursor_execute(self, statement, parameters, context, executemany):
+        pass
+
+    def is_consumed(self):
+        return False
+    
+    def consume_final(self):
+        assert self.count == self._statement_count, "desired statement count %d does not match %d" % (self.count, self._statement_count)
+        return True
+        
+class AllOf(AssertRule):
+    def __init__(self, *rules):
+        self.rules = set(rules)
+        
+    def process_execute(self, clauseelement, *multiparams, **params):
+        for rule in self.rules:
+            rule.process_execute(clauseelement, *multiparams, **params)
+
+    def process_cursor_execute(self, statement, parameters, context, executemany):
+        for rule in self.rules:
+            rule.process_cursor_execute(statement, parameters, context, executemany)
+
+    def is_consumed(self):
+        if not self.rules:
+            return True
+        
+        for rule in list(self.rules):
+            if rule.rule_passed(): # a rule passed, move on
+                self.rules.remove(rule)
+                return len(self.rules) == 0
+
+        assert False, "No assertion rules were satisfied for statement"
+    
+    def consume_final(self):
+        return len(self.rules) == 0
+        
+def _process_engine_statement(query, context):
+    if context.engine.name == 'mssql' and query.endswith('; select scope_identity()'):
+        query = query[:-25]
+    
+    query = re.sub(r'\n', '', query)
+    
+    return query
+    
+def _process_assertion_statement(query, context):
+    paramstyle = context.dialect.paramstyle
+    if paramstyle == 'named':
+        pass
+    elif paramstyle =='pyformat':
+        query = re.sub(r':([\w_]+)', r"%(\1)s", query)
+    else:
+        # positional params
+        repl = None
+        if paramstyle=='qmark':
+            repl = "?"
+        elif paramstyle=='format':
+            repl = r"%s"
+        elif paramstyle=='numeric':
+            repl = None
+        query = re.sub(r':([\w_]+)', repl, query)
+
+    return query
+
+class SQLAssert(ConnectionProxy):
+    rules = None
+    
+    def add_rules(self, rules):
+        self.rules = list(rules)
+    
+    def statement_complete(self):
+        for rule in self.rules:
+            if not rule.consume_final():
+                assert False, "All statements are complete, but pending assertion rules remain"
+
+    def clear_rules(self):
+        del self.rules
+        
+    def execute(self, conn, execute, clauseelement, *multiparams, **params):
+        result = execute(clauseelement, *multiparams, **params)
+
+        if self.rules is not None:
+            if not self.rules:
+                assert False, "All rules have been exhausted, but further statements remain"
+            rule = self.rules[0]
+            rule.process_execute(clauseelement, *multiparams, **params)
+            if rule.is_consumed():
+                self.rules.pop(0)
+            
+        return result
+        
+    def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
+        result = execute(cursor, statement, parameters, context)
+        
+        if self.rules:
+            rule = self.rules[0]
+            rule.process_cursor_execute(statement, parameters, context, executemany)
+
+        return result
+
+asserter = SQLAssert()
+    

lib/sqlalchemy/test/config.py

+import optparse, os, sys, re, ConfigParser, StringIO, time, warnings
+logging = None
+
+__all__ = 'parser', 'configure', 'options',
+
+db = None
+db_label, db_url, db_opts = None, None, {}
+
+options = None
+file_config = None
+
+base_config = """
+[db]
+sqlite=sqlite:///:memory:
+sqlite_file=sqlite:///querytest.db
+postgres=postgres://scott:tiger@127.0.0.1:5432/test
+mysql=mysql://scott:tiger@127.0.0.1:3306/test
+oracle=oracle://scott:tiger@127.0.0.1:1521
+oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
+mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test
+firebird=firebird://sysdba:masterkey@localhost//tmp/test.fdb
+maxdb=maxdb://MONA:RED@/maxdb1
+"""
+
+def _log(option, opt_str, value, parser):
+    global logging
+    if not logging:
+        import logging
+        logging.basicConfig()
+
+    if opt_str.endswith('-info'):
+        logging.getLogger(value).setLevel(logging.INFO)
+    elif opt_str.endswith('-debug'):
+        logging.getLogger(value).setLevel(logging.DEBUG)
+
+
+def _list_dbs(*args):
+    print "Available --db options (use --dburi to override)"
+    for macro in sorted(file_config.options('db')):
+        print "%20s\t%s" % (macro, file_config.get('db', macro))
+    sys.exit(0)
+
+def _server_side_cursors(options, opt_str, value, parser):
+    db_opts['server_side_cursors'] = True
+
+def _engine_strategy(options, opt_str, value, parser):
+    if value:
+        db_opts['strategy'] = value
+
+class _ordered_map(object):
+    def __init__(self):
+        self._keys = list()
+        self._data = dict()
+
+    def __setitem__(self, key, value):
+        if key not in self._keys:
+            self._keys.append(key)
+        self._data[key] = value
+
+    def __iter__(self):
+        for key in self._keys:
+            yield self._data[key]
+
+# at one point in refactoring, modules were injecting into the config
+# process.  this could probably just become a list now.
+post_configure = _ordered_map()
+
+def _engine_uri(options, file_config):
+    global db_label, db_url
+    db_label = 'sqlite'
+    if options.dburi:
+        db_url = options.dburi
+        db_label = db_url[:db_url.index(':')]
+    elif options.db:
+        db_label = options.db
+        db_url = None
+
+    if db_url is None:
+        if db_label not in file_config.options('db'):
+            raise RuntimeError(
+                "Unknown engine.  Specify --dbs for known engines.")
+        db_url = file_config.get('db', db_label)
+post_configure['engine_uri'] = _engine_uri
+
+def _require(options, file_config):
+    if not(options.require or
+           (file_config.has_section('require') and
+            file_config.items('require'))):
+        return
+
+    try:
+        import pkg_resources
+    except ImportError:
+        raise RuntimeError("setuptools is required for version requirements")
+
+    cmdline = []
+    for requirement in options.require:
+        pkg_resources.require(requirement)
+        cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0])
+
+    if file_config.has_section('require'):
+        for label, requirement in file_config.items('require'):
+            if not label == db_label or label.startswith('%s.' % db_label):
+                continue
+            seen = [c for c in cmdline if requirement.startswith(c)]
+            if seen:
+                continue
+            pkg_resources.require(requirement)
+post_configure['require'] = _require
+
+def _engine_pool(options, file_config):
+    if options.mockpool:
+        from sqlalchemy import pool
+        db_opts['poolclass'] = pool.AssertionPool
+post_configure['engine_pool'] = _engine_pool
+
+def _create_testing_engine(options, file_config):
+    from sqlalchemy.test import engines, testing
+    global db
+    db = engines.testing_engine(db_url, db_opts)
+    testing.db = db
+post_configure['create_engine'] = _create_testing_engine
+
+def _prep_testing_database(options, file_config):
+    from sqlalchemy.test import engines
+    from sqlalchemy import schema
+
+    try:
+        # also create alt schemas etc. here?
+        if options.dropfirst:
+            e = engines.utf8_engine()
+            existing = e.table_names()
+            if existing:
+                print "Dropping existing tables in database: " + db_url
+                try:
+                    print "Tables: %s" % ', '.join(existing)
+                except:
+                    pass
+                print "Abort within 5 seconds..."
+                time.sleep(5)
+                md = schema.MetaData(e, reflect=True)
+                md.drop_all()
+            e.dispose()
+    except (KeyboardInterrupt, SystemExit):
+        raise
+    except Exception, e:
+        warnings.warn(RuntimeWarning(
+            "Error checking for existing tables in testing "
+            "database: %s" % e))
+post_configure['prep_db'] = _prep_testing_database
+
+def _set_table_options(options, file_config):
+    from sqlalchemy.test import schema
+
+    table_options = schema.table_options
+    for spec in options.tableopts:
+        key, value = spec.split('=')
+        table_options[key] = value
+
+    if options.mysql_engine:
+        table_options['mysql_engine'] = options.mysql_engine
+post_configure['table_options'] = _set_table_options
+
+def _reverse_topological(options, file_config):
+    if options.reversetop:
+        from sqlalchemy.orm import unitofwork
+        from sqlalchemy import topological
+        class RevQueueDepSort(topological.QueueDependencySorter):
+            def __init__(self, tuples, allitems):
+                self.tuples = list(tuples)
+                self.allitems = list(allitems)
+                self.tuples.reverse()
+                self.allitems.reverse()
+        topological.QueueDependencySorter = RevQueueDepSort
+        unitofwork.DependencySorter = RevQueueDepSort
+post_configure['topological'] = _reverse_topological
+

lib/sqlalchemy/test/engines.py

+import sys, types, weakref
+from collections import deque
+import config
+from sqlalchemy.util import function_named, callable
+
+class ConnectionKiller(object):
+    def __init__(self):
+        self.proxy_refs = weakref.WeakKeyDictionary()
+
+    def checkout(self, dbapi_con, con_record, con_proxy):
+        self.proxy_refs[con_proxy] = True
+
+    def _apply_all(self, methods):
+        for rec in self.proxy_refs:
+            if rec is not None and rec.is_valid:
+                try:
+                    for name in methods:
+                        if callable(name):
+                            name(rec)
+                        else:
+                            getattr(rec, name)()
+                except (SystemExit, KeyboardInterrupt):
+                    raise
+                except Exception, e:
+                    # fixme
+                    sys.stderr.write("\n" + str(e) + "\n")
+
+    def rollback_all(self):
+        self._apply_all(('rollback',))
+
+    def close_all(self):
+        self._apply_all(('rollback', 'close'))
+
+    def assert_all_closed(self):
+        for rec in self.proxy_refs:
+            if rec.is_valid:
+                assert False
+
+testing_reaper = ConnectionKiller()
+
+def assert_conns_closed(fn):
+    def decorated(*args, **kw):
+        try:
+            fn(*args, **kw)
+        finally:
+            testing_reaper.assert_all_closed()
+    return function_named(decorated, fn.__name__)
+
+def rollback_open_connections(fn):
+    """Decorator that rolls back all open connections after fn execution."""
+
+    def decorated(*args, **kw):
+        try:
+            fn(*args, **kw)
+        finally:
+            testing_reaper.rollback_all()
+    return function_named(decorated, fn.__name__)
+
+def close_open_connections(fn):
+    """Decorator that closes all connections after fn execution."""
+
+    def decorated(*args, **kw):
+        try:
+            fn(*args, **kw)
+        finally:
+            testing_reaper.close_all()
+    return function_named(decorated, fn.__name__)
+
+def all_dialects():
+    import sqlalchemy.databases as d
+    for name in d.__all__:
+        mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
+        yield mod.dialect()
+        
+class ReconnectFixture(object):
+    def __init__(self, dbapi):
+        self.dbapi = dbapi
+        self.connections = []
+
+    def __getattr__(self, key):
+        return getattr(self.dbapi, key)
+
+    def connect(self, *args, **kwargs):
+        conn = self.dbapi.connect(*args, **kwargs)
+        self.connections.append(conn)
+        return conn
+
+    def shutdown(self):
+        for c in list(self.connections):
+            c.close()
+        self.connections = []
+
+def reconnecting_engine(url=None, options=None):
+    url = url or config.db_url
+    dbapi = config.db.dialect.dbapi
+    if not options:
+        options = {}
+    options['module'] = ReconnectFixture(dbapi)
+    engine = testing_engine(url, options)
+    engine.test_shutdown = engine.dialect.dbapi.shutdown
+    return engine
+
+def testing_engine(url=None, options=None):
+    """Produce an engine configured by --options with optional overrides."""
+
+    from sqlalchemy import create_engine
+    from sqlalchemy.test.assertsql import asserter
+
+    url = url or config.db_url
+    options = options or config.db_opts
+
+    options.setdefault('proxy', asserter)
+    
+    listeners = options.setdefault('listeners', [])
+    listeners.append(testing_reaper)
+
+    engine = create_engine(url, **options)
+
+    return engine
+
+def utf8_engine(url=None, options=None):
+    """Hook for dialects or drivers that don't handle utf8 by default."""
+
+    from sqlalchemy.engine import url as engine_url
+
+    if config.db.name == 'mysql':
+        dbapi_ver = config.db.dialect.dbapi.version_info
+        if (dbapi_ver < (1, 2, 1) or
+            dbapi_ver in ((1, 2, 1, 'gamma', 1), (1, 2, 1, 'gamma', 2),
+                          (1, 2, 1, 'gamma', 3), (1, 2, 1, 'gamma', 5))):
+            raise RuntimeError('Character set support unavailable with this '
+                               'driver version: %s' % repr(dbapi_ver))
+        else:
+            url = url or config.db_url
+            url = engine_url.make_url(url)
+            url.query['charset'] = 'utf8'
+            url.query['use_unicode'] = '0'
+            url = str(url)
+
+    return testing_engine(url, options)
+
+def mock_engine(db=None):
+    """Provides a mocking engine based on the current testing.db."""
+    
+    from sqlalchemy import create_engine
+    
+    dbi = db or config.db
+    buffer = []
+    def executor(sql, *a, **kw):
+        buffer.append(sql)
+    engine = create_engine(dbi.name + '://',
+                           strategy='mock', executor=executor)
+    assert not hasattr(engine, 'mock')
+    engine.mock = buffer
+    return engine
+
+class ReplayableSession(object):
+    """A simple record/playback tool.
+
+    This is *not* a mock testing class.  It only records a session for later
+    playback and makes no assertions on call consistency whatsoever.  It's
+    unlikely to be suitable for anything other than DB-API recording.
+
+    """
+
+    Callable = object()
+    NoAttribute = object()
+    Natives = set([getattr(types, t)
+                   for t in dir(types) if not t.startswith('_')]). \
+                   difference([getattr(types, t)
+                               for t in ('FunctionType', 'BuiltinFunctionType',
+                                         'MethodType', 'BuiltinMethodType',
+                                         'LambdaType', 'UnboundMethodType',)])
+    def __init__(self):
+        self.buffer = deque()
+
+    def recorder(self, base):
+        return self.Recorder(self.buffer, base)
+
+    def player(self):
+        return self.Player(self.buffer)
+
+    class Recorder(object):
+        def __init__(self, buffer, subject):
+            self._buffer = buffer
+            self._subject = subject
+
+        def __call__(self, *args, **kw):
+            subject, buffer = [object.__getattribute__(self, x)
+                               for x in ('_subject', '_buffer')]
+
+            result = subject(*args, **kw)
+            if type(result) not in ReplayableSession.Natives:
+                buffer.append(ReplayableSession.Callable)
+                return type(self)(buffer, result)
+            else:
+                buffer.append(result)
+                return result
+
+        def __getattribute__(self, key):
+            try:
+                return object.__getattribute__(self, key)
+            except AttributeError:
+                pass
+
+            subject, buffer = [object.__getattribute__(self, x)
+                               for x in ('_subject', '_buffer')]
+            try:
+                result = type(subject).__getattribute__(subject, key)
+            except AttributeError:
+                buffer.append(ReplayableSession.NoAttribute)
+                raise
+            else:
+                if type(result) not in ReplayableSession.Natives:
+                    buffer.append(ReplayableSession.Callable)
+                    return type(self)(buffer, result)
+                else:
+                    buffer.append(result)
+                    return result
+
+    class Player(object):
+        def __init__(self, buffer):
+            self._buffer = buffer
+
+        def __call__(self, *args, **kw):
+            buffer = object.__getattribute__(self, '_buffer')
+            result = buffer.popleft()
+            if result is ReplayableSession.Callable:
+                return self
+            else:
+                return result
+
+        def __getattribute__(self, key):
+            try:
+                return object.__getattribute__(self, key)
+            except AttributeError:
+                pass
+            buffer = object.__getattribute__(self, '_buffer')
+            result = buffer.popleft()
+            if result is ReplayableSession.Callable:
+                return self
+            elif result is ReplayableSession.NoAttribute:
+                raise AttributeError(key)
+            else:
+                return result

lib/sqlalchemy/test/noseplugin.py

+import logging
+import os
+import re
+import sys
+import time
+import warnings
+import ConfigParser
+import StringIO
+from config import db, db_label, db_url, file_config, base_config, \
+                           post_configure, \
+                           _list_dbs, _server_side_cursors, _engine_strategy, \
+                           _engine_uri, _require, _engine_pool, \
+                           _create_testing_engine, _prep_testing_database, \
+                           _set_table_options, _reverse_topological, _log
+from sqlalchemy.test import testing, config, requires
+from nose.plugins import Plugin
+from nose.util import tolist
+import nose.case
+
+log = logging.getLogger('nose.plugins.sqlalchemy')
+
+class NoseSQLAlchemy(Plugin):
+    """
+    Handles the setup and extra properties required for testing SQLAlchemy
+    """
+    enabled = True
+    name = 'sqlalchemy'
+    score = 100
+
+    def options(self, parser, env=os.environ):
+        Plugin.options(self, parser, env)
+        opt = parser.add_option
+        #opt("--verbose", action="store_true", dest="verbose",
+            #help="enable stdout echoing/printing")
+        #opt("--quiet", action="store_true", dest="quiet", help="suppress output")
+        opt("--log-info", action="callback", type="string", callback=_log,
+            help="turn on info logging for <LOG> (multiple OK)")
+        opt("--log-debug", action="callback", type="string", callback=_log,
+            help="turn on debug logging for <LOG> (multiple OK)")
+        opt("--require", action="append", dest="require", default=[],
+            help="require a particular driver or module version (multiple OK)")
+        opt("--db", action="store", dest="db", default="sqlite",
+            help="Use prefab database uri")
+        opt('--dbs', action='callback', callback=_list_dbs,
+            help="List available prefab dbs")
+        opt("--dburi", action="store", dest="dburi",
+            help="Database uri (overrides --db)")
+        opt("--dropfirst", action="store_true", dest="dropfirst",
+            help="Drop all tables in the target database first (use with caution on Oracle, MS-SQL)")
+        opt("--mockpool", action="store_true", dest="mockpool",
+            help="Use mock pool (asserts only one connection used)")
+        opt("--enginestrategy", action="callback", type="string",
+            callback=_engine_strategy,
+            help="Engine strategy (plain or threadlocal, defaults to plain)")
+        opt("--reversetop", action="store_true", dest="reversetop", default=False,
+            help="Reverse the collection ordering for topological sorts (helps "
+                  "reveal dependency issues)")
+        opt("--unhashable", action="store_true", dest="unhashable", default=False,
+            help="Disallow SQLAlchemy from performing a hash() on mapped test objects.")
+        opt("--noncomparable", action="store_true", dest="noncomparable", default=False,
+            help="Disallow SQLAlchemy from performing == on mapped test objects.")
+        opt("--truthless", action="store_true", dest="truthless", default=False,
+            help="Disallow SQLAlchemy from truth-evaluating mapped test objects.")
+        opt("--serverside", action="callback", callback=_server_side_cursors,
+            help="Turn on server side cursors for PG")
+        opt("--mysql-engine", action="store", dest="mysql_engine", default=None,
+            help="Use the specified MySQL storage engine for all tables, default is "
+                 "a db-default/InnoDB combo.")
+        opt("--table-option", action="append", dest="tableopts", default=[],
+            help="Add a dialect-specific table option, key=value")
+
+        global file_config
+        file_config = ConfigParser.ConfigParser()
+        file_config.readfp(StringIO.StringIO(base_config))
+        file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')])
+        config.file_config = file_config
+        
+    def configure(self, options, conf):
+        Plugin.configure(self, options, conf)
+
+        import testing, requires
+        testing.db = db
+        testing.requires = requires
+
+        # Lazy setup of other options (post coverage)
+        for fn in post_configure:
+            fn(options, file_config)
+        
+    def describeTest(self, test):
+        return ""
+        
+    def wantClass(self, cls):
+        """Return true if you want the main test selector to collect
+        tests from this class, false if you don't, and None if you don't
+        care.
+
+        :Parameters:
+           cls : class
+             The class being examined by the selector
+
+        """
+
+        if not issubclass(cls, testing.TestBase):
+            return False
+        else:
+            if (hasattr(cls, '__whitelist__') and
+                testing.db.name in cls.__whitelist__):
+                return True
+            else:
+                return not self.__should_skip_for(cls)
+    
+    def __should_skip_for(self, cls):
+        if hasattr(cls, '__requires__'):
+            def test_suite(): return 'ok'
+            for requirement in cls.__requires__:
+                check = getattr(requires, requirement)
+                if check(test_suite)() != 'ok':
+                    # The requirement will perform messaging.
+                    return True
+        if (hasattr(cls, '__unsupported_on__') and
+            testing.db.name in cls.__unsupported_on__):
+            print "'%s' unsupported on DB implementation '%s'" % (
+                cls.__class__.__name__, testing.db.name)
+            return True
+        if (getattr(cls, '__only_on__', None) not in (None, testing.db.name)):
+            print "'%s' unsupported on DB implementation '%s'" % (
+                cls.__class__.__name__, testing.db.name)
+            return True
+        if (getattr(cls, '__skip_if__', False)):
+            for c in getattr(cls, '__skip_if__'):
+                if c():
+                    print "'%s' skipped by %s" % (
+                        cls.__class__.__name__, c.__name__)
+                    return True
+        for rule in getattr(cls, '__excluded_on__', ()):
+            if testing._is_excluded(*rule):
+                print "'%s' unsupported on DB %s version %s" % (
+                    cls.__class__.__name__, testing.db.name,
+                    _server_version())
+                return True
+        return False
+
+    #def begin(self):
+        #pass
+
+    def beforeTest(self, test):
+        testing.resetwarnings()
+
+    def afterTest(self, test):
+        testing.resetwarnings()
+        
+    #def handleError(self, test, err):
+        #pass
+
+    #def finalize(self, result=None):
+        #pass

lib/sqlalchemy/test/orm.py

+import inspect, re
+import config, testing
+from sqlalchemy import orm
+
+__all__ = 'mapper',
+
+
+_whitespace = re.compile(r'^(\s+)')
+
+def _find_pragma(lines, current):
+    m = _whitespace.match(lines[current])
+    basis = m and m.group() or ''
+
+    for line in reversed(lines[0:current]):
+        if 'testlib.pragma' in line:
+            return line
+        m = _whitespace.match(line)
+        indent = m and m.group() or ''
+
+        # simplistic detection:
+
+        # >> # testlib.pragma foo
+        # >> center_line()
+        if indent == basis:
+            break
+        # >> # testlib.pragma foo
+        # >> if fleem:
+        # >>     center_line()
+        if line.endswith(':'):
+            break
+    return None
+
+def _make_blocker(method_name, fallback):
+    """Creates tripwired variant of a method, raising when called.
+
+    To excempt an invocation from blockage, there are two options.
+
+    1) add a pragma in a comment::
+
+        # testlib.pragma exempt:methodname
+        offending_line()
+
+    2) add a magic cookie to the function's namespace::
+        __sa_baremethodname_exempt__ = True
+        ...
+        offending_line()
+        another_offending_lines()
+
+    The second is useful for testing and development.
+    """
+
+    if method_name.startswith('__') and method_name.endswith('__'):
+        frame_marker = '__sa_%s_exempt__' % method_name[2:-2]
+    else:
+        frame_marker = '__sa_%s_exempt__' % method_name
+    pragma_marker = 'exempt:' + method_name
+
+    def method(self, *args, **kw):
+        frame_r = None
+        try:
+            frame = inspect.stack()[1][0]
+            frame_r = inspect.getframeinfo(frame, 9)
+
+            module = frame.f_globals.get('__name__', '')
+
+            type_ = type(self)
+
+            pragma = _find_pragma(*frame_r[3:5])
+
+            exempt = (
+                (not module.startswith('sqlalchemy')) or
+                (pragma and pragma_marker in pragma) or
+                (frame_marker in frame.f_locals) or
+                ('self' in frame.f_locals and
+                 getattr(frame.f_locals['self'], frame_marker, False)))
+
+            if exempt:
+                supermeth = getattr(super(type_, self), method_name, None)
+                if (supermeth is None or
+                    getattr(supermeth, 'im_func', None) is method):
+                    return fallback(self, *args, **kw)
+                else:
+                    return supermeth(*args, **kw)
+            else:
+                raise AssertionError(
+                    "%s.%s called in %s, line %s in %s" % (
+                    type_.__name__, method_name, module, frame_r[1], frame_r[2]))
+        finally:
+            del frame
+    method.__name__ = method_name
+    return method
+
+def mapper(type_, *args, **kw):
+    forbidden = [
+        ('__hash__', 'unhashable', lambda s: id(s)),
+        ('__eq__', 'noncomparable', lambda s, o: s is o),
+        ('__ne__', 'noncomparable', lambda s, o: s is not o),
+        ('__cmp__', 'noncomparable', lambda s, o: object.__cmp__(s, o)),
+        ('__le__', 'noncomparable', lambda s, o: object.__le__(s, o)),
+        ('__lt__', 'noncomparable', lambda s, o: object.__lt__(s, o)),
+        ('__ge__', 'noncomparable', lambda s, o: object.__ge__(s, o)),
+        ('__gt__', 'noncomparable', lambda s, o: object.__gt__(s, o)),
+        ('__nonzero__', 'truthless', lambda s: 1), ]
+
+    if isinstance(type_, type) and type_.__bases__ == (object,):
+        for method_name, option, fallback in forbidden:
+            if (getattr(config.options, option, False) and
+                method_name not in type_.__dict__):
+                setattr(type_, method_name, _make_blocker(method_name, fallback))
+
+    return orm.mapper(type_, *args, **kw)

lib/sqlalchemy/test/pickleable.py

+"""
+
+some objects used for pickle tests, declared in their own module so that they
+are easily pickleable.
+
+"""
+
+
+class Foo(object):
+    def __init__(self, moredata):
+        self.data = 'im data'
+        self.stuff = 'im stuff'
+        self.moredata = moredata
+    __hash__ = object.__hash__
+    def __eq__(self, other):
+        return other.data == self.data and other.stuff == self.stuff and other.moredata==self.moredata
+
+
+class Bar(object):
+    def __init__(self, x, y):
+        self.x = x
+        self.y = y
+    __hash__ = object.__hash__
+    def __eq__(self, other):
+        return other.__class__ is self.__class__ and other.x==self.x and other.y==self.y
+    def __str__(self):
+        return "Bar(%d, %d)" % (self.x, self.y)
+
+class OldSchool:
+    def __init__(self, x, y):
+        self.x = x
+        self.y = y
+    def __eq__(self, other):
+        return other.__class__ is self.__class__ and other.x==self.x and other.y==self.y
+
+class OldSchoolWithoutCompare:    
+    def __init__(self, x, y):
+        self.x = x
+        self.y = y
+    
+class BarWithoutCompare(object):
+    def __init__(self, x, y):
+        self.x = x
+        self.y = y
+    def __str__(self):
+        return "Bar(%d, %d)" % (self.x, self.y)
+
+
+class NotComparable(object):
+    def __init__(self, data):
+        self.data = data
+
+    def __hash__(self):
+        return id(self)
+
+    def __eq__(self, other):
+        return NotImplemented
+
+    def __ne__(self, other):
+        return NotImplemented
+
+
+class BrokenComparable(object):
+    def __init__(self, data):
+        self.data = data
+
+    def __hash__(self):
+        return id(self)
+
+    def __eq__(self, other):
+        raise NotImplementedError
+
+    def __ne__(self, other):
+        raise NotImplementedError
+

lib/sqlalchemy/test/profiling.py

+"""Profiling support for unit and performance tests.
+
+These are special purpose profiling methods which operate
+in a more fine-grained way than nose's profiling plugin.
+
+"""
+
+import os, sys
+from sqlalchemy.util import function_named
+import config
+
+__all__ = 'profiled', 'function_call_count', 'conditional_call_count'
+
+all_targets = set()
+profile_config = { 'targets': set(),
+                   'report': True,
+                   'sort': ('time', 'calls'),
+                   'limit': None }
+profiler = None
+
+def profiled(target=None, **target_opts):
+    """Optional function profiling.
+
+    @profiled('label')
+    or
+    @profiled('label', report=True, sort=('calls',), limit=20)
+
+    Enables profiling for a function when 'label' is targetted for
+    profiling.  Report options can be supplied, and override the global
+    configuration and command-line options.
+    """
+
+    # manual or automatic namespacing by module would remove conflict issues
+    if target is None:
+        target = 'anonymous_target'
+    elif target in all_targets:
+        print "Warning: redefining profile target '%s'" % target
+    all_targets.add(target)
+
+    filename = "%s.prof" % target
+
+    def decorator(fn):
+        def profiled(*args, **kw):
+            if (target not in profile_config['targets'] and
+                not target_opts.get('always', None)):
+                return fn(*args, **kw)
+
+            elapsed, load_stats, result = _profile(
+                filename, fn, *args, **kw)
+
+            report = target_opts.get('report', profile_config['report'])
+            if report:
+                sort_ = target_opts.get('sort', profile_config['sort'])
+                limit = target_opts.get('limit', profile_config['limit'])
+                print "Profile report for target '%s' (%s)" % (
+                    target, filename)
+
+                stats = load_stats()
+                stats.sort_stats(*sort_)
+                if limit:
+                    stats.print_stats(limit)
+                else:
+                    stats.print_stats()
+                #stats.print_callers()
+            os.unlink(filename)
+            return result
+        return function_named(profiled, fn.__name__)
+    return decorator
+
+def function_call_count(count=None, versions={}, variance=0.05):
+    """Assert a target for a test case's function call count.
+
+    count
+      Optional, general target function call count.
+
+    versions
+      Optional, a dictionary of Python version strings to counts,
+      for example::
+
+        { '2.5.1': 110,
+          '2.5': 100,
+          '2.4': 150 }
+
+      The best match for the current running python will be used.
+      If none match, 'count' will be used as the fallback.
+
+    variance
+      An +/- deviation percentage, defaults to 5%.
+    """
+
+    # this could easily dump the profile report if --verbose is in effect
+
+    version_info = list(sys.version_info)
+    py_version = '.'.join([str(v) for v in sys.version_info])
+
+    while version_info:
+        version = '.'.join([str(v) for v in version_info])
+        if version in versions:
+            count = versions[version]
+            break
+        version_info.pop()
+
+    if count is None:
+        return lambda fn: fn
+
+    def decorator(fn):
+        def counted(*args, **kw):
+            try:
+                filename = "%s.prof" % fn.__name__
+
+                elapsed, stat_loader, result = _profile(
+                    filename, fn, *args, **kw)
+
+                stats = stat_loader()
+                calls = stats.total_calls
+
+                stats.sort_stats('calls', 'cumulative')
+                stats.print_stats()
+                #stats.print_callers()
+                deviance = int(count * variance)
+                if (calls < (count - deviance) or
+                    calls > (count + deviance)):
+                    raise AssertionError(
+                        "Function call count %s not within %s%% "
+                        "of expected %s. (Python version %s)" % (
+                        calls, (variance * 100), count, py_version))
+
+                return result
+            finally:
+                if os.path.exists(filename):
+                    os.unlink(filename)
+        return function_named(counted, fn.__name__)
+    return decorator
+
+def conditional_call_count(discriminator, categories):
+    """Apply a function call count conditionally at runtime.
+
+    Takes two arguments, a callable that returns a key value, and a dict
+    mapping key values to a tuple of arguments to function_call_count.
+
+    The callable is not evaluated until the decorated function is actually
+    invoked.  If the `discriminator` returns a key not present in the
+    `categories` dictionary, no call count assertion is applied.
+
+    Useful for integration tests, where running a named test in isolation may
+    have a function count penalty not seen in the full suite, due to lazy
+    initialization in the DB-API, SA, etc.
+    """
+
+    def decorator(fn):
+        def at_runtime(*args, **kw):
+            criteria = categories.get(discriminator(), None)
+            if criteria is None:
+                return fn(*args, **kw)
+
+            rewrapped = function_call_count(*criteria)(fn)
+            return rewrapped(*args, **kw)
+        return function_named(at_runtime, fn.__name__)
+    return decorator
+
+
+def _profile(filename, fn, *args, **kw):
+    global profiler
+    if not profiler:
+        profiler = 'hotshot'
+        if sys.version_info > (2, 5):
+            try:
+                import cProfile
+                profiler = 'cProfile'
+            except ImportError:
+                pass
+
+    if profiler == 'cProfile':
+        return _profile_cProfile(filename, fn, *args, **kw)
+    else:
+        return _profile_hotshot(filename, fn, *args, **kw)
+
+def _profile_cProfile(filename, fn, *args, **kw):
+    import cProfile, gc, pstats, time
+
+    load_stats = lambda: pstats.Stats(filename)
+    gc.collect()
+
+    began = time.time()
+    cProfile.runctx('result = fn(*args, **kw)', globals(), locals(),
+                    filename=filename)
+    ended = time.time()
+
+    return ended - began, load_stats, locals()['result']
+
+def _profile_hotshot(filename, fn, *args, **kw):
+    import gc, hotshot, hotshot.stats, time
+    load_stats = lambda: hotshot.stats.load(filename)
+
+    gc.collect()
+    prof = hotshot.Profile(filename)
+    began = time.time()
+    prof.start()
+    try:
+        result = fn(*args, **kw)
+    finally:
+        prof.stop()
+        ended = time.time()
+        prof.close()
+
+    return ended - began, load_stats, result
+

lib/sqlalchemy/test/requires.py

+"""Global database feature support policy.
+
+Provides decorators to mark tests requiring specific feature support from the
+target database.
+
+"""
+
+from testing import \
+     _block_unconditionally as no_support, \
+     _chain_decorators_on, \
+     exclude, \
+     emits_warning_on
+
+
+def deferrable_constraints(fn):
+    """Target database must support derferable constraints."""
+    return _chain_decorators_on(
+        fn,
+        no_support('firebird', 'not supported by database'),
+        no_support('mysql', 'not supported by database'),
+        no_support('mssql', 'not supported by database'),
+        )
+
+def foreign_keys(fn):
+    """Target database must support foreign keys."""
+    return _chain_decorators_on(
+        fn,
+        no_support('sqlite', 'not supported by database'),
+        )
+
+def identity(fn):
+    """Target database must support GENERATED AS IDENTITY or a facsimile.
+
+    Includes GENERATED AS IDENTITY, AUTOINCREMENT, AUTO_INCREMENT, or other
+    column DDL feature that fills in a DB-generated identifier at INSERT-time
+    without requiring pre-execution of a SEQUENCE or other artifact.
+
+    """
+    return _chain_decorators_on(
+        fn,
+        no_support('firebird', 'not supported by database'),
+        no_support('oracle', 'not supported by database'),
+        no_support('postgres', 'not supported by database'),
+        no_support('sybase', 'not supported by database'),
+        )
+
+def independent_connections(fn):
+    """Target must support simultaneous, independent database connections."""
+
+    # This is also true of some configurations of UnixODBC and probably win32
+    # ODBC as well.
+    return _chain_decorators_on(
+        fn,
+        no_support('sqlite', 'no driver support')
+        )
+
+def row_triggers(fn):
+    """Target must support standard statement-running EACH ROW triggers."""
+    return _chain_decorators_on(
+        fn,
+        # no access to same table
+        no_support('mysql', 'requires SUPER priv'),
+        exclude('mysql', '<', (5, 0, 10), 'not supported by database'),
+        no_support('postgres', 'not supported by database: no statements'),
+        )
+
+def savepoints(fn):
+    """Target database must support savepoints."""
+    return _chain_decorators_on(
+        fn,
+        emits_warning_on('mssql', 'Savepoint support in mssql is experimental and may lead to data loss.'),
+        no_support('access', 'not supported by database'),
+        no_support('sqlite', 'not supported by database'),
+        no_support('sybase', 'FIXME: guessing, needs confirmation'),
+        exclude('mysql', '<', (5, 0, 3), 'not supported by database'),
+        )
+
+def sequences(fn):
+    """Target database must support SEQUENCEs."""
+    return _chain_decorators_on(
+        fn,
+        no_support('access', 'no SEQUENCE support'),
+        no_support('mssql', 'no SEQUENCE support'),
+        no_support('mysql', 'no SEQUENCE support'),
+        no_support('sqlite', 'no SEQUENCE support'),
+        no_support('sybase', 'no SEQUENCE support'),
+        )
+
+def subqueries(fn):
+    """Target database must support subqueries."""
+    return _chain_decorators_on(
+        fn,
+        exclude('mysql', '<', (4, 1, 1), 'no subquery support'),