Source

astoptimizer / astoptimizer / ast_tools.py

Full commit
Victor Stinner 6d0e38b 


Victor Stinner 82f1110 
Victor Stinner 014eaa2 
Victor Stinner 6d0e38b 
Victor Stinner 5a926a3 







Victor Stinner 010d52c 





Victor Stinner 5a926a3 












Victor Stinner 010d52c 





Victor Stinner 5a926a3 

















Victor Stinner 0de3010 

Victor Stinner dd3a4a9 
Victor Stinner 89fab93 









Victor Stinner 0de3010 
Victor Stinner 05e1f59 




Victor Stinner 0de3010 













Victor Stinner e96e0fd 







Victor Stinner 82f1110 
Victor Stinner 05e1f59 





Victor Stinner 82f1110 
Victor Stinner 0de3010 

Victor Stinner 05e1f59 
Victor Stinner 82f1110 
Victor Stinner 0de3010 







Victor Stinner 5a926a3 

















Victor Stinner a277a3c 





























Victor Stinner 014eaa2 


Victor Stinner 66191d2 







import ast
from astoptimizer.compatibility import (
    PYTHON3, COMPLEX_TYPES, BYTES_TYPE, UNICODE_TYPE)
import sys
import copy

def copy_lineno(node, new_node):
    ast.fix_missing_locations(new_node)
    ast.copy_location(new_node, node)
    return new_node

def new_constant(node, value):
    if isinstance(value, bool):
        name = "True" if value else "False"
        if sys.version_info >= (3, 4):
            # NameConstant was added to Python 3.4, see
            # http://bugs.python.org/issue16619
            new_node = ast.NameConstant(value=value)
        else:
            new_node = ast.Name(id=name, ctx=ast.Load())
    elif isinstance(value, COMPLEX_TYPES):
        new_node = ast.Num(n=value)
    elif isinstance(value, UNICODE_TYPE):
        if PYTHON3:
            new_node = ast.Str(s=value)
        else:
            new_node = ast.Str(s=value)
    elif isinstance(value, BYTES_TYPE):
        if PYTHON3:
            new_node = ast.Bytes(s=value)
        else:
            new_node = ast.Str(s=value)
    elif value is None:
        if sys.version_info >= (3, 4):
            # NameConstant was added to Python 3.4, see
            # http://bugs.python.org/issue16619
            new_node = ast.NameConstant(value=None)
        else:
            new_node = ast.Name(id="None", ctx=ast.Load())
    elif isinstance(value, tuple):
        elts = [new_constant(node, elt) for elt in value]
        new_node = ast.Tuple(elts=elts, ctx=ast.Load())
    elif isinstance(value, frozenset):
        if all(isinstance(elt, UNICODE_TYPE) for elt in value):
            arg = new_constant(node, UNICODE_TYPE().join(sorted(value)))
        elif all(isinstance(elt, BYTES_TYPE) for elt in value):
            arg = new_constant(node, BYTES_TYPE().join(sorted(value)))
        else:
            elts = [new_constant(node, elt) for elt in value]
            arg = ast.Tuple(elts=elts, ctx=ast.Load())
            copy_lineno(node, arg)
        func = ast.Name(id='frozenset', ctx=ast.Load())
        new_node = ast.Call(func, [arg], [], None, None)
    else:
        raise NotImplementedError("unable to create an AST object for constant: %r" % (value,))
    return copy_lineno(node, new_node)

def _new_constant_list(node, elts):
    return [new_constant(node, elt) for elt in elts]

def new_tuple_elts(node, elts=None):
    if elts is None:
        elts = []
    new_node = ast.Tuple(elts=elts, ctx=ast.Load())
    return copy_lineno(node, new_node)

def new_tuple(node, iterable=()):
    elts = _new_constant_list(node, iterable)
    return new_tuple_elts(node, elts)

def new_list_elts(node, elts=None):
    if elts is None:
        elts = []
    new_node = ast.List(elts=elts, ctx=ast.Load())
    return copy_lineno(node, new_node)

def new_list(node, iterable=()):
    elts = _new_constant_list(node, iterable)
    return new_list_elts(node, elts)

def sort_set_elts(elts):
    elts = list(elts)
    try:
        # sort elements for astoptimizer unit tests
        elts.sort()
    except TypeError:
        # elements may be unsortable
        pass
    return elts

def new_dict_elts(node, keys=None, values=None):
    if keys is None:
        keys = []
    if values is None:
        values = []
    new_node = ast.Dict(keys=keys, values=values)
    return copy_lineno(node, new_node)

if sys.version_info >= (2, 7):
    def new_set_elts(node, elts=None):
        if elts is None:
            elts = []
        new_node = ast.Set(elts=elts)
        return copy_lineno(node, new_node)

    def new_set(node, iterable=()):
        elts = sort_set_elts(iterable)
        elts = _new_constant_list(node, elts)
        return new_set_elts(node, elts)

def new_literal(node, value):
    if isinstance(value, list):
        return new_list(node, value)
    elif sys.version_info >= (2, 7) and isinstance(value, set):
        return new_set(node, value)
    else:
        return new_constant(node, value)

def iter_all_ast(node):
    yield node
    for field, value in ast.iter_fields(node):
        if isinstance(value, list):
            for item in value:
                if isinstance(item, ast.AST):
                    for child in iter_all_ast(item):
                        yield child
        elif isinstance(value, ast.AST):
            for child in iter_all_ast(value):
                yield child

def ast_contains(tree, obj_type):
    if isinstance(tree, list):
        return any(ast_contains(node, obj_type) for node in tree)
    else:
        return any(isinstance(node, obj_type) for node in iter_all_ast(tree))

def new_call(node, name, *args):
    # name: str
    # args: ast objects
    name = ast.Name(id=name, ctx=ast.Load())
    copy_lineno(node, name)
    new_node = ast.Call(
        func=name,
        args=list(args),
        keywords=[],
        starargs=None,
        kwargs=None)
    return copy_lineno(node, new_node)

def check_func_args(node, min_narg=None, max_narg=None):
    keywords = node.keywords
    starargs = node.starargs
    kwargs = node.kwargs
    # Don't support keywords, *args, **kw yet
    if keywords or starargs or kwargs:
        return False
    if min_narg is not None and len(node.args) < min_narg:
        return False
    if max_narg is not None and len(node.args) > max_narg:
        return False
    return True

def new_pass(node):
    new_node = ast.Pass()
    return copy_lineno(node, new_node)

def clone_node_list(node_list):
    # FIXME: use something faster? or more specialized?
    return copy.deepcopy(node_list)

def is_empty_body(node_list):
    if len(node_list) == 0:
        return True
    if len(node_list) != 1:
        return False
    node = node_list[0]
    return isinstance(node, ast.Pass)