Commits

Gustavo Picon  committed 1990667

Renamed Node.get_database_engine() to Node.get_database_vendor()

  • Participants
  • Parent commits 7f5daea

Comments (0)

Files changed (3)

   and/or signals.
 * Improved translation files, including javascript.
 * Fixed Django 1.4 support in the admin.
+* Renamed Node.get_database_engine() to Node.get_database_vendor(). As the name
+  implies, it returns the database vendor instead of the engine used. Treebeard
+  will get the value from Django, but you can subclass the method if needed.
 
 Release 1.61 (Jul 24, 2010)
 ---------------------------

File treebeard/models.py

     from functools import reduce
 
 from django.db.models import Q
-from django.db import models, transaction
+from django.db import models, transaction, router, connections
 from django.conf import settings
 
 from treebeard.exceptions import InvalidPosition, MissingNodeOrderBy
 class Node(models.Model):
     "Node class"
 
+    _db_vendor = None
+
     @classmethod
     def add_root(cls, **kwargs):  # pragma: no cover
         """
         return cls
 
     @classmethod
-    def get_database_engine(cls):
+    def get_database_vendor(cls, action):
         """
-        Returns the supported database engine used by a treebeard model.
+        Returns the supported database vendor used by a treebeard model when
+        performing read (select) or write (update, insert, delete) operations.
 
-        This will return the default database engine depending on the version
-        of Django. If you use something different, like a non-default database,
-        you need to override this method and return the correct engine.
+        :param action:
 
-        :returns: postgresql, postgresql_psycopg2, mysql or sqlite3
+            `read` or `write`
+
+        :returns: postgresql, mysql or sqlite
         """
-        engine = None
-        try:
-            engine = settings.DATABASES['default']['ENGINE']
-        except (AttributeError, KeyError):
-            engine = None
-            # the old style settings still work in Django 1.2+ if there is no
-        # DATABASES setting
-        if engine is None:
-            engine = settings.DATABASE_ENGINE
-        return engine.split('.')[-1]
+        if cls._db_vendor is None:
+            cls._db_vendor = {
+                'read': connections[router.db_for_read(cls)].vendor,
+                'write': connections[router.db_for_write(cls)].vendor
+            }
+        return cls._db_vendor[action]
 
     class Meta:
         "Abstract model."

File treebeard/mp_tree.py

             # fix the numchild field
             vals = ['_' * cls.steplen]
             # the cake and sql portability are a lie
-            if cls.get_database_engine() == 'mysql':
+            if cls.get_database_vendor('read') == 'mysql':
                 sql = "SELECT tbn1.path, tbn1.numchild, ("\
                       "SELECT COUNT(1) "\
                       "FROM %(table)s AS tbn2 "\
         2. update the number of children of parent nodes
         """
         if (
-                cls.get_database_engine() == 'mysql' and
+                cls.get_database_vendor('write') == 'mysql' and
                 len(oldpath) != len(newpath)
         ):
             # no words can describe how dumb mysql is
 
         """
 
+        vendor = cls.get_database_vendor('write')
         sql1 = "UPDATE %s SET" % (
             connection.ops.quote_name(cls._meta.db_table), )
 
         # <3 "standard" sql
-        if cls.get_database_engine() == 'sqlite3':
+        if vendor == 'sqlite':
             # I know that the third argument in SUBSTR (LENGTH(path)) is
             # awful, but sqlite fails without it:
             # OperationalError: wrong number of arguments to function substr()
             # even when the documentation says that 2 arguments are valid:
             # http://www.sqlite.org/lang_corefunc.html
             sqlpath = "%s||SUBSTR(path, %s, LENGTH(path))"
-        elif cls.get_database_engine() == 'mysql':
+        elif vendor == 'mysql':
             # hooray for mysql ignoring standards in their default
             # configuration!
             # to make || work as it should, enable ansi mode
 
         sql2 = ["path=%s" % (sqlpath, )]
         vals = [newpath, len(oldpath) + 1]
-        if (
-                len(oldpath) != len(newpath) and
-                cls.get_database_engine() != 'mysql'
-        ):
+        if len(oldpath) != len(newpath) and vendor != 'mysql':
             # when using mysql, this won't update the depth and it has to be
             # done in another query
             # doesn't even work with sql_mode='ANSI,TRADITIONAL'