Commits

Gustavo Picon committed a05a114

Added get_descendants_group_count helper

Comments (0)

Files changed (2)

treebeard/models.py

 
 
 
+    @classmethod
+    def get_descendants_group_count(cls, parent=None):
+        """
+        Helper function for efficient aggregates.
+        FIXME: Needs better docstring.
+        """
+        if parent:
+            depth = parent.depth + 1
+            sqlparms = cls._get_children_path_interval(parent.path)
+            extrand = 'AND path BETWEEN %s AND %s'
+        else:
+            depth = 1
+            sqlparms = []
+            extrand = ''
+        sql = 'SELECT SUBSTR(path, 0, %d) AS subpath, COUNT(*)-1 AS count ' \
+              ' FROM  %s ' \
+              ' WHERE depth >= %d %s' \
+              ' GROUP BY subpath ORDER BY subpath' % (1+depth*cls.steplen,
+              cls._meta.db_table, depth, extrand)
+
+        cursor = connection.cursor()
+        cursor.execute(sql, sqlparms)
+        ret = cursor.fetchall()
+        paths = [path for path, _ in ret]
+        nodes = dict([(obj.path, obj)
+                      for obj in cls.objects.filter(path__in=paths)])
+        ret = [(nodes[path], count) for path, count in ret]
+        transaction.commit_unless_managed()
+        return ret
+
+
     def get_siblings(self):
         """
         :returns: A queryset of all the node's siblings, including the node

treebeard/tests.py

         self.assertEqual(self.got(TestSortedNodeShortPath), expected_sorted)
 
 
+class TestHelpers(TestTreeBase):
+
+    def setUp(self):
+        TestNode.load_bulk(BASE_DATA)
+        for node in TestNode.get_root_nodes():
+            TestNode.load_bulk(BASE_DATA, node)
+        TestNode.add_root(desc='5')
+
+    def test_descendants_group_count_root(self):
+        got = [(o.path, count)
+               for o, count in TestNode.get_descendants_group_count()]
+        expected = [('001', 10),
+                    ('002', 15),
+                    ('003', 10),
+                    ('004', 11),
+                    ('005', 0)]
+        self.assertEqual(got, expected)
+
+
+    def test_descendants_group_count_node(self):
+        parent = TestNode.objects.get(path='002')
+        got = [(o.path, count)
+               for o, count in TestNode.get_descendants_group_count(parent)]
+        expected = [('002001', 0),
+                    ('002002', 0),
+                    ('002003', 1),
+                    ('002004', 0),
+                    ('002005', 0),
+                    ('002006', 5),
+                    ('002007', 0),
+                    ('002008', 1)]
+        self.assertEqual(got, expected)
+
 
 #~
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.