Commits

Dimitri Tcaciuc committed ce080ea

If package names are passed in, they are recursively searched for modules.

Comments (0)

Files changed (1)

importgraph/__init__.py

         return find_module(components[1], module_filename)
 
 
+def find_package_modules(pkg_name, pkg_filename):
+    """Recursively finds all modules/sub-packages in a package."""
+    from pkgutil import iter_modules
+    modules = []
+    for (importer, name, ispkg) in iter_modules([pkg_filename]):
+        q_name = "{0}.{1}".format(pkg_name, name)
+        if ispkg:
+            modules.extend(find_package_modules(q_name, os.path.join(pkg_filename, name)))
+        else:
+            modules.append(q_name)
+    return modules
+
+
 class ImportVisitor(ast.NodeVisitor):
 
     def __init__(self, graph, root, include=None, visited=None, level=None):
 def main():
     argparser = argparse.ArgumentParser()
     argparser.add_argument("pymodules", nargs="+", default=[], help="Modules to analyze")
-#   argparser.add_argument("--include", nargs="+", default=[], metavar="MODULES",
-#                          help="List of modules names to ignore, spearated by a space")
 
     args = argparser.parse_args(sys.argv[1:])
 
     graph = networkx.DiGraph()
 
     for module_name in args.pymodules:
+        # print "iterating module_name"
         # Load the ast from the file being analyzed
-        (module_file, module_filename,
-         (module_ext, module_mode, module_type)) = find_module(module_name)
-
-        if module_type != imp.PKG_DIRECTORY:
-            # TODO list all modules under there.
-            pass
-        elif module_type != imp.PY_SOURCE:
+        try:
+            (module_file, module_filename,
+             (module_ext, module_mode, module_type)) = find_module(module_name)
+        except ImportError:
             continue
 
-        try:
-            tree = ast.parse(module_file.read())
-        finally:
-            module_file.close()
+        if module_type == imp.PKG_DIRECTORY:
+            args.pymodules.extend(find_package_modules(module_name, module_filename))
 
-        graph.add_node(module_name)
+        elif module_type == imp.PY_SOURCE:
+            try:
+                tree = ast.parse(module_file.read())
+            finally:
+                module_file.close()
 
-        # Anaylze the AST
-        include = set()
-        visitor = ImportVisitor(graph, module_name, include=include)
-        visitor.visit(tree)
+            graph.add_node(module_name)
+
+            # Anaylze the AST
+            include = set()
+            visitor = ImportVisitor(graph, module_name, include=include)
+            visitor.visit(tree)
 
     # Plot output the results
-    networkx.draw_graphviz(graph, prog='dot')
+    networkx.draw_graphviz(graph)#, prog='dot')
     plt.show()
 
 if __name__ == "__main__":