Benoit Boissinot avatar Benoit Boissinot committed 613636a

refactor bfs

Comments (0)

Files changed (2)

src/discovery_tonfa.py

         if len(nodes) <= MAX_SAMPLE:
             return set(nodes)
 
+
         dag = self.dag
 
         self.writer.indent()
         heads = dag.headsofconnectedset(nodes)
         self.writer.done()
 
-        sample = set()
 
-        dist = {}
-        order = []
-        visit = list(heads)
-        seen = set(stop)
-        cands = []
+        def bfs(start, parents):
+            sample = set()
+            dist = {}
+            order = []
+            visit = list(start)
+            seen = set()
+
+            roots = set()
+
+            while visit:
+                curr = visit.pop(0)
+                if curr in seen:
+                    continue
+                d = dist.setdefault(curr, 1)
+                order.append(curr)
+                seen.add(curr)
+
+                isroot = True
+                for p in parents(curr):
+                    if p not in nodes:
+                        continue
+                    dist.setdefault(p, d+1)
+                    visit.append(p)
+                    isroot = False
+                if isroot:
+                    roots.add(curr)
+
+            self.writer.step("sample")
+            factor = 1
+            for n in order:
+                if dist[n] > factor:
+                    factor *= 2
+                if dist[n] == factor:
+                    sample.add(n)
+            return sample, roots
 
         self.writer.step("heads -> roots")
-        roots = set()
-
-        while visit:
-            curr = visit.pop(0)
-            if curr in seen:
-                continue
-            d = dist.setdefault(curr, 1)
-            order.append(curr)
-            seen.add(curr)
-            cands.append(curr)
-
-            if not len(list(p for p in dag.parents(curr) if p not in stop)):
-                roots.add(curr)
-
-            for p in dag.parents(curr):
-                dist.setdefault(p, d+1)
-                visit.append(p)
-
-        self.writer.step("sample")
-        factor = 1
-        for n in order:
-            if dist[n] > factor:
-                factor *= 2
-            if dist[n] == factor:
-                sample.add(n)
+        downsample, roots = bfs(heads, dag.parents)
 
         self.writer.step("roots -> heads")
-        visit = list(roots)
-        order = []
-        dist = {}
-        seen = set()
-        while visit:
-            curr = visit.pop(0)
-            if curr in seen:
-                continue
-            d = dist.setdefault(curr, 1)
-            order.append(curr)
-            seen.add(curr)
-            for c in dag.children(curr):
-                if c not in nodes:
-                    continue
-                dist.setdefault(c, d+1)
-                visit.append(c)
-
-        self.writer.step("sample")
-        factor = 1
-        for n in order:
-            if dist[n] > factor:
-                factor *= 2
-            if dist[n] == factor:
-                sample.add(n)
+        upsample, roots = bfs(roots, dag.children)
 
         self.writer.step("finalize sample")
+        sample = upsample | downsample
         assert sample
         sample.difference_update(heads)
         if len(sample)+len(heads) > MAX_SAMPLE:
         elif len(sample)+len(heads) < 200:
             more = MAX_SAMPLE - len(sample) - len(heads)
             self.writer.step("filling with %d random samples" % more)
-            sample.update(random.sample(list(set(cands) - sample - heads), more))
+            sample.update(random.sample(nodes - sample - heads, more))
         sample.update(heads)
 
         self.writer.unindent()
         assertnodes(list(self.expected), list(actual))
 
 if __name__ == "__main__":
-    random.seed(0)
+    random.seed(42)
     Tests().testall()
             self.testdag(name, desc)
 
     def testall(self):
-#        self.testpredefined()
+        #self.testpredefined()
+        #self.testfile("../data/linux.dag", pairs=[([160371], [162753])])
         self.testfile("../data/linux.dag")
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.