Commits

ol...@ollycope.com  committed 200a0a1

Hoist most of the StormFactory implementation into a generic _DataMapperFactory

Most of the storm implementation will work for SQLAlchemy with just minor
changes (eg 'remove' will need to become 'delete' in sqlalchemy land).

  • Participants
  • Parent commits c1aac95

Comments (0)

Files changed (1)

         ob.delete()
 
 
-class StormFactory(Factory):
+class _DataMapperFactory(Factory):
     """\
-    Typically you would configure this in at the start of your test code, so::
+    A factory that requires a data mapper object (eg a storm store or
+    sqlalchemy session) in order to bind objects to the database.
+    """
+
+    #: Callable that returns the data mapper object
+    _getmapper = lambda: None
+    mapperkey = '_datamapper'
+
+    default_flush = True
+    default_commit = False
+
+    @classmethod
+    def _mapper_flush(cls, context):
+        mapper = cls._getmapper_cached(context)
+        if mapper is not None:
+            mapper.flush()
+
+    @classmethod
+    def _mapper_commit(cls, context):
+        mapper = cls._getmapper_cached(context)
+        if mapper is not None:
+            mapper.commit()
+
+    def _mapper_add(self, context, ob):
+        """
+        Ask the datamapper to add ``ob`` to the data store
+        """
+        mapper = self._getmapper_cached(context)
+        if mapper is not None:
+            mapper.add(ob)
+
+    def _mapper_remove(self, context, ob):
+        """
+        Ask the datamapper to remove ``ob`` from the data store
+        """
+        mapper = self._getmapper_cached(context)
+        if mapper is not None:
+            mapper.remove(ob)
+
+    @classmethod
+    def configure(cls, getmapper, *args, **kwargs):
+        base = super(_DataMapperFactory, cls).configure(*args, **kwargs)
+        return type(cls.__name__, (base,),
+                    {'_getmapper': staticmethod(getmapper)})
+
+    @classmethod
+    def _getmapper_cached(cls, context):
+        try:
+            return context.factoryoptions[cls.mapperkey]
+        except KeyError:
+            return context.factoryoptions.setdefault(cls.mapperkey,
+                                                     cls._getmapper())
+
+    @classmethod
+    def setup_complete(cls, context, created):
+        if context.factoryoptions.get('flush', cls.default_flush):
+            cls._mapper_flush(context)
+        if context.factoryoptions.get('commit', cls.default_commit):
+            cls._mapper_commit(context)
+
+    @classmethod
+    def teardown_complete(cls, context, created):
+        if context.factoryoptions.get('flush', cls.default_flush):
+            cls._mapper_flush(context)
+        if context.factoryoptions.get('commit', cls.default_commit):
+            cls._mapper_commit(context)
+
+    def _create_object(self, context, args, kwargs):
+        ob = self.what.__new__(self.what)
+        for item, value in kwargs.items():
+            setattr(ob, item, value)
+        self._mapper_add(context, ob)
+        return ob
+
+    def _destroy_object(self, context, ob):
+        self._mapper_remove(context, ob)
+
+
+class StormFactory(_DataMapperFactory):
+    """\
+    Typically you will need to configure this at the start of your test code,
+    like so::
 
         getstore = lamdba: getUtility(IZStorm).get('main'))
         Factory = StormFactory.configure(getstore)
 
     """
 
-    #: Callable that can retrieve Storm's store object.
-    getstore = lambda: None
-
-    storekey = '_StormFactory_store'
-
-    default_flush = True
-    default_commit = False
-
-    @classmethod
-    def configure(cls, getstore, *args, **kwargs):
-        base = super(StormFactory, cls).configure(*args, **kwargs)
-        return type(cls.__name__, (base,),
-                    {'getstore': staticmethod(getstore)})
-
-    @classmethod
-    def _getstore_cached(cls, context):
-        try:
-            return context.factoryoptions[cls.storekey]
-        except KeyError:
-            return context.factoryoptions.setdefault(cls.storekey,
-                                                     cls.getstore())
-
-    def _create_object(self, context, args, kwargs):
-        store = self._getstore_cached(context)
-        ob = self.what.__new__(self.what)
-        for item, value in kwargs.items():
-            setattr(ob, item, value)
-        if store is not None:
-            store.add(ob)
-        return ob
-
-    def _destroy_object(self, context, ob):
-        store = self._getstore_cached(context)
-        if store is not None:
-            store.remove(ob)
-
-    @classmethod
-    def setup_complete(cls, context, created):
-        store = cls._getstore_cached(context)
-        if store:
-            if context.factoryoptions.get('flush', cls.default_flush):
-                store.flush()
-            if context.factoryoptions.get('commit', cls.default_commit):
-                store.commit()
-
-    @classmethod
-    def teardown_complete(cls, context, created):
-        store = cls._getstore_cached(context)
-        if store:
-            if context.factoryoptions.get('flush', cls.default_flush):
-                store.flush()
-            if context.factoryoptions.get('commit', cls.default_commit):
-                store.commit()
-
 
 class ArgumentGenerator(object):
     """