Commits

Gustavo Picon  committed 46fcaf3

Cursors are multi-db friendlier (thanks Natalia Stepina)

  • Participants
  • Parent commits d8ae5bc

Comments (0)

Files changed (4)

File treebeard/al_tree.py

         if self.parent_id:
             newobj.parent_id = self.parent_id
 
-        cursor = connection.cursor()
+        cursor = self._get_database_cursor('write')
         for sql, vals in stmts:
             cursor.execute(sql, vals)
 
                 self.parent = target.parent
 
         if stmts:
-            cursor = connection.cursor()
+            cursor = self._get_database_cursor('write')
             for sql, vals in stmts:
                 cursor.execute(sql, vals)
 

File treebeard/models.py

 class Node(models.Model):
     """Node class"""
 
-    _db_vendor = None
+    _db_connection = None
 
     @classmethod
     def add_root(cls, **kwargs):  # pragma: no cover
         return current_class
 
     @classmethod
+    def _get_database_connection(cls, action):
+        if cls._db_connection is None:
+            cls._db_connection = {
+                'read': connections[router.db_for_read(cls)],
+                'write': connections[router.db_for_write(cls)]
+            }
+        return cls._db_connection[action]
+
+    @classmethod
     def get_database_vendor(cls, action):
         """
-        Returns the supported database vendor used by a treebeard model when
+        returns the supported database vendor used by a treebeard model when
         performing read (select) or write (update, insert, delete) operations.
 
         :param action:
 
         :returns: postgresql, mysql or sqlite
         """
-        if cls._db_vendor is None:
-            cls._db_vendor = {
-                'read': connections[router.db_for_read(cls)].vendor,
-                'write': connections[router.db_for_write(cls)].vendor
-            }
-        return cls._db_vendor[action]
+        return cls._get_database_connection(action).vendor
+
+    @classmethod
+    def _get_database_cursor(cls, action):
+        return cls._get_database_connection(action).cursor()
 
     class Meta:
         """Abstract model."""

File treebeard/mp_tree.py

             cls.objects.all().delete()
             cls.load_bulk(dump, None, True)
         else:
-            cursor = connection.cursor()
+            cursor = cls._get_database_cursor('write')
 
             # fix the depth field
             # we need the WHERE to speed up postgres
                   'subpathlen': depth * cls.steplen,
                   'depth': depth,
                   'extrand': extrand}
-        cursor = connection.cursor()
+        cursor = cls._get_database_cursor('write')
         cursor.execute(sql, params)
 
         ret = []
               "WHERE path=%%s" % {
                   'table': connection.ops.quote_name(
                       self.__class__._meta.db_table)}
-        cursor = connection.cursor()
+        cursor = self._get_database_cursor('write')
         cursor.execute(sql, [self.path])
         transaction.commit_unless_managed()
 
         if parentpath:
             stmts.append(self._get_sql_update_numchild(parentpath, 'inc'))
 
-        cursor = connection.cursor()
+        cursor = self._get_database_cursor('write')
         for sql, vals in stmts:
             cursor.execute(sql, vals)
 
         # updates needed for mysql and children count in parents
         self._updates_after_move(oldpath, newpath, stmts)
 
-        cursor = connection.cursor()
+        cursor = self._get_database_cursor('write')
         for sql, vals in stmts:
             cursor.execute(sql, vals)
         transaction.commit_unless_managed()

File treebeard/ns_tree.py

             # delete method and let it handle the removal of the user's
             # foreign keys...
             super(NS_NodeQuerySet, self).delete()
-            cursor = connection.cursor()
+            cursor = self.model._get_database_cursor('write')
 
             # Now closing the gap (Celko's trees book, page 62)
             # We do this for every gap that was left in the tree when the nodes
 
         newobj._cached_parent_obj = self
 
-        cursor = connection.cursor()
+        cursor = self._get_database_cursor('write')
         cursor.execute(sql, params)
 
         # saving the instance before returning it
 
         # saving the instance before returning it
         if sql:
-            cursor = connection.cursor()
+            cursor = self._get_database_cursor('write')
             cursor.execute(sql, params)
         newobj.save()
 
                 target = siblings[0]
 
         # ok let's move this
-        cursor = connection.cursor()
+        cursor = self._get_database_cursor('write')
         move_right = cls._move_right
         gap = self.rgt - self.lft + 1
         sql = None