Mike Bayer avatar Mike Bayer committed 4d2aa4b

break out _save_obj(), _delete_obj(), _post_update() into a new module
persistence.py - Mapper loses awareness of how to emit INSERT/UPDATE/DELETE,
persistence.py is only used by unitofwork.py. Then break each method out
into a top level with almost no logic, calling into _organize_states_for_XYZ(),
_collect_XYZ_commands(), _emit_XYZ_statements().

Comments (0)

Files changed (5)

lib/sqlalchemy/orm/mapper.py

             self._memoized_values[key] = value = callable_()
             return value
 
-    def _post_update(self, states, uowtransaction, post_update_cols):
-        """Issue UPDATE statements on behalf of a relationship() which
-        specifies post_update.
-
-        """
-        cached_connections = util.PopulateDict(
-            lambda conn:conn.execution_options(
-            compiled_cache=self._compiled_cache
-        ))
-
-        # if session has a connection callable,
-        # organize individual states with the connection 
-        # to use for update
-        if uowtransaction.session.connection_callable:
-            connection_callable = \
-                    uowtransaction.session.connection_callable
-        else:
-            connection = uowtransaction.transaction.connection(self)
-            connection_callable = None
-
-        tups = []
-        for state in _sort_states(states):
-            if connection_callable:
-                conn = connection_callable(self, state.obj())
-            else:
-                conn = connection
-
-            mapper = _state_mapper(state)
-
-            tups.append((state, state.dict, mapper, conn))
-
-        table_to_mapper = self._sorted_tables
-
-        for table in table_to_mapper:
-            update = []
-
-            for state, state_dict, mapper, connection in tups:
-                if table not in mapper._pks_by_table:
-                    continue
-
-                pks = mapper._pks_by_table[table]
-                params = {}
-                hasdata = False
-
-                for col in mapper._cols_by_table[table]:
-                    if col in pks:
-                        params[col._label] = \
-                                mapper._get_state_attr_by_column(
-                                                state,
-                                                state_dict, col)
-                    elif col in post_update_cols:
-                        prop = mapper._columntoproperty[col]
-                        history = attributes.get_state_history(
-                                    state, prop.key, 
-                                    attributes.PASSIVE_NO_INITIALIZE)
-                        if history.added:
-                            value = history.added[0]
-                            params[col.key] = value
-                            hasdata = True
-                if hasdata:
-                    update.append((state, state_dict, params, mapper, 
-                                    connection))
-
-            if update:
-                mapper = table_to_mapper[table]
-
-                def update_stmt():
-                    clause = sql.and_()
-
-                    for col in mapper._pks_by_table[table]:
-                        clause.clauses.append(col == sql.bindparam(col._label,
-                                                        type_=col.type))
-
-                    return table.update(clause)
-
-                statement = self._memo(('post_update', table), update_stmt)
-
-                # execute each UPDATE in the order according to the original
-                # list of states to guarantee row access order, but
-                # also group them into common (connection, cols) sets 
-                # to support executemany().
-                for key, grouper in groupby(
-                    update, lambda rec: (rec[4], rec[2].keys())
-                ):
-                    multiparams = [params for state, state_dict, 
-                                            params, mapper, conn in grouper]
-                    cached_connections[connection].\
-                                        execute(statement, multiparams)
-
-    def _save_obj(self, states, uowtransaction, single=False):
-        """Issue ``INSERT`` and/or ``UPDATE`` statements for a list 
-        of objects.
-
-        This is called within the context of a UOWTransaction during a
-        flush operation, given a list of states to be flushed.  The
-        base mapper in an inheritance hierarchy handles the inserts/
-        updates for all descendant mappers.
-
-        """
-
-        # if batch=false, call _save_obj separately for each object
-        if not single and not self.batch:
-            for state in _sort_states(states):
-                self._save_obj([state], 
-                                uowtransaction, 
-                                single=True)
-            return
-
-        # if session has a connection callable,
-        # organize individual states with the connection 
-        # to use for insert/update
-        if uowtransaction.session.connection_callable:
-            connection_callable = \
-                    uowtransaction.session.connection_callable
-        else:
-            connection = uowtransaction.transaction.connection(self)
-            connection_callable = None
-
-        tups = []
-
-        for state in _sort_states(states):
-            if connection_callable:
-                conn = connection_callable(self, state.obj())
-            else:
-                conn = connection
-
-            has_identity = bool(state.key)
-            mapper = _state_mapper(state)
-            instance_key = state.key or mapper._identity_key_from_state(state)
-
-            row_switch = None
-
-            # call before_XXX extensions
-            if not has_identity:
-                mapper.dispatch.before_insert(mapper, conn, state)
-            else:
-                mapper.dispatch.before_update(mapper, conn, state)
-
-            # detect if we have a "pending" instance (i.e. has 
-            # no instance_key attached to it), and another instance 
-            # with the same identity key already exists as persistent. 
-            # convert to an UPDATE if so.
-            if not has_identity and \
-                instance_key in uowtransaction.session.identity_map:
-                instance = \
-                    uowtransaction.session.identity_map[instance_key]
-                existing = attributes.instance_state(instance)
-                if not uowtransaction.is_deleted(existing):
-                    raise orm_exc.FlushError(
-                        "New instance %s with identity key %s conflicts "
-                        "with persistent instance %s" % 
-                        (state_str(state), instance_key,
-                         state_str(existing)))
-
-                self._log_debug(
-                    "detected row switch for identity %s.  "
-                    "will update %s, remove %s from "
-                    "transaction", instance_key, 
-                    state_str(state), state_str(existing))
-
-                # remove the "delete" flag from the existing element
-                uowtransaction.remove_state_actions(existing)
-                row_switch = existing
-
-            tups.append(
-                (state, state.dict, mapper, conn, 
-                has_identity, instance_key, row_switch)
-            )
-
-        # dictionary of connection->connection_with_cache_options.
-        cached_connections = util.PopulateDict(
-            lambda conn:conn.execution_options(
-            compiled_cache=self._compiled_cache
-        ))
-
-        table_to_mapper = self._sorted_tables
-
-        for table in table_to_mapper:
-            insert = []
-            update = []
-
-            for state, state_dict, mapper, connection, has_identity, \
-                            instance_key, row_switch in tups:
-                if table not in mapper._pks_by_table:
-                    continue
-
-                pks = mapper._pks_by_table[table]
-
-                isinsert = not has_identity and not row_switch
-
-                params = {}
-                value_params = {}
-
-                if isinsert:
-                    has_all_pks = True
-                    for col in mapper._cols_by_table[table]:
-                        if col is mapper.version_id_col:
-                            params[col.key] = \
-                              mapper.version_id_generator(None)
-                        else:
-                            # pull straight from the dict for 
-                            # pending objects
-                            prop = mapper._columntoproperty[col]
-                            value = state_dict.get(prop.key, None)
-
-                            if value is None:
-                                if col in pks:
-                                    has_all_pks = False
-                                elif col.default is None and \
-                                     col.server_default is None:
-                                    params[col.key] = value
-
-                            elif isinstance(value, sql.ClauseElement):
-                                value_params[col] = value
-                            else:
-                                params[col.key] = value
-
-                    insert.append((state, state_dict, params, mapper, 
-                                    connection, value_params, has_all_pks))
-                else:
-                    hasdata = hasnull = False
-                    for col in mapper._cols_by_table[table]:
-                        if col is mapper.version_id_col:
-                            params[col._label] = \
-                                mapper._get_committed_state_attr_by_column(
-                                            row_switch or state, 
-                                            row_switch and row_switch.dict 
-                                                        or state_dict,
-                                            col)
-
-                            prop = mapper._columntoproperty[col]
-                            history = attributes.get_state_history(
-                                state, prop.key, 
-                                attributes.PASSIVE_NO_INITIALIZE
-                            )
-                            if history.added:
-                                params[col.key] = history.added[0]
-                                hasdata = True
-                            else:
-                                params[col.key] = \
-                                            mapper.version_id_generator(
-                                                params[col._label])
-
-                                # HACK: check for history, in case the 
-                                # history is only
-                                # in a different table than the one 
-                                # where the version_id_col is.
-                                for prop in mapper._columntoproperty.\
-                                                        itervalues():
-                                    history = attributes.get_state_history(
-                                            state, prop.key, 
-                                            attributes.PASSIVE_NO_INITIALIZE)
-                                    if history.added:
-                                        hasdata = True
-                        else:
-                            prop = mapper._columntoproperty[col]
-                            history = attributes.get_state_history(
-                                            state, prop.key, 
-                                            attributes.PASSIVE_NO_INITIALIZE)
-                            if history.added:
-                                if isinstance(history.added[0],
-                                                sql.ClauseElement):
-                                    value_params[col] = history.added[0]
-                                else:
-                                    value = history.added[0]
-                                    params[col.key] = value
-
-                                if col in pks:
-                                    if history.deleted and \
-                                        not row_switch:
-                                        # if passive_updates and sync detected
-                                        # this was a  pk->pk sync, use the new
-                                        # value to locate the row, since the
-                                        # DB would already have set this
-                                        if ("pk_cascaded", state, col) in \
-                                                        uowtransaction.\
-                                                        attributes:
-                                            value = history.added[0]
-                                            params[col._label] = value
-                                        else:
-                                            # use the old value to 
-                                            # locate the row
-                                            value = history.deleted[0]
-                                            params[col._label] = value
-                                        hasdata = True
-                                    else:
-                                        # row switch logic can reach us here
-                                        # remove the pk from the update params
-                                        # so the update doesn't
-                                        # attempt to include the pk in the
-                                        # update statement
-                                        del params[col.key]
-                                        value = history.added[0]
-                                        params[col._label] = value
-                                    if value is None:
-                                        hasnull = True
-                                else:
-                                    hasdata = True
-                            elif col in pks:
-                                value = state.manager[prop.key].\
-                                            impl.get(state, state_dict)
-                                if value is None:
-                                    hasnull = True
-                                params[col._label] = value
-                    if hasdata:
-                        if hasnull:
-                            raise sa_exc.FlushError(
-                                        "Can't update table "
-                                        "using NULL for primary "
-                                        "key value")
-                        update.append((state, state_dict, params, mapper, 
-                                        connection, value_params))
-
-            if update:
-                mapper = table_to_mapper[table]
-
-                needs_version_id = mapper.version_id_col is not None and \
-                            table.c.contains_column(mapper.version_id_col)
-
-                def update_stmt():
-                    clause = sql.and_()
-
-                    for col in mapper._pks_by_table[table]:
-                        clause.clauses.append(col == sql.bindparam(col._label,
-                                                        type_=col.type))
-
-                    if needs_version_id:
-                        clause.clauses.append(mapper.version_id_col ==\
-                                sql.bindparam(mapper.version_id_col._label,
-                                                type_=col.type))
-
-                    return table.update(clause)
-
-                statement = self._memo(('update', table), update_stmt)
-
-                rows = 0
-                for state, state_dict, params, mapper, \
-                            connection, value_params in update:
-
-                    if value_params:
-                        c = connection.execute(
-                                            statement.values(value_params),
-                                            params)
-                    else:
-                        c = cached_connections[connection].\
-                                            execute(statement, params)
-
-                    mapper._postfetch(
-                            uowtransaction, 
-                            table, 
-                            state, 
-                            state_dict, 
-                            c.context.prefetch_cols, 
-                            c.context.postfetch_cols,
-                            c.context.compiled_parameters[0], 
-                            value_params)
-                    rows += c.rowcount
-
-                if connection.dialect.supports_sane_rowcount:
-                    if rows != len(update):
-                        raise orm_exc.StaleDataError(
-                                "UPDATE statement on table '%s' expected to update %d row(s); "
-                                "%d were matched." %
-                                (table.description, len(update), rows))
-
-                elif needs_version_id:
-                    util.warn("Dialect %s does not support updated rowcount "
-                            "- versioning cannot be verified." % 
-                            c.dialect.dialect_description,
-                            stacklevel=12)
-
-            if insert:
-                statement = self._memo(('insert', table), table.insert)
-
-                for (connection, pkeys, hasvalue, has_all_pks), \
-                    records in groupby(insert, 
-                                        lambda rec: (rec[4], 
-                                                rec[2].keys(), 
-                                                bool(rec[5]), 
-                                                rec[6])
-                ):
-                    if has_all_pks and not hasvalue:
-                        records = list(records)
-                        multiparams = [rec[2] for rec in records]
-                        c = cached_connections[connection].\
-                                            execute(statement, multiparams)
-
-                        for (state, state_dict, params, mapper, 
-                                conn, value_params, has_all_pks), \
-                                last_inserted_params in \
-                                zip(records, c.context.compiled_parameters):
-                            mapper._postfetch(
-                                    uowtransaction, 
-                                    table,
-                                    state, 
-                                    state_dict,
-                                    c.context.prefetch_cols,
-                                    c.context.postfetch_cols,
-                                    last_inserted_params, 
-                                    value_params)
-
-                    else:
-                        for state, state_dict, params, mapper, \
-                                    connection, value_params, \
-                                    has_all_pks in records:
-
-                            if value_params:
-                                result = connection.execute(
-                                            statement.values(value_params),
-                                            params)
-                            else:
-                                result = cached_connections[connection].\
-                                                    execute(statement, params)
-
-                            primary_key = result.context.inserted_primary_key
-
-                            if primary_key is not None:
-                                # set primary key attributes
-                                for pk, col in zip(primary_key, 
-                                                mapper._pks_by_table[table]):
-                                    prop = mapper._columntoproperty[col]
-                                    if state_dict.get(prop.key) is None:
-                                        # TODO: would rather say:
-                                        #state_dict[prop.key] = pk
-                                        mapper._set_state_attr_by_column(
-                                                    state, 
-                                                    state_dict, 
-                                                    col, pk)
-
-                            mapper._postfetch(
-                                    uowtransaction, 
-                                    table, 
-                                    state, 
-                                    state_dict,
-                                    result.context.prefetch_cols, 
-                                    result.context.postfetch_cols,
-                                    result.context.compiled_parameters[0], 
-                                    value_params)
-
-
-        for state, state_dict, mapper, connection, has_identity, \
-                        instance_key, row_switch in tups:
-
-            if mapper._readonly_props:
-                readonly = state.unmodified_intersection(
-                    [p.key for p in mapper._readonly_props 
-                        if p.expire_on_flush or p.key not in state.dict]
-                )
-                if readonly:
-                    state.expire_attributes(state.dict, readonly)
-
-            # if eager_defaults option is enabled,
-            # refresh whatever has been expired.
-            if self.eager_defaults and state.unloaded:
-                state.key = self._identity_key_from_state(state)
-                uowtransaction.session.query(self)._load_on_ident(
-                    state.key, refresh_state=state,
-                    only_load_props=state.unloaded)
-
-            # call after_XXX extensions
-            if not has_identity:
-                mapper.dispatch.after_insert(mapper, connection, state)
-            else:
-                mapper.dispatch.after_update(mapper, connection, state)
-
-    def _postfetch(self, uowtransaction, table, 
-                    state, dict_, prefetch_cols, postfetch_cols,
-                                params, value_params):
-        """During a flush, expire attributes in need of newly 
-        persisted database state."""
-
-        if self.version_id_col is not None:
-            prefetch_cols = list(prefetch_cols) + [self.version_id_col]
-
-        for c in prefetch_cols:
-            if c.key in params and c in self._columntoproperty:
-                self._set_state_attr_by_column(state, dict_, c, params[c.key])
-
-        if postfetch_cols:
-            state.expire_attributes(state.dict, 
-                                [self._columntoproperty[c].key 
-                                for c in postfetch_cols if c in 
-                                self._columntoproperty]
-                            )
-
-        # synchronize newly inserted ids from one table to the next
-        # TODO: this still goes a little too often.  would be nice to
-        # have definitive list of "columns that changed" here
-        for m, equated_pairs in self._table_to_equated[table]:
-            sync.populate(state, m, state, m, 
-                                            equated_pairs, 
-                                            uowtransaction,
-                                            self.passive_updates)
-
     @util.memoized_property
     def _table_to_equated(self):
         """memoized map of tables to collections of columns to be 
 
         return result
 
-    def _delete_obj(self, states, uowtransaction):
-        """Issue ``DELETE`` statements for a list of objects.
-
-        This is called within the context of a UOWTransaction during a
-        flush operation.
-
-        """
-        if uowtransaction.session.connection_callable:
-            connection_callable = \
-                    uowtransaction.session.connection_callable
-        else:
-            connection = uowtransaction.transaction.connection(self)
-            connection_callable = None
-
-        tups = []
-        cached_connections = util.PopulateDict(
-            lambda conn:conn.execution_options(
-            compiled_cache=self._compiled_cache
-        ))
-
-        for state in _sort_states(states):
-            mapper = _state_mapper(state)
-
-            if connection_callable:
-                conn = connection_callable(self, state.obj())
-            else:
-                conn = connection
-
-            mapper.dispatch.before_delete(mapper, conn, state)
-
-            tups.append((state, 
-                    state.dict,
-                    _state_mapper(state), 
-                    bool(state.key),
-                    conn))
-
-        table_to_mapper = self._sorted_tables
-
-        for table in reversed(table_to_mapper.keys()):
-            delete = util.defaultdict(list)
-            for state, state_dict, mapper, has_identity, connection in tups:
-                if not has_identity or table not in mapper._pks_by_table:
-                    continue
-
-                params = {}
-                delete[connection].append(params)
-                for col in mapper._pks_by_table[table]:
-                    params[col.key] = \
-                            value = \
-                            mapper._get_state_attr_by_column(
-                                            state, state_dict, col)
-                    if value is None:
-                        raise sa_exc.FlushError(
-                                    "Can't delete from table "
-                                    "using NULL for primary "
-                                    "key value")
-
-                if mapper.version_id_col is not None and \
-                            table.c.contains_column(mapper.version_id_col):
-                    params[mapper.version_id_col.key] = \
-                                mapper._get_committed_state_attr_by_column(
-                                        state, state_dict,
-                                        mapper.version_id_col)
-
-            mapper = table_to_mapper[table]
-            need_version_id = mapper.version_id_col is not None and \
-                table.c.contains_column(mapper.version_id_col)
-
-            def delete_stmt():
-                clause = sql.and_()
-                for col in mapper._pks_by_table[table]:
-                    clause.clauses.append(
-                            col == sql.bindparam(col.key, type_=col.type))
-
-                if need_version_id:
-                    clause.clauses.append(
-                        mapper.version_id_col == 
-                        sql.bindparam(
-                                mapper.version_id_col.key, 
-                                type_=mapper.version_id_col.type
-                        )
-                    )
-
-                return table.delete(clause)
-
-            for connection, del_objects in delete.iteritems():
-                statement = self._memo(('delete', table), delete_stmt)
-                rows = -1
-
-                connection = cached_connections[connection]
-
-                if need_version_id and \
-                        not connection.dialect.supports_sane_multi_rowcount:
-                    # TODO: need test coverage for this [ticket:1761]
-                    if connection.dialect.supports_sane_rowcount:
-                        rows = 0
-                        # execute deletes individually so that versioned
-                        # rows can be verified
-                        for params in del_objects:
-                            c = connection.execute(statement, params)
-                            rows += c.rowcount
-                    else:
-                        util.warn(
-                            "Dialect %s does not support deleted rowcount "
-                            "- versioning cannot be verified." % 
-                            connection.dialect.dialect_description,
-                            stacklevel=12)
-                        connection.execute(statement, del_objects)
-                else:
-                    c = connection.execute(statement, del_objects)
-                    if connection.dialect.supports_sane_multi_rowcount:
-                        rows = c.rowcount
-
-                if rows != -1 and rows != len(del_objects):
-                    raise orm_exc.StaleDataError(
-                        "DELETE statement on table '%s' expected to delete %d row(s); "
-                        "%d were matched." % 
-                        (table.description, len(del_objects), c.rowcount)
-                    )
-
-        for state, state_dict, mapper, has_identity, connection in tups:
-            mapper.dispatch.after_delete(mapper, connection, state)
 
     def _instance_processor(self, context, path, reduced_path, adapter, 
                                 polymorphic_from=None, 
         """Produce a mapper level row processor callable 
            which processes rows into mapped instances."""
 
+        # note that this method, most of which exists in a closure
+        # called _instance(), resists being broken out, as 
+        # attempts to do so tend to add significant function
+        # call overhead.  _instance() is the most
+        # performance-critical section in the whole ORM.
+
         pk_cols = self.primary_key
 
         if polymorphic_from or refresh_state:
                                             state, state.dict, col, val)
 
 
-def _sort_states(states):
-    pending = set(states)
-    persistent = set(s for s in pending if s.key is not None)
-    pending.difference_update(persistent)
-    return sorted(pending, key=operator.attrgetter("insert_order")) + \
-        sorted(persistent, key=lambda q:q.key[1])
-
 class _ColumnMapping(util.py25_dict):
     """Error reporting helper for mapper._columntoproperty."""
 

lib/sqlalchemy/orm/persistence.py

+# orm/persistence.py
+# Copyright (C) 2005-2012 the SQLAlchemy authors and contributors <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""private module containing functions used to emit INSERT, UPDATE
+and DELETE statements on behalf of a :class:`.Mapper` and its descending
+mappers.
+
+The functions here are called only by the unit of work functions
+in unitofwork.py.
+
+"""
+
+import operator
+from itertools import groupby
+
+from sqlalchemy import sql, util, exc as sa_exc
+from sqlalchemy.orm import attributes, sync, \
+                        exc as orm_exc
+
+from sqlalchemy.orm.util import _state_mapper, state_str
+
+def save_obj(base_mapper, states, uowtransaction, single=False):
+    """Issue ``INSERT`` and/or ``UPDATE`` statements for a list 
+    of objects.
+
+    This is called within the context of a UOWTransaction during a
+    flush operation, given a list of states to be flushed.  The
+    base mapper in an inheritance hierarchy handles the inserts/
+    updates for all descendant mappers.
+
+    """
+
+    # if batch=false, call _save_obj separately for each object
+    if not single and not base_mapper.batch:
+        for state in _sort_states(states):
+            save_obj(base_mapper, [state], uowtransaction, single=True)
+        return
+
+    states_to_insert, states_to_update = _organize_states_for_save(
+                                                base_mapper, 
+                                                states, 
+                                                uowtransaction)
+
+    cached_connections = _cached_connection_dict(base_mapper)
+
+    for table, mapper in base_mapper._sorted_tables.iteritems():
+        insert = _collect_insert_commands(base_mapper, uowtransaction, 
+                                table, states_to_insert)
+
+        update = _collect_update_commands(base_mapper, uowtransaction, 
+                                table, states_to_update)
+
+        if update:
+            _emit_update_statements(base_mapper, uowtransaction, 
+                                    cached_connections, 
+                                    mapper, table, update)
+
+        if insert:
+            _emit_insert_statements(base_mapper, uowtransaction, 
+                                    cached_connections, 
+                                    table, insert)
+
+    _finalize_insert_update_commands(base_mapper, uowtransaction, 
+                                    states_to_insert, states_to_update)
+
+def post_update(base_mapper, states, uowtransaction, post_update_cols):
+    """Issue UPDATE statements on behalf of a relationship() which
+    specifies post_update.
+
+    """
+    cached_connections = _cached_connection_dict(base_mapper)
+
+    states_to_update = _organize_states_for_post_update(
+                                    base_mapper, 
+                                    states, uowtransaction)
+
+
+    for table, mapper in base_mapper._sorted_tables.iteritems():
+        update = _collect_post_update_commands(base_mapper, uowtransaction, 
+                                            table, states_to_update, 
+                                            post_update_cols)
+
+        if update:
+            _emit_post_update_statements(base_mapper, uowtransaction, 
+                                    cached_connections, 
+                                    mapper, table, update)
+
+def delete_obj(base_mapper, states, uowtransaction):
+    """Issue ``DELETE`` statements for a list of objects.
+
+    This is called within the context of a UOWTransaction during a
+    flush operation.
+
+    """
+
+    cached_connections = _cached_connection_dict(base_mapper)
+
+    states_to_delete = _organize_states_for_delete(
+                                        base_mapper, 
+                                        states,
+                                        uowtransaction)
+
+    table_to_mapper = base_mapper._sorted_tables
+
+    for table in reversed(table_to_mapper.keys()):
+        delete = _collect_delete_commands(base_mapper, uowtransaction, 
+                                table, states_to_delete)
+
+        mapper = table_to_mapper[table]
+
+        _emit_delete_statements(base_mapper, uowtransaction, 
+                    cached_connections, mapper, table, delete)
+
+    for state, state_dict, mapper, has_identity, connection \
+                        in states_to_delete:
+        mapper.dispatch.after_delete(mapper, connection, state)
+
+def _organize_states_for_save(base_mapper, states, uowtransaction):
+    """Make an initial pass across a set of states for INSERT or
+    UPDATE.
+    
+    This includes splitting out into distinct lists for
+    each, calling before_insert/before_update, obtaining
+    key information for each state including its dictionary,
+    mapper, the connection to use for the execution per state,
+    and the identity flag.
+    
+    """
+
+    states_to_insert = []
+    states_to_update = []
+
+    for state, dict_, mapper, connection in _connections_for_states(
+                                            base_mapper, uowtransaction, 
+                                            states):
+
+        has_identity = bool(state.key)
+        instance_key = state.key or mapper._identity_key_from_state(state)
+
+        row_switch = None
+
+        # call before_XXX extensions
+        if not has_identity:
+            mapper.dispatch.before_insert(mapper, connection, state)
+        else:
+            mapper.dispatch.before_update(mapper, connection, state)
+
+        # detect if we have a "pending" instance (i.e. has 
+        # no instance_key attached to it), and another instance 
+        # with the same identity key already exists as persistent. 
+        # convert to an UPDATE if so.
+        if not has_identity and \
+            instance_key in uowtransaction.session.identity_map:
+            instance = \
+                uowtransaction.session.identity_map[instance_key]
+            existing = attributes.instance_state(instance)
+            if not uowtransaction.is_deleted(existing):
+                raise orm_exc.FlushError(
+                    "New instance %s with identity key %s conflicts "
+                    "with persistent instance %s" % 
+                    (state_str(state), instance_key,
+                     state_str(existing)))
+
+            base_mapper._log_debug(
+                "detected row switch for identity %s.  "
+                "will update %s, remove %s from "
+                "transaction", instance_key, 
+                state_str(state), state_str(existing))
+
+            # remove the "delete" flag from the existing element
+            uowtransaction.remove_state_actions(existing)
+            row_switch = existing
+
+        if not has_identity and not row_switch:
+            states_to_insert.append(
+                (state, dict_, mapper, connection, 
+                has_identity, instance_key, row_switch)
+            )
+        else:
+            states_to_update.append(
+                (state, dict_, mapper, connection, 
+                has_identity, instance_key, row_switch)
+            )
+
+    return states_to_insert, states_to_update
+
+def _organize_states_for_post_update(base_mapper, states, 
+                                                uowtransaction):
+    """Make an initial pass across a set of states for UPDATE
+    corresponding to post_update.
+    
+    This includes obtaining key information for each state 
+    including its dictionary, mapper, the connection to use for 
+    the execution per state.
+    
+    """
+    return list(_connections_for_states(base_mapper, uowtransaction, 
+                                            states))
+
+def _organize_states_for_delete(base_mapper, states, uowtransaction):
+    """Make an initial pass across a set of states for DELETE.
+    
+    This includes calling out before_delete and obtaining
+    key information for each state including its dictionary,
+    mapper, the connection to use for the execution per state.
+    
+    """
+    states_to_delete = []
+
+    for state, dict_, mapper, connection in _connections_for_states(
+                                            base_mapper, uowtransaction, 
+                                            states):
+
+        mapper.dispatch.before_delete(mapper, connection, state)
+
+        states_to_delete.append((state, dict_, mapper, 
+                bool(state.key), connection))
+    return states_to_delete
+
+def _collect_insert_commands(base_mapper, uowtransaction, table, 
+                                                states_to_insert):
+    """Identify sets of values to use in INSERT statements for a
+    list of states.
+    
+    """
+    insert = []
+    for state, state_dict, mapper, connection, has_identity, \
+                    instance_key, row_switch in states_to_insert:
+        if table not in mapper._pks_by_table:
+            continue
+
+        pks = mapper._pks_by_table[table]
+
+        params = {}
+        value_params = {}
+
+        has_all_pks = True
+        for col in mapper._cols_by_table[table]:
+            if col is mapper.version_id_col:
+                params[col.key] = mapper.version_id_generator(None)
+            else:
+                # pull straight from the dict for 
+                # pending objects
+                prop = mapper._columntoproperty[col]
+                value = state_dict.get(prop.key, None)
+
+                if value is None:
+                    if col in pks:
+                        has_all_pks = False
+                    elif col.default is None and \
+                         col.server_default is None:
+                        params[col.key] = value
+
+                elif isinstance(value, sql.ClauseElement):
+                    value_params[col] = value
+                else:
+                    params[col.key] = value
+
+        insert.append((state, state_dict, params, mapper, 
+                        connection, value_params, has_all_pks))
+    return insert
+
+def _collect_update_commands(base_mapper, uowtransaction, 
+                                table, states_to_update):
+    """Identify sets of values to use in UPDATE statements for a
+    list of states.
+    
+    This function works intricately with the history system
+    to determine exactly what values should be updated
+    as well as how the row should be matched within an UPDATE
+    statement.  Includes some tricky scenarios where the primary
+    key of an object might have been changed.
+
+    """
+
+    update = []
+    for state, state_dict, mapper, connection, has_identity, \
+                    instance_key, row_switch in states_to_update:
+        if table not in mapper._pks_by_table:
+            continue
+
+        pks = mapper._pks_by_table[table]
+
+        params = {}
+        value_params = {}
+
+        hasdata = hasnull = False
+        for col in mapper._cols_by_table[table]:
+            if col is mapper.version_id_col:
+                params[col._label] = \
+                    mapper._get_committed_state_attr_by_column(
+                                    row_switch or state, 
+                                    row_switch and row_switch.dict 
+                                                or state_dict,
+                                    col)
+
+                prop = mapper._columntoproperty[col]
+                history = attributes.get_state_history(
+                    state, prop.key, 
+                    attributes.PASSIVE_NO_INITIALIZE
+                )
+                if history.added:
+                    params[col.key] = history.added[0]
+                    hasdata = True
+                else:
+                    params[col.key] = mapper.version_id_generator(
+                                                params[col._label])
+
+                    # HACK: check for history, in case the 
+                    # history is only
+                    # in a different table than the one 
+                    # where the version_id_col is.
+                    for prop in mapper._columntoproperty.itervalues():
+                        history = attributes.get_state_history(
+                                state, prop.key, 
+                                attributes.PASSIVE_NO_INITIALIZE)
+                        if history.added:
+                            hasdata = True
+            else:
+                prop = mapper._columntoproperty[col]
+                history = attributes.get_state_history(
+                                state, prop.key, 
+                                attributes.PASSIVE_NO_INITIALIZE)
+                if history.added:
+                    if isinstance(history.added[0],
+                                    sql.ClauseElement):
+                        value_params[col] = history.added[0]
+                    else:
+                        value = history.added[0]
+                        params[col.key] = value
+
+                    if col in pks:
+                        if history.deleted and \
+                            not row_switch:
+                            # if passive_updates and sync detected
+                            # this was a  pk->pk sync, use the new
+                            # value to locate the row, since the
+                            # DB would already have set this
+                            if ("pk_cascaded", state, col) in \
+                                            uowtransaction.attributes:
+                                value = history.added[0]
+                                params[col._label] = value
+                            else:
+                                # use the old value to 
+                                # locate the row
+                                value = history.deleted[0]
+                                params[col._label] = value
+                            hasdata = True
+                        else:
+                            # row switch logic can reach us here
+                            # remove the pk from the update params
+                            # so the update doesn't
+                            # attempt to include the pk in the
+                            # update statement
+                            del params[col.key]
+                            value = history.added[0]
+                            params[col._label] = value
+                        if value is None:
+                            hasnull = True
+                    else:
+                        hasdata = True
+                elif col in pks:
+                    value = state.manager[prop.key].impl.get(
+                                                    state, state_dict)
+                    if value is None:
+                        hasnull = True
+                    params[col._label] = value
+        if hasdata:
+            if hasnull:
+                raise sa_exc.FlushError(
+                            "Can't update table "
+                            "using NULL for primary "
+                            "key value")
+            update.append((state, state_dict, params, mapper, 
+                            connection, value_params))
+    return update
+
+
+def _collect_post_update_commands(base_mapper, uowtransaction, table, 
+                        states_to_update, post_update_cols):
+    """Identify sets of values to use in UPDATE statements for a
+    list of states within a post_update operation.
+
+    """
+
+    update = []
+    for state, state_dict, mapper, connection in states_to_update:
+        if table not in mapper._pks_by_table:
+            continue
+        pks = mapper._pks_by_table[table]
+        params = {}
+        hasdata = False
+
+        for col in mapper._cols_by_table[table]:
+            if col in pks:
+                params[col._label] = \
+                        mapper._get_state_attr_by_column(
+                                        state,
+                                        state_dict, col)
+            elif col in post_update_cols:
+                prop = mapper._columntoproperty[col]
+                history = attributes.get_state_history(
+                            state, prop.key, 
+                            attributes.PASSIVE_NO_INITIALIZE)
+                if history.added:
+                    value = history.added[0]
+                    params[col.key] = value
+                    hasdata = True
+        if hasdata:
+            update.append((state, state_dict, params, mapper, 
+                            connection))
+    return update
+
+def _collect_delete_commands(base_mapper, uowtransaction, table, 
+                                states_to_delete):
+    """Identify values to use in DELETE statements for a list of 
+    states to be deleted."""
+
+    delete = util.defaultdict(list)
+
+    for state, state_dict, mapper, has_identity, connection \
+                                        in states_to_delete:
+        if not has_identity or table not in mapper._pks_by_table:
+            continue
+
+        params = {}
+        delete[connection].append(params)
+        for col in mapper._pks_by_table[table]:
+            params[col.key] = \
+                    value = \
+                    mapper._get_state_attr_by_column(
+                                    state, state_dict, col)
+            if value is None:
+                raise sa_exc.FlushError(
+                            "Can't delete from table "
+                            "using NULL for primary "
+                            "key value")
+
+        if mapper.version_id_col is not None and \
+                    table.c.contains_column(mapper.version_id_col):
+            params[mapper.version_id_col.key] = \
+                        mapper._get_committed_state_attr_by_column(
+                                state, state_dict,
+                                mapper.version_id_col)
+    return delete
+
+
+def _emit_update_statements(base_mapper, uowtransaction, 
+                        cached_connections, mapper, table, update):
+    """Emit UPDATE statements corresponding to value lists collected
+    by _collect_update_commands()."""
+
+    needs_version_id = mapper.version_id_col is not None and \
+                table.c.contains_column(mapper.version_id_col)
+
+    def update_stmt():
+        clause = sql.and_()
+
+        for col in mapper._pks_by_table[table]:
+            clause.clauses.append(col == sql.bindparam(col._label,
+                                            type_=col.type))
+
+        if needs_version_id:
+            clause.clauses.append(mapper.version_id_col ==\
+                    sql.bindparam(mapper.version_id_col._label,
+                                    type_=col.type))
+
+        return table.update(clause)
+
+    statement = base_mapper._memo(('update', table), update_stmt)
+
+    rows = 0
+    for state, state_dict, params, mapper, \
+                connection, value_params in update:
+
+        if value_params:
+            c = connection.execute(
+                                statement.values(value_params),
+                                params)
+        else:
+            c = cached_connections[connection].\
+                                execute(statement, params)
+
+        _postfetch(
+                mapper,
+                uowtransaction, 
+                table, 
+                state, 
+                state_dict, 
+                c.context.prefetch_cols, 
+                c.context.postfetch_cols,
+                c.context.compiled_parameters[0], 
+                value_params)
+        rows += c.rowcount
+
+    if connection.dialect.supports_sane_rowcount:
+        if rows != len(update):
+            raise orm_exc.StaleDataError(
+                    "UPDATE statement on table '%s' expected to "
+                    "update %d row(s); %d were matched." %
+                    (table.description, len(update), rows))
+
+    elif needs_version_id:
+        util.warn("Dialect %s does not support updated rowcount "
+                "- versioning cannot be verified." % 
+                c.dialect.dialect_description,
+                stacklevel=12)
+
+def _emit_insert_statements(base_mapper, uowtransaction, 
+                        cached_connections, table, insert):
+    """Emit INSERT statements corresponding to value lists collected
+    by _collect_insert_commands()."""
+
+    statement = base_mapper._memo(('insert', table), table.insert)
+
+    for (connection, pkeys, hasvalue, has_all_pks), \
+        records in groupby(insert, 
+                            lambda rec: (rec[4], 
+                                    rec[2].keys(), 
+                                    bool(rec[5]), 
+                                    rec[6])
+    ):
+        if has_all_pks and not hasvalue:
+            records = list(records)
+            multiparams = [rec[2] for rec in records]
+            c = cached_connections[connection].\
+                                execute(statement, multiparams)
+
+            for (state, state_dict, params, mapper, 
+                    conn, value_params, has_all_pks), \
+                    last_inserted_params in \
+                    zip(records, c.context.compiled_parameters):
+                _postfetch(
+                        mapper,
+                        uowtransaction, 
+                        table,
+                        state, 
+                        state_dict,
+                        c.context.prefetch_cols,
+                        c.context.postfetch_cols,
+                        last_inserted_params, 
+                        value_params)
+
+        else:
+            for state, state_dict, params, mapper, \
+                        connection, value_params, \
+                        has_all_pks in records:
+
+                if value_params:
+                    result = connection.execute(
+                                statement.values(value_params),
+                                params)
+                else:
+                    result = cached_connections[connection].\
+                                        execute(statement, params)
+
+                primary_key = result.context.inserted_primary_key
+
+                if primary_key is not None:
+                    # set primary key attributes
+                    for pk, col in zip(primary_key, 
+                                    mapper._pks_by_table[table]):
+                        prop = mapper._columntoproperty[col]
+                        if state_dict.get(prop.key) is None:
+                            # TODO: would rather say:
+                            #state_dict[prop.key] = pk
+                            mapper._set_state_attr_by_column(
+                                        state, 
+                                        state_dict, 
+                                        col, pk)
+
+                _postfetch(
+                        mapper,
+                        uowtransaction, 
+                        table, 
+                        state, 
+                        state_dict,
+                        result.context.prefetch_cols, 
+                        result.context.postfetch_cols,
+                        result.context.compiled_parameters[0], 
+                        value_params)
+
+
+
+def _emit_post_update_statements(base_mapper, uowtransaction, 
+                            cached_connections, mapper, table, update):
+    """Emit UPDATE statements corresponding to value lists collected
+    by _collect_post_update_commands()."""
+
+    def update_stmt():
+        clause = sql.and_()
+
+        for col in mapper._pks_by_table[table]:
+            clause.clauses.append(col == sql.bindparam(col._label,
+                                            type_=col.type))
+
+        return table.update(clause)
+
+    statement = base_mapper._memo(('post_update', table), update_stmt)
+
+    # execute each UPDATE in the order according to the original
+    # list of states to guarantee row access order, but
+    # also group them into common (connection, cols) sets 
+    # to support executemany().
+    for key, grouper in groupby(
+        update, lambda rec: (rec[4], rec[2].keys())
+    ):
+        multiparams = [params for state, state_dict, 
+                                params, mapper, conn in grouper]
+        cached_connections[conn].\
+                            execute(statement, multiparams)
+
+
+def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, 
+                                    mapper, table, delete):
+    """Emit DELETE statements corresponding to value lists collected
+    by _collect_delete_commands()."""
+
+    need_version_id = mapper.version_id_col is not None and \
+        table.c.contains_column(mapper.version_id_col)
+
+    def delete_stmt():
+        clause = sql.and_()
+        for col in mapper._pks_by_table[table]:
+            clause.clauses.append(
+                    col == sql.bindparam(col.key, type_=col.type))
+
+        if need_version_id:
+            clause.clauses.append(
+                mapper.version_id_col == 
+                sql.bindparam(
+                        mapper.version_id_col.key, 
+                        type_=mapper.version_id_col.type
+                )
+            )
+
+        return table.delete(clause)
+
+    for connection, del_objects in delete.iteritems():
+        statement = base_mapper._memo(('delete', table), delete_stmt)
+        rows = -1
+
+        connection = cached_connections[connection]
+
+        if need_version_id and \
+                not connection.dialect.supports_sane_multi_rowcount:
+            # TODO: need test coverage for this [ticket:1761]
+            if connection.dialect.supports_sane_rowcount:
+                rows = 0
+                # execute deletes individually so that versioned
+                # rows can be verified
+                for params in del_objects:
+                    c = connection.execute(statement, params)
+                    rows += c.rowcount
+            else:
+                util.warn(
+                    "Dialect %s does not support deleted rowcount "
+                    "- versioning cannot be verified." % 
+                    connection.dialect.dialect_description,
+                    stacklevel=12)
+                connection.execute(statement, del_objects)
+        else:
+            c = connection.execute(statement, del_objects)
+            if connection.dialect.supports_sane_multi_rowcount:
+                rows = c.rowcount
+
+        if rows != -1 and rows != len(del_objects):
+            raise orm_exc.StaleDataError(
+                "DELETE statement on table '%s' expected to "
+                "delete %d row(s); %d were matched." % 
+                (table.description, len(del_objects), c.rowcount)
+            )
+
+def _finalize_insert_update_commands(base_mapper, uowtransaction, 
+                            states_to_insert, states_to_update):
+    """finalize state on states that have been inserted or updated,
+    including calling after_insert/after_update events.
+    
+    """
+    for state, state_dict, mapper, connection, has_identity, \
+                    instance_key, row_switch in states_to_insert + \
+                                                    states_to_update:
+
+        if mapper._readonly_props:
+            readonly = state.unmodified_intersection(
+                [p.key for p in mapper._readonly_props 
+                    if p.expire_on_flush or p.key not in state.dict]
+            )
+            if readonly:
+                state.expire_attributes(state.dict, readonly)
+
+        # if eager_defaults option is enabled,
+        # refresh whatever has been expired.
+        if base_mapper.eager_defaults and state.unloaded:
+            state.key = base_mapper._identity_key_from_state(state)
+            uowtransaction.session.query(base_mapper)._load_on_ident(
+                state.key, refresh_state=state,
+                only_load_props=state.unloaded)
+
+        # call after_XXX extensions
+        if not has_identity:
+            mapper.dispatch.after_insert(mapper, connection, state)
+        else:
+            mapper.dispatch.after_update(mapper, connection, state)
+
+def _postfetch(mapper, uowtransaction, table, 
+                state, dict_, prefetch_cols, postfetch_cols,
+                            params, value_params):
+    """Expire attributes in need of newly persisted database state,
+    after an INSERT or UPDATE statement has proceeded for that
+    state."""
+
+    if mapper.version_id_col is not None:
+        prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
+
+    for c in prefetch_cols:
+        if c.key in params and c in mapper._columntoproperty:
+            mapper._set_state_attr_by_column(state, dict_, c, params[c.key])
+
+    if postfetch_cols:
+        state.expire_attributes(state.dict, 
+                            [mapper._columntoproperty[c].key 
+                            for c in postfetch_cols if c in 
+                            mapper._columntoproperty]
+                        )
+
+    # synchronize newly inserted ids from one table to the next
+    # TODO: this still goes a little too often.  would be nice to
+    # have definitive list of "columns that changed" here
+    for m, equated_pairs in mapper._table_to_equated[table]:
+        sync.populate(state, m, state, m, 
+                                        equated_pairs, 
+                                        uowtransaction,
+                                        mapper.passive_updates)
+
+def _connections_for_states(base_mapper, uowtransaction, states):
+    """Return an iterator of (state, state.dict, mapper, connection).
+    
+    The states are sorted according to _sort_states, then paired
+    with the connection they should be using for the given
+    unit of work transaction.
+    
+    """
+    # if session has a connection callable,
+    # organize individual states with the connection 
+    # to use for update
+    if uowtransaction.session.connection_callable:
+        connection_callable = \
+                uowtransaction.session.connection_callable
+    else:
+        connection = uowtransaction.transaction.connection(
+                                                    base_mapper)
+        connection_callable = None
+
+    for state in _sort_states(states):
+        if connection_callable:
+            connection = connection_callable(base_mapper, state.obj())
+
+        mapper = _state_mapper(state)
+
+        yield state, state.dict, mapper, connection
+
+def _cached_connection_dict(base_mapper):
+    # dictionary of connection->connection_with_cache_options.
+    return util.PopulateDict(
+        lambda conn:conn.execution_options(
+        compiled_cache=base_mapper._compiled_cache
+    ))
+
+def _sort_states(states):
+    pending = set(states)
+    persistent = set(s for s in pending if s.key is not None)
+    pending.difference_update(persistent)
+    return sorted(pending, key=operator.attrgetter("insert_order")) + \
+                sorted(persistent, key=lambda q:q.key[1])
+
+

lib/sqlalchemy/orm/sync.py

 
 """private module containing functions used for copying data 
 between instances based on join conditions.
+
 """
 
 from sqlalchemy.orm import exc, util as mapperutil, attributes

lib/sqlalchemy/orm/unitofwork.py

 
 from sqlalchemy import util, event
 from sqlalchemy.util import topological
-from sqlalchemy.orm import attributes, interfaces
+from sqlalchemy.orm import attributes, interfaces, persistence
 from sqlalchemy.orm import util as mapperutil
 session = util.importlater("sqlalchemy.orm", "session")
 
         states, cols = uow.post_update_states[self.mapper]
         states = [s for s in states if uow.states[s][0] == self.isdelete]
 
-        self.mapper._post_update(states, uow, cols)
+        persistence.post_update(self.mapper, states, uow, cols)
 
 class SaveUpdateAll(PostSortRec):
     def __init__(self, uow, mapper):
         assert mapper is mapper.base_mapper
 
     def execute(self, uow):
-        self.mapper._save_obj(
+        persistence.save_obj(self.mapper, 
             uow.states_for_mapper_hierarchy(self.mapper, False, False),
             uow
         )
         assert mapper is mapper.base_mapper
 
     def execute(self, uow):
-        self.mapper._delete_obj(
+        persistence.delete_obj(self.mapper,
             uow.states_for_mapper_hierarchy(self.mapper, True, False),
             uow
         )
                         if r.__class__ is cls_ and 
                         r.mapper is mapper]
         recs.difference_update(our_recs)
-        mapper._save_obj(
+        persistence.save_obj(mapper,
                         [self.state] + 
                         [r.state for r in our_recs], 
                         uow)
                         r.mapper is mapper]
         recs.difference_update(our_recs)
         states = [self.state] + [r.state for r in our_recs]
-        mapper._delete_obj(
+        persistence.delete_obj(mapper,
                         [s for s in states if uow.states[s][0]], 
                         uow)
 

test/orm/test_mapper.py

     validates, aliased, defer, deferred, synonym, attributes, \
     column_property, composite, dynamic_loader, \
     comparable_property, Session
-from sqlalchemy.orm.mapper import _sort_states
+from sqlalchemy.orm.persistence import _sort_states
 from test.lib.testing import eq_, AssertsCompiledSQL
 from test.lib import fixtures
 from test.orm import _fixtures
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.