1. Fredrik Håård
  2. blaag.haard.se


blaag.haard.se / code / constants / inlining.py

import ast, imp, sys, re, _ast, os.path, types

RE_CONSTANT = re.compile('[A-Z_]+')

class ConstantMaker(ast.NodeTransformer):
    """NodeTransformer that will inline any Number and String 
    constants defined on the module level wherever they are used
    throughout the module. Any Number or String variable matching [A-Z_]+ 
    on the module level will be used as a constant"""

    def __init__(self):
        self._constants = {}
        super(ConstantMaker, self).__init__()

    def visit_Module(self, node):
        """Find eglible variables to be inlined and store
        the Name->value mapping in self._constants for later use"""

        assigns = [x for x in node.body if 
                   type(x) == _ast.Assign]

        for assign in assigns:
            if type(assign.value) in (_ast.Num, _ast.Str):
                for name in assign.targets:
                    if RE_CONSTANT.match(name.id):
                        self._constants[name.id] = assign.value

        return self.generic_visit(node)

    def visit_Name(self, node):
        """If node.id is in self._constants, replace the
        loading of the node with the actual value"""
        return self._constants.get(node.id, node)

def transform(src):
    """Transforms the given source and return the AST"""
    tree = ast.parse(src)
    cm = ConstantMaker()
    newtree = cm.visit(tree)
    return newtree

class InliningImporter(object):
    """Importer to be put on meta_path to instrument any
    subsequent imports"""
    def __init__(self):
        self._cache = {}

    def find_module(self, name, path=None):
            suffix = name.split('.')[-1]
            self._cache[name] = imp.find_module(suffix)
        except ImportError:
            return None
        return self

    def load_module(self, name):
        """Load a module and instrument it. Will fallback to default 
        behaviour if there is no source available, or if the module to
        be imported is not in (PY_SOURCE, PY_COMPILED, PY_DIRECTORY)"""

        module = types.ModuleType(name) #create empty module object
        fd, pathname, (suffix, mode, type_) = self._cache[name] 

        with fd:
            if type_ == imp.PY_SOURCE:
                filename = pathname
            elif type_ == imp.PY_COMPILED:
                filename = pathname[:-1]
            elif type_ == imp.PKG_DIRECTORY:
                filename = os.path.join(pathname, '__init__.py')
                module.__path__ = [pathname]
                return imp.load_module(name, fd, pathname, 
                                       (suffix, mode, type_))        
            if not filename == pathname:
                    with open(filename, 'U') as realfile:
                        src = realfile.read()
                except IOError: #fallback
                    return imp.load_module(name, fd, pathname, 
                                           (suffix, mode, type_))
                src = fd.read()

        module.__file__ = filename

        module = types.ModuleType(name)
        inlined = transform(src)
        code = compile(inlined, filename, 'exec')
        sys.modules[name] = module
        exec(code,  module.__dict__)
        return module

def install_hook():
    """Install the import hook"""
    importer = InliningImporter()
    sys.meta_path.insert(0, importer)