Source

astoptimizer / astoptimizer / ast_tools.py

import ast
from astoptimizer.compatibility import (
    PYTHON3, COMPLEX_TYPES, BYTES_TYPE, UNICODE_TYPE)
import sys

def copy_lineno(node, new_node):
    ast.fix_missing_locations(new_node)
    ast.copy_location(new_node, node)
    return new_node

def new_constant(node, value):
    if isinstance(value, bool):
        name = "True" if value else "False"
        new_node = ast.Name(id=name, ctx=ast.Load())
    elif isinstance(value, COMPLEX_TYPES):
        new_node = ast.Num(n=value)
    elif isinstance(value, UNICODE_TYPE):
        if PYTHON3:
            new_node = ast.Str(s=value)
        else:
            new_node = ast.Str(s=value)
    elif isinstance(value, BYTES_TYPE):
        if PYTHON3:
            new_node = ast.Bytes(s=value)
        else:
            new_node = ast.Str(s=value)
    elif value is None:
        new_node = ast.Name(id="None", ctx=ast.Load())
    elif isinstance(value, tuple):
        elts = [new_constant(node, elt) for elt in value]
        new_node = ast.Tuple(elts=elts, ctx=ast.Load())
    elif isinstance(value, frozenset):
        if all(isinstance(elt, UNICODE_TYPE) for elt in value):
            arg = new_constant(node, UNICODE_TYPE().join(sorted(value)))
        elif all(isinstance(elt, BYTES_TYPE) for elt in value):
            arg = new_constant(node, BYTES_TYPE().join(sorted(value)))
        else:
            elts = [new_constant(node, elt) for elt in value]
            arg = ast.Tuple(elts=elts, ctx=ast.Load())
            copy_lineno(node, arg)
        func = ast.Name(id='frozenset', ctx=ast.Load())
        new_node = ast.Call(func, [arg], [], None, None)
    else:
        raise NotImplementedError("unable to create an AST object for constant: %r" % (value,))
    return copy_lineno(node, new_node)

def new_literal(node, value):
    if isinstance(value, list):
        elts = [new_constant(node, elt) for elt in value]
        new_node = ast.List(elts=elts, ctx=ast.Load())
        return copy_lineno(node, new_node)
    else:
        return new_constant(node, value)

def new_list(node, elts=None):
    if elts is None:
        elts = []
    new_node = ast.List(elts=elts, ctx=ast.Load())
    return copy_lineno(node, new_node)

if sys.version_info >= (2, 7):
    def new_set_elts(node, elts=None):
        if elts is None:
            elts = []
        new_node = ast.Set(elts=elts)
        return copy_lineno(node, new_node)

    def new_set(node, iterable=()):
        elts = [new_constant(node, elt) for elt in iterable]
        return new_set_elts(node, elts)

def iter_all_ast(node):
    yield node
    for field, value in ast.iter_fields(node):
        if isinstance(value, list):
            for item in value:
                if isinstance(item, ast.AST):
                    for child in iter_all_ast(item):
                        yield child
        elif isinstance(value, ast.AST):
            for child in iter_all_ast(value):
                yield child

def ast_contains(tree, obj_type):
    if isinstance(tree, list):
        return any(ast_contains(node, obj_type) for node in tree)
    else:
        return any(isinstance(node, obj_type) for node in iter_all_ast(tree))

def new_call(node, name, *args):
    # name: str
    # args: ast objects
    name = ast.Name(id=name, ctx=ast.Load())
    copy_lineno(node, name)
    new_node = ast.Call(
        func=name,
        args=list(args),
        keywords=[],
        starargs=None,
        kwargs=None)
    return copy_lineno(node, new_node)

def check_func_args(node, min_narg=None, max_narg=None):
    keywords = node.keywords
    starargs = node.starargs
    kwargs = node.kwargs
    # Don't support keywords, *args, **kw yet
    if keywords or starargs or kwargs:
        return False
    if min_narg is not None and len(node.args) < min_narg:
        return False
    if max_narg is not None and len(node.args) > max_narg:
        return False
    return True

def new_pass(node):
    new_node = ast.Pass()
    return copy_lineno(node, new_node)