Commits

Anonymous committed f77e64a

``set_base()`` function and unit test.

  • Participants
  • Parent commits 3433b0b

Comments (0)

Files changed (2)

File sqlahelper.py

     If the name is "default" or omitted, this will be the application's default
     engine. The contextual session will be bound to it, the declarative base's
     metadata will be bound to it, and calling ``get_engine()`` without an
+    argument will return it.
     """
     _engines[name] = engine
     if name == "default":
 
     If no argument, look for an engine named "default".
 
-    Raise ``RuntimeError`` if no engine by the specified was configured.
+    Raise ``RuntimeError`` if no engine under that name has been configured.
     """
     try:
         return _engines[name]
     """Return the central SQLAlchemy declarative base.
     """
     return _base
+
+def set_base(base):
+    """Set the central SQLAlchemy declarative base.
+
+    Subsequent calls to ``get_base()`` will return this base instead of the
+    default one. This is useful if you need to override the default base, for
+    instance to make it inherit from your own superclass.
+
+    You'll have to make sure that no part of your application's code or any
+    third-party library calls ``get_base()`` before you call ``set_base()``,
+    otherwise they'll get the old base. You can ensure this by calling
+    ``set_base()`` early in the application's execution, before importing the
+    third-party libraries.
+    """
+    global _base
+    _base = base
 
 import sqlalchemy as sa
 from sqlalchemy.engine.base import Engine
+import sqlalchemy.ext.declarative as declarative
 
 import sqlahelper
 
         self.file = os.path.join(dir, filename)
         self.url = "sqlite:///" + self.file
 
-class PyramidSQLATestCase(unittest.TestCase):
+class SQLAHelperTestCase(unittest.TestCase):
     def setUp(self):
         self.dir = tempfile.mkdtemp()
         self.db1 = DBInfo(self.dir, "db1.sqlite")
                 raise AssertionError("%r is not %r" % (a, b))
 
 
-class TestAddEngine(PyramidSQLATestCase):
+class TestAddEngine(SQLAHelperTestCase):
     def test_one_engine(self):
         e = sa.create_engine(self.db1.url)
         sqlahelper.add_engine(e)
         self.assertRaises(RuntimeError, sqlahelper.get_engine)
 
 
-class TestDeclarativeBase(PyramidSQLATestCase):
+class TestDeclarativeBase(SQLAHelperTestCase):
     def test1(self):
         import transaction
         Base = sqlahelper.get_base()
         result = [x.first_name for x in q]
         control = [u"Wilma", u"Fred", u"Betty", u"Barney"]
         self.assertEqual(result, control)
+
+
+class TestSetBase(SQLAHelperTestCase):
+    def test1(self):
+        base = sqlahelper.get_base()
+        my_base = declarative.declarative_base()
+        sqlahelper.set_base(my_base)
+        base2 = sqlahelper.get_base()
+        try:
+            self.assertIsNot(base2, base)
+            self.assertIs(base2, my_base)
+        except AttributeError:  # Python < 2.7
+            self.assertNotEqual(base2, base)
+            self.assertEqual(base2, my_base)
+
+
+if __name__ == "__main__":
+    unittest.main()