Commits

Anonymous committed db89155

Add delete and update methods to query

  • Participants
  • Parent commits 15d3a6b

Comments (0)

Files changed (2)

lib/sqlalchemy/orm/query.py

 
 from itertools import chain
 
-from sqlalchemy import sql, util, log
+from sqlalchemy import sql, util, log, schema
 from sqlalchemy import exc as sa_exc
 from sqlalchemy.orm import exc as orm_exc
 from sqlalchemy.sql import util as sql_util
 from sqlalchemy.sql import expression, visitors, operators
-from sqlalchemy.orm import attributes, interfaces, mapper, object_mapper
+from sqlalchemy.orm import attributes, interfaces, mapper, object_mapper, evaluator
 from sqlalchemy.orm.util import _state_mapper, _is_mapped_class, \
      _is_aliased_class, _entity_descriptor, _entity_info, _class_to_mapper, \
      _orm_columns, AliasedClass, _orm_selectable, join as orm_join, ORMAdapter
         if self._autoflush and not self._populate_existing:
             self.session._autoflush()
         return self.session.scalar(s, params=self._params, mapper=self._mapper_zero())
+    
+    def delete(self, synchronize_session='evaluate'):
+        """EXPERIMENTAL"""
+        #TODO: lots of duplication and ifs - probably needs to be refactored to strategies
+        context = self._compile_context()
+        if len(context.statement.froms) != 1 or not isinstance(context.statement.froms[0], schema.Table):
+            raise sa_exc.ArgumentError("Only deletion via a single table query is currently supported")
+        primary_table = context.statement.froms[0]
+        
+        session = self.session
+        
+        if synchronize_session == 'evaluate':
+            try:
+                evaluator_compiler = evaluator.EvaluatorCompiler()
+                eval_condition = evaluator_compiler.process(self.whereclause)
+            except evaluator.UnevaluatableError:
+                synchronize_session = 'fetch'
+        
+        delete_stmt = sql.delete(primary_table, context.whereclause)
+        
+        if synchronize_session == 'fetch':
+            #TODO: use RETURNING when available
+            select_stmt = context.statement.with_only_columns(primary_table.primary_key)
+            matched_rows = session.execute(select_stmt).fetchall()
+        
+        session.execute(delete_stmt)
+        
+        if synchronize_session == 'evaluate':
+            target_cls = self._mapper_zero().class_
+            
+            #TODO: detect when the where clause is a trivial primary key match
+            objs_to_expunge = [obj for (cls, pk, entity_name),obj in session.identity_map.iteritems()
+                if issubclass(cls, target_cls) and eval_condition(obj)]
+            for obj in objs_to_expunge:
+                session.expunge(obj)
+        elif synchronize_session == 'fetch':
+            target_mapper = self._mapper_zero()
+            for primary_key in matched_rows:
+                identity_key = target_mapper.identity_key_from_primary_key(list(primary_key))
+                if identity_key in session.identity_map:
+                    session.expunge(session.identity_map[identity_key])
 
+    def update(self, values, synchronize_session='evaluate'):
+        """EXPERIMENTAL"""
+        
+        #TODO: value keys need to be mapped to corresponding sql cols and instr.attr.s to string keys
+        #TODO: updates of manytoone relations need to be converted to fk assignments
+        
+        context = self._compile_context()
+        if len(context.statement.froms) != 1 or not isinstance(context.statement.froms[0], schema.Table):
+            raise sa_exc.ArgumentError("Only update via a single table query is currently supported")
+        primary_table = context.statement.froms[0]
+        
+        session = self.session
+        
+        if synchronize_session == 'evaluate':
+            try:
+                evaluator_compiler = evaluator.EvaluatorCompiler()
+                eval_condition = evaluator_compiler.process(self.whereclause)
+                
+                value_evaluators = {}
+                for key,value in values.items():
+                    value_evaluators[key] = evaluator_compiler.process(expression._literal_as_binds(value))
+            except evaluator.UnevaluatableError:
+                synchronize_session = 'expire'
+        
+        update_stmt = sql.update(primary_table, context.whereclause, values)
+        
+        if synchronize_session == 'expire':
+            select_stmt = context.statement.with_only_columns(primary_table.primary_key)
+            matched_rows = session.execute(select_stmt).fetchall()
+        
+        session.execute(update_stmt)
+        
+        if synchronize_session == 'evaluate':
+            target_cls = self._mapper_zero().class_
+            
+            for (cls, pk, entity_name),obj in session.identity_map.iteritems():
+                if issubclass(cls, target_cls) and eval_condition(obj):
+                    for key,eval_value in value_evaluators.items():
+                        obj.__dict__[key] = eval_value(obj)
+        
+        elif synchronize_session == 'expire':
+            target_mapper = self._mapper_zero()
+            
+            for primary_key in matched_rows:
+                identity_key = target_mapper.identity_key_from_primary_key(list(primary_key))
+                if identity_key in session.identity_map:
+                    session.expire(session.identity_map[identity_key], values.keys())
+       
+    
     def _compile_context(self, labels=True):
         context = QueryContext(self)
 

test/orm/query.py

                 b1
         )
         
+class UpdateTest(_base.MappedTest):
+    def define_tables(self, metadata):
+        Table('users', metadata,
+              Column('id', Integer, primary_key=True),
+              Column('name', String),
+              Column('age', Integer))
+    
+    def setup_classes(self):
+        class User(_base.ComparableEntity):
+            pass
+    
+    @testing.resolve_artifact_names
+    def insert_data(self):
+        users.insert().execute([
+            dict(id=1, name='john', age=25),
+            dict(id=2, name='jack', age=47),
+            dict(id=3, name='jill', age=29),
+            dict(id=4, name='jane', age=37),
+        ])
+    
+    @testing.resolve_artifact_names
+    def setup_mappers(self):
+        mapper(User, users)
+    
+    @testing.resolve_artifact_names
+    def test_delete(self):
+        sess = create_session(bind=testing.db, autocommit=False)
+        
+        john,jack,jill,jane = sess.query(User).order_by(User.id).all()
+        sess.query(User).filter(or_(User.name == 'john', User.name == 'jill')).delete()
+        
+        assert john not in sess and jill not in sess
+        
+        eq_(sess.query(User).order_by(User.id).all(), [jack,jane])
+        
+    @testing.resolve_artifact_names
+    def test_delete_without_session_sync(self):
+        sess = create_session(bind=testing.db, autocommit=False)
+        
+        john,jack,jill,jane = sess.query(User).order_by(User.id).all()
+        sess.query(User).filter(or_(User.name == 'john', User.name == 'jill')).delete(synchronize_session=False)
+        
+        assert john in sess and jill in sess
+        
+        eq_(sess.query(User).order_by(User.id).all(), [jack,jane])
+    
+    @testing.resolve_artifact_names
+    def test_delete_with_fetch_strategy(self):
+        sess = create_session(bind=testing.db, autocommit=False)
+        
+        john,jack,jill,jane = sess.query(User).order_by(User.id).all()
+        sess.query(User).filter(or_(User.name == 'john', User.name == 'jill')).delete(synchronize_session='fetch')
+        
+        assert john not in sess and jill not in sess
+        
+        eq_(sess.query(User).order_by(User.id).all(), [jack,jane])
+    
+    @testing.resolve_artifact_names
+    def test_delete_fallback(self):
+        sess = create_session(bind=testing.db, autocommit=False)
+        
+        john,jack,jill,jane = sess.query(User).order_by(User.id).all()
+        sess.query(User).filter(User.name == select([func.max(User.name)])).delete()
+        
+        assert john not in sess
+        
+        eq_(sess.query(User).order_by(User.id).all(), [jack,jill,jane])
+    
+    @testing.resolve_artifact_names
+    def test_update(self):
+        sess = create_session(bind=testing.db, autocommit=False)
+        
+        john,jack,jill,jane = sess.query(User).order_by(User.id).all()
+        sess.query(User).filter(User.age > 29).update({'age': User.age - 10})
+        
+        eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27])
+        eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27]))
+
+    @testing.resolve_artifact_names
+    def test_update_with_expire_strategy(self):
+        sess = create_session(bind=testing.db, autocommit=False)
+        
+        john,jack,jill,jane = sess.query(User).order_by(User.id).all()
+        sess.query(User).filter(User.age > 29).update({'age': User.age - 10}, synchronize_session='expire')
+        
+        eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27])
+        eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27]))
 
 if __name__ == '__main__':
     testenv.main()