Mike Bayer avatar Mike Bayer committed d5b63e5

- test updates

Comments (0)

Files changed (6)

lib/sqlalchemy/testing/exclusions.py

 from ..util import decorator
 from . import config
 from .. import util
+import contextlib
 
 
-class fails_if(object):
-    def __init__(self, predicate, reason=None):
-        self.predicate = _as_predicate(predicate)
-        self.reason = reason
-
-    @property
-    def enabled(self):
-        return not self.predicate()
-
-    def __call__(self, fn):
-        @decorator
-        def decorate(fn, *args, **kw):
-            if not self.predicate():
-                return fn(*args, **kw)
-            else:
-                try:
-                    fn(*args, **kw)
-                except Exception, ex:
-                    print ("'%s' failed as expected (%s): %s " % (
-                        fn.__name__, self.predicate, str(ex)))
-                    return True
-                else:
-                    raise AssertionError(
-                        "Unexpected success for '%s' (%s)" %
-                        (fn.__name__, self.predicate))
-        return decorate(fn)
-
 class skip_if(object):
     def __init__(self, predicate, reason=None):
         self.predicate = _as_predicate(predicate)
         self.reason = reason
 
+    _fails_on = None
+
     @property
     def enabled(self):
         return not self.predicate()
 
+    @contextlib.contextmanager
+    def fail_if(self, name='block'):
+        try:
+            yield
+        except Exception, ex:
+            if self.predicate():
+                print ("%s failed as expected (%s): %s " % (
+                    name, self.predicate, str(ex)))
+            else:
+                raise
+        else:
+            if self.predicate():
+                raise AssertionError(
+                    "Unexpected success for '%s' (%s)" %
+                    (name, self.predicate))
+
     def __call__(self, fn):
         @decorator
         def decorate(fn, *args, **kw):
                         )
                 raise SkipTest(msg)
             else:
+                if self._fails_on:
+                    with self._fails_on.fail_if(name=fn.__name__):
+                        return fn(*args, **kw)
+                else:
+                    return fn(*args, **kw)
+        return decorate(fn)
+
+    def fails_on(self, other, reason=None):
+        self._fails_on = skip_if(other, reason)
+        return self
+
+class fails_if(skip_if):
+    def __call__(self, fn):
+        @decorator
+        def decorate(fn, *args, **kw):
+            with self.fail_if(name=fn.__name__):
                 return fn(*args, **kw)
         return decorate(fn)
 
         )
 
 def open():
-    return skip_if(BooleanPredicate(False))
+    return skip_if(BooleanPredicate(False, "mark as execute"))
 
 def closed():
-    return skip_if(BooleanPredicate(True))
+    return skip_if(BooleanPredicate(True, "marked as skip"))
 
 @decorator
 def future(fn, *args, **kw):

lib/sqlalchemy/testing/requirements.py

         return exclusions.open()
 
     @property
+    def self_referential_foreign_keys(self):
+        """Target database must support self-referential foreign keys."""
+
+        return exclusions.open()
+
+    @property
     def foreign_key_ddl(self):
         """Target database must support the DDL phrases for FOREIGN KEY."""
 
         return exclusions.open()
 
     @property
+    def named_constraints(self):
+        """target database must support names for constraints."""
+
+        return exclusions.open()
+
+    @property
     def autoincrement_insert(self):
         """target platform generates new surrogate integer primary key values
         when insert() is executed, excluding the pk column."""

lib/sqlalchemy/testing/suite/test_insert.py

         assert r.is_insert
         assert not r.returns_rows
 
+class ReturningTest(fixtures.TablesTest):
+    run_deletes = 'each'
+    __requires__ = 'returning',
 
-__all__ = ('InsertSequencingTest', 'InsertBehaviorTest')
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('autoinc_pk', metadata,
+                Column('id', Integer, primary_key=True, \
+                                test_needs_autoincrement=True),
+                Column('data', String(50))
+            )
 
+    def test_explicit_returning_pk(self):
+        engine = config.db
+        table = self.tables.autoinc_pk
+        r = engine.execute(
+            table.insert().returning(
+                            table.c.id),
+            data="some data"
+        )
+        pk = r.first()[0]
+        fetched_pk = config.db.scalar(select([table.c.id]))
+        eq_(fetched_pk, pk)
 
+
+
+__all__ = ('InsertSequencingTest', 'InsertBehaviorTest', 'ReturningTest')
+
+

lib/sqlalchemy/testing/suite/test_reflection.py

         else:
             schema_prefix = ""
 
-        users = Table('users', metadata,
-            Column('user_id', sa.INT, primary_key=True),
-            Column('test1', sa.CHAR(5), nullable=False),
-            Column('test2', sa.Float(5), nullable=False),
-            Column('parent_user_id', sa.Integer,
-                        sa.ForeignKey('%susers.user_id' % schema_prefix)),
-            schema=schema,
-            test_needs_fk=True,
-        )
+        if testing.requires.self_referential_foreign_keys.enabled:
+            users = Table('users', metadata,
+                Column('user_id', sa.INT, primary_key=True),
+                Column('test1', sa.CHAR(5), nullable=False),
+                Column('test2', sa.Float(5), nullable=False),
+                Column('parent_user_id', sa.Integer,
+                            sa.ForeignKey('%susers.user_id' % schema_prefix)),
+                schema=schema,
+                test_needs_fk=True,
+            )
+        else:
+            users = Table('users', metadata,
+                Column('user_id', sa.INT, primary_key=True),
+                Column('test1', sa.CHAR(5), nullable=False),
+                Column('test2', sa.Float(5), nullable=False),
+                schema=schema,
+                test_needs_fk=True,
+            )
+
         Table("dingalings", metadata,
                   Column('dingaling_id', sa.Integer, primary_key=True),
                   Column('address_id', sa.Integer,
         addr_pkeys = addr_cons['constrained_columns']
         eq_(addr_pkeys,  ['address_id'])
 
-        @testing.requires.reflects_pk_names
-        def go():
+        with testing.requires.reflects_pk_names.fail_if():
             eq_(addr_cons['name'], 'email_ad_pk')
-        go()
 
     @testing.requires.primary_key_constraint_reflection
     def test_get_pk_constraint(self):
         self._test_get_pk_constraint()
 
     @testing.requires.table_reflection
-    @testing.fails_on('sqlite', 'no schemas')
+    @testing.requires.schemas
     def test_get_pk_constraint_with_schema(self):
         self._test_get_pk_constraint(schema='test_schema')
 
                                             schema=schema)
         fkey1 = users_fkeys[0]
 
-        @testing.fails_on('sqlite', 'no support for constraint names')
-        def go():
+        with testing.requires.named_constraints.fail_if():
             self.assert_(fkey1['name'] is not None)
-        go()
 
         eq_(fkey1['referred_schema'], expected_schema)
         eq_(fkey1['referred_table'], users.name)
         eq_(fkey1['referred_columns'], ['user_id', ])
-        eq_(fkey1['constrained_columns'], ['parent_user_id'])
+        if testing.requires.self_referential_foreign_keys.enabled:
+            eq_(fkey1['constrained_columns'], ['parent_user_id'])
+
         #addresses
         addr_fkeys = insp.get_foreign_keys(addresses.name,
                                            schema=schema)
         fkey1 = addr_fkeys[0]
-        @testing.fails_on('sqlite', 'no support for constraint names')
-        def go():
+
+        with testing.requires.named_constraints.fail_if():
             self.assert_(fkey1['name'] is not None)
-        go()
+
         eq_(fkey1['referred_schema'], expected_schema)
         eq_(fkey1['referred_table'], users.name)
         eq_(fkey1['referred_columns'], ['user_id', ])

test/engine/test_reflection.py

     @testing.provide_metadata
     def test_two_foreign_keys(self):
         meta = self.metadata
-        t1 = Table(
+        Table(
             't1',
             meta,
             Column('id', sa.Integer, primary_key=True),
             Column('t3id', sa.Integer, sa.ForeignKey('t3.id')),
             test_needs_fk=True,
             )
-        t2 = Table('t2', meta, Column('id', sa.Integer,
-                   primary_key=True), test_needs_fk=True)
-        t3 = Table('t3', meta, Column('id', sa.Integer,
-                   primary_key=True), test_needs_fk=True)
+        Table('t2', meta,
+                    Column('id', sa.Integer, primary_key=True),
+                    test_needs_fk=True)
+        Table('t3', meta,
+                    Column('id', sa.Integer, primary_key=True),
+                    test_needs_fk=True)
         meta.create_all()
         meta2 = MetaData()
         t1r, t2r, t3r = [Table(x, meta2, autoload=True,

test/requirements.py

      skip_if,\
      only_if,\
      only_on,\
-     fails_on,\
      fails_on_everything_except,\
      fails_if,\
      SpecPredicate,\
 def exclude(db, op, spec, description=None):
     return SpecPredicate(db, op, spec, description=description)
 
-
-crashes = skip
-
-def _chain_decorators_on(*decorators):
-    def decorate(fn):
-        for decorator in reversed(decorators):
-            fn = decorator(fn)
-        return fn
-    return decorate
-
 class DefaultRequirements(SuiteRequirements):
     @property
     def deferrable_or_no_constraints(self):
             ])
 
     @property
+    def named_constraints(self):
+        """target database must support names for constraints."""
+
+        return skip_if([
+            no_support('sqlite', 'not supported by database'),
+            ])
+
+    @property
     def foreign_keys(self):
         """Target database must support foreign keys."""
 
                 no_support('sqlite', 'not supported by database')
             )
 
-
     @property
     def unbounded_varchar(self):
         """Target database must support VARCHAR with no length"""
 
     @property
     def isolation_level(self):
-        return _chain_decorators_on(
-            only_on(('postgresql', 'sqlite', 'mysql'),
-                        "DBAPI has no isolation level support"),
-            fails_on('postgresql+pypostgresql',
+        return only_on(
+                    ('postgresql', 'sqlite', 'mysql'),
+                    "DBAPI has no isolation level support"
+                ).fails_on('postgresql+pypostgresql',
                           'pypostgresql bombs on multiple isolation level calls')
-        )
 
     @property
     def row_triggers(self):
     @property
     def nullsordering(self):
         """Target backends that support nulls ordering."""
-        return _chain_decorators_on(
-            fails_on_everything_except('postgresql', 'oracle', 'firebird')
-        )
+        return fails_on_everything_except('postgresql', 'oracle', 'firebird')
 
     @property
     def reflects_pk_names(self):
         """Target driver reflects the name of primary key constraints."""
-        return _chain_decorators_on(
-            fails_on_everything_except('postgresql', 'oracle')
-        )
+
+        return fails_on_everything_except('postgresql', 'oracle')
 
     @property
     def python2(self):
-        return _chain_decorators_on(
-            skip_if(
+        return skip_if(
                 lambda: sys.version_info >= (3,),
                 "Python version 2.xx is required."
                 )
-        )
 
     @property
     def python3(self):
-        return _chain_decorators_on(
-            skip_if(
+        return skip_if(
                 lambda: sys.version_info < (3,),
                 "Python version 3.xx is required."
                 )
-        )
 
     @property
     def python26(self):
-        return _chain_decorators_on(
-            skip_if(
+        return skip_if(
                 lambda: sys.version_info < (2, 6),
                 "Python version 2.6 or greater is required"
             )
-        )
 
     @property
     def python25(self):
-        return _chain_decorators_on(
-            skip_if(
+        return skip_if(
                 lambda: sys.version_info < (2, 5),
                 "Python version 2.5 or greater is required"
             )
-        )
 
     @property
     def cpython(self):
-        return _chain_decorators_on(
-             only_if(lambda: util.cpython,
+        return only_if(lambda: util.cpython,
                "cPython interpreter needed"
              )
-        )
 
     @property
     def predictable_gc(self):
 
     @property
     def sqlite(self):
-        return _chain_decorators_on(
-            skip_if(lambda: not self._has_sqlite())
-        )
+        return skip_if(lambda: not self._has_sqlite())
 
     @property
     def ad_hoc_engines(self):
         as not present.
 
         """
-        return _chain_decorators_on(
-            skip_if(lambda: self.config.options.low_connections)
-        )
+        return skip_if(lambda: self.config.options.low_connections)
 
     @property
     def skip_mysql_on_windows(self):
         """Catchall for a large variety of MySQL on Windows failures"""
 
-        return _chain_decorators_on(
-            skip_if(self._has_mysql_on_windows,
+        return skip_if(self._has_mysql_on_windows,
                 "Not supported on MySQL + Windows"
             )
-        )
 
     @property
     def english_locale_on_postgresql(self):
-        return _chain_decorators_on(
-            skip_if(lambda: against('postgresql') \
+        return skip_if(lambda: against('postgresql') \
                     and not self.db.scalar('SHOW LC_COLLATE').startswith('en'))
-        )
 
     @property
     def selectone(self):
         """target driver must support the literal statement 'select 1'"""
-        return _chain_decorators_on(
-            skip_if(lambda: against('oracle'),
-                "non-standard SELECT scalar syntax"),
-            skip_if(lambda: against('firebird'),
-                "non-standard SELECT scalar syntax")
-        )
+        return skip_if(["oracle", "firebird"], "non-standard SELECT scalar syntax")
 
     def _has_cextensions(self):
         try:
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.