Source

sqlalchemy / test / sql / test_returning.py

Full commit
from sqlalchemy.testing import eq_
from sqlalchemy import *
from sqlalchemy import testing
from sqlalchemy.testing.schema import Table, Column
from sqlalchemy.types import TypeDecorator
from sqlalchemy.testing import fixtures, AssertsExecutionResults, engines, \
        assert_raises_message
from sqlalchemy import exc as sa_exc
import itertools

class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
    __requires__ = 'returning',

    def setup(self):
        meta = MetaData(testing.db)
        global table, GoofyType

        class GoofyType(TypeDecorator):
            impl = String

            def process_bind_param(self, value, dialect):
                if value is None:
                    return None
                return "FOO" + value

            def process_result_value(self, value, dialect):
                if value is None:
                    return None
                return value + "BAR"

        table = Table('tables', meta,
            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
            Column('persons', Integer),
            Column('full', Boolean),
            Column('goofy', GoofyType(50))
        )
        table.create(checkfirst=True)

    def teardown(self):
        table.drop()

    def test_column_targeting(self):
        result = table.insert().returning(table.c.id, table.c.full).execute({'persons': 1, 'full': False})

        row = result.first()
        assert row[table.c.id] == row['id'] == 1
        assert row[table.c.full] == row['full'] == False

        result = table.insert().values(persons=5, full=True, goofy="somegoofy").\
                            returning(table.c.persons, table.c.full, table.c.goofy).execute()
        row = result.first()
        assert row[table.c.persons] == row['persons'] == 5
        assert row[table.c.full] == row['full'] == True

        eq_(row[table.c.goofy], row['goofy'])
        eq_(row['goofy'], "FOOsomegoofyBAR")

    @testing.fails_on('firebird', "fb can't handle returning x AS y")
    def test_labeling(self):
        result = table.insert().values(persons=6).\
                            returning(table.c.persons.label('lala')).execute()
        row = result.first()
        assert row['lala'] == 6

    @testing.fails_on('firebird', "fb/kintersbasdb can't handle the bind params")
    @testing.fails_on('oracle+zxjdbc', "JDBC driver bug")
    def test_anon_expressions(self):
        result = table.insert().values(goofy="someOTHERgoofy").\
                            returning(func.lower(table.c.goofy, type_=GoofyType)).execute()
        row = result.first()
        eq_(row[0], "foosomeothergoofyBAR")

        result = table.insert().values(persons=12).\
                            returning(table.c.persons + 18).execute()
        row = result.first()
        eq_(row[0], 30)

    def test_update_returning(self):
        table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])

        result = table.update(table.c.persons > 4, dict(full=True)).returning(table.c.id).execute()
        eq_(result.fetchall(), [(1,)])

        result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
        eq_(result2.fetchall(), [(1, True), (2, False)])

    def test_insert_returning(self):
        result = table.insert().returning(table.c.id).execute({'persons': 1, 'full': False})

        eq_(result.fetchall(), [(1,)])

    @testing.requires.multivalues_inserts
    def test_multirow_returning(self):
        ins = table.insert().returning(table.c.id, table.c.persons).values(
                            [
                                {'persons': 1, 'full': False},
                                {'persons': 2, 'full': True},
                                {'persons': 3, 'full': False},
                            ]
                        )
        result = testing.db.execute(ins)
        eq_(
                result.fetchall(),
                 [(1, 1), (2, 2), (3, 3)]
        )

    def test_no_ipk_on_returning(self):
        result = testing.db.execute(
                    table.insert().returning(table.c.id),
                    {'persons': 1, 'full': False}
                )
        assert_raises_message(
            sa_exc.InvalidRequestError,
            "Can't call inserted_primary_key when returning\(\) is used.",
            getattr, result, "inserted_primary_key"
        )

    @testing.fails_on_everything_except('postgresql', 'firebird')
    def test_literal_returning(self):
        if testing.against("postgresql"):
            literal_true = "true"
        else:
            literal_true = "1"

        result4 = testing.db.execute('insert into tables (id, persons, "full") '
                                        'values (5, 10, %s) returning persons' % literal_true)
        eq_([dict(row) for row in result4], [{'persons': 10}])

    def test_delete_returning(self):
        table.insert().execute([{'persons': 5, 'full': False}, {'persons': 3, 'full': False}])

        result = table.delete(table.c.persons > 4).returning(table.c.id).execute()
        eq_(result.fetchall(), [(1,)])

        result2 = select([table.c.id, table.c.full]).order_by(table.c.id).execute()
        eq_(result2.fetchall(), [(2, False),])

class SequenceReturningTest(fixtures.TestBase):
    __requires__ = 'returning', 'sequences'

    def setup(self):
        meta = MetaData(testing.db)
        global table, seq
        seq = Sequence('tid_seq')
        table = Table('tables', meta,
                    Column('id', Integer, seq, primary_key=True),
                    Column('data', String(50))
                )
        table.create(checkfirst=True)

    def teardown(self):
        table.drop()

    def test_insert(self):
        r = table.insert().values(data='hi').returning(table.c.id).execute()
        assert r.first() == (1, )
        assert seq.execute() == 2

class KeyReturningTest(fixtures.TestBase, AssertsExecutionResults):
    """test returning() works with columns that define 'key'."""

    __requires__ = 'returning',

    def setup(self):
        meta = MetaData(testing.db)
        global table

        table = Table('tables', meta,
            Column('id', Integer, primary_key=True, key='foo_id', test_needs_autoincrement=True),
            Column('data', String(20)),
        )
        table.create(checkfirst=True)

    def teardown(self):
        table.drop()

    @testing.exclude('firebird', '<', (2, 0), '2.0+ feature')
    @testing.exclude('postgresql', '<', (8, 2), '8.2+ feature')
    def test_insert(self):
        result = table.insert().returning(table.c.foo_id).execute(data='somedata')
        row = result.first()
        assert row[table.c.foo_id] == row['id'] == 1

        result = table.select().execute().first()
        assert row[table.c.foo_id] == row['id'] == 1


class ReturnDefaultsTest(fixtures.TablesTest):
    __requires__ = ('returning', )
    run_define_tables = 'each'

    @classmethod
    def define_tables(cls, metadata):
        from sqlalchemy.sql import ColumnElement
        from sqlalchemy.ext.compiler import compiles

        counter = itertools.count()

        class IncDefault(ColumnElement):
            pass

        @compiles(IncDefault)
        def compile(element, compiler, **kw):
            return str(next(counter))

        Table("t1", metadata,
                Column("id", Integer, primary_key=True, test_needs_autoincrement=True),
                Column("data", String(50)),
                Column("insdef", Integer, default=IncDefault()),
                Column("upddef", Integer, onupdate=IncDefault())
            )

    def test_chained_insert_pk(self):
        t1 = self.tables.t1
        result = testing.db.execute(
                        t1.insert().values(upddef=1).return_defaults(t1.c.insdef)
                    )
        eq_(
            [result.returned_defaults[k] for k in (t1.c.id, t1.c.insdef)],
            [1, 0]
        )

    def test_arg_insert_pk(self):
        t1 = self.tables.t1
        result = testing.db.execute(
                        t1.insert(return_defaults=[t1.c.insdef]).values(upddef=1)
                    )
        eq_(
            [result.returned_defaults[k] for k in (t1.c.id, t1.c.insdef)],
            [1, 0]
        )

    def test_chained_update_pk(self):
        t1 = self.tables.t1
        testing.db.execute(
                        t1.insert().values(upddef=1)
                    )
        result = testing.db.execute(t1.update().values(data='d1').
                            return_defaults(t1.c.upddef))
        eq_(
            [result.returned_defaults[k] for k in (t1.c.upddef,)],
            [1]
        )

    def test_arg_update_pk(self):
        t1 = self.tables.t1
        testing.db.execute(
                        t1.insert().values(upddef=1)
                    )
        result = testing.db.execute(t1.update(return_defaults=[t1.c.upddef]).
                            values(data='d1'))
        eq_(
            [result.returned_defaults[k] for k in (t1.c.upddef,)],
            [1]
        )

    def test_insert_non_default(self):
        """test that a column not marked at all as a
        default works with this feature."""

        t1 = self.tables.t1
        result = testing.db.execute(
                        t1.insert().values(upddef=1).return_defaults(t1.c.data)
                    )
        eq_(
            [result.returned_defaults[k] for k in (t1.c.id, t1.c.data,)],
            [1, None]
        )

    def test_update_non_default(self):
        """test that a column not marked at all as a
        default works with this feature."""

        t1 = self.tables.t1
        testing.db.execute(
                        t1.insert().values(upddef=1)
                    )
        result = testing.db.execute(t1.update().
                            values(upddef=2).return_defaults(t1.c.data))
        eq_(
            [result.returned_defaults[k] for k in (t1.c.data,)],
            [None]
        )

    #@testing.fails_on("oracle+cx_oracle", "seems like a cx_oracle bug")
    def test_insert_non_default_plus_default(self):
        t1 = self.tables.t1
        result = testing.db.execute(
                        t1.insert().values(upddef=1).return_defaults(
                                                    t1.c.data, t1.c.insdef)
                    )
        eq_(
            dict(result.returned_defaults),
            {"id": 1, "data": None, "insdef": 0}
        )

    @testing.fails_on("oracle+cx_oracle", "seems like a cx_oracle bug")
    def test_update_non_default_plus_default(self):
        t1 = self.tables.t1
        testing.db.execute(
                        t1.insert().values(upddef=1)
                    )
        result = testing.db.execute(t1.update().
                            values(insdef=2).return_defaults(
                                                t1.c.data, t1.c.upddef))
        eq_(
            dict(result.returned_defaults),
            {"data": None, 'upddef': 1}
        )

class ImplicitReturningFlag(fixtures.TestBase):
    def test_flag_turned_off(self):
        e = engines.testing_engine(options={'implicit_returning':False})
        assert e.dialect.implicit_returning is False
        c = e.connect()
        assert e.dialect.implicit_returning is False

    def test_flag_turned_on(self):
        e = engines.testing_engine(options={'implicit_returning':True})
        assert e.dialect.implicit_returning is True
        c = e.connect()
        assert e.dialect.implicit_returning is True

    def test_flag_turned_default(self):
        supports = [False]
        def go():
            supports[0] = True
        testing.requires.returning(go)()
        e = engines.testing_engine()

        # starts as False.  This is because all of Firebird,
        # Postgresql, Oracle, SQL Server started supporting RETURNING
        # as of a certain version, and the flag is not set until
        # version detection occurs.  If some DB comes along that has
        # RETURNING in all cases, this test can be adjusted.
        assert e.dialect.implicit_returning is False

        # version detection on connect sets it
        c = e.connect()
        assert e.dialect.implicit_returning is supports[0]