Source

importgraph / importgraph / __init__.py

#!/usr/bin/env python
import ast
import argparse
import imp
import networkx
import matplotlib.pyplot as plt
import os
import sys
from collections import namedtuple

ModuleInfo = namedtuple("ModuleInfo", "name,path")
MAX_DEPTH = 5

# TODO: Generate this in some better way. This is only an example
STDLIB_MODULES = ['os', 'subprocess', 'sys', 'inspect', 'functools', 'collections']

def find_module(module):
    """Find the file containing a module."""
    module_filename = None
    basename = module.replace(".", "/")
    # Looking for module file; can be either a .py 
    # file or a python package.
    for path in sys.path:
        # TODO: probably better to check if dirname(path)
        # exists first, then look for specific types of files.
        module_filename = os.path.join(path, basename + ".py")
        if os.access(module_filename, os.F_OK):
            break
        module_filename = os.path.join(path, basename, "__init__.py")
        if os.access(module_filename, os.F_OK):
            break
        module_filename = os.path.join(path, basename + ".so")
        if os.access(module_filename, os.F_OK):
            break
    return module_filename

class ImportVisitor(ast.NodeVisitor):

    def __init__(self, graph, root, ignore=None, visited=None, level=None):
        self.graph = graph
        self.root = root
        self.ignore = ignore or set()
        self.visited = visited or []
        self.level = level or 1

    def visit_Import(self, node):
        for alias in node.names:
            self._process_import(alias.name)
        self.generic_visit(node)

    def visit_ImportFrom(self, node):
        self._process_import(node.module)
        self.generic_visit(node)

    def _process_import(self, module):
        if (module in self.visited
            or self.level > MAX_DEPTH):
            return

        self.visited.append(module)

        # TODO add filtering
        display_name = module.split(".", 1)[0]
        self.graph.add_edge(display_name, self.root)

        module_filename = find_module(module)

        # Recurse if possible.
        if module not in self.ignore and module_filename:
            if module_filename.endswith(".py"):
                print "processing", module, module_filename
                try:
                    with open(module_filename) as code:
                        tree = ast.parse(code.read())
                    visitor = ImportVisitor(self.graph, 
                                            display_name, 
                                            visited=self.visited, 
                                            ignore=self.ignore,
                                            level=(self.level + 1))
                    visitor.visit(tree)
                except SyntaxError, e:
                    pass

def main():
    argparser = argparse.ArgumentParser()
    argparser.add_argument("pyfile", help="Source file to analyze")
    argparser.add_argument("--ignore-stdlib", action='store_true',
                           help="Don't inspect modules in Python's standard library")
    argparser.add_argument("--ignore", nargs="+", default=[], metavar="MODULE",
                           help="List of modules names to ignore, spearated by a space")
    args = argparser.parse_args(sys.argv[1:])

    # Load the ast from the file being analyzed
    with open(args.pyfile) as code:
        tree = ast.parse(code.read())

    # Create the intial graph
    graph = networkx.DiGraph()
    graph.add_node(args.pyfile)

    # Anaylze the AST
    ignore = set()
    if args.ignore_stdlib:
        ignore.update(STDLIB_MODULES)
    ignore.update(args.ignore)

    visitor = ImportVisitor(graph, args.pyfile, ignore=ignore)
    visitor.visit(tree)

    # Plot output the results
    networkx.draw_graphviz(graph)#, prog='dot')
    plt.show()

if __name__ == "__main__":
    main()