astoptimizer / astoptimizer / ast_tools.py

import ast
from astoptimizer.compatibility import (
    PYTHON3, COMPLEX_TYPES, BYTES_TYPE, UNICODE_TYPE)
import sys
import copy

def copy_lineno(node, new_node):
    ast.fix_missing_locations(new_node)
    ast.copy_location(new_node, node)
    return new_node

def new_constant(node, value):
    if isinstance(value, bool):
        name = "True" if value else "False"
        if sys.version_info >= (3, 4):
            # NameConstant was added to Python 3.4, see
            # http://bugs.python.org/issue16619
            new_node = ast.NameConstant(value=value)
        else:
            new_node = ast.Name(id=name, ctx=ast.Load())
    elif isinstance(value, COMPLEX_TYPES):
        new_node = ast.Num(n=value)
    elif isinstance(value, UNICODE_TYPE):
        if PYTHON3:
            new_node = ast.Str(s=value)
        else:
            new_node = ast.Str(s=value)
    elif isinstance(value, BYTES_TYPE):
        if PYTHON3:
            new_node = ast.Bytes(s=value)
        else:
            new_node = ast.Str(s=value)
    elif value is None:
        if sys.version_info >= (3, 4):
            # NameConstant was added to Python 3.4, see
            # http://bugs.python.org/issue16619
            new_node = ast.NameConstant(value=None)
        else:
            new_node = ast.Name(id="None", ctx=ast.Load())
    elif isinstance(value, tuple):
        elts = [new_constant(node, elt) for elt in value]
        new_node = ast.Tuple(elts=elts, ctx=ast.Load())
    elif isinstance(value, frozenset):
        if all(isinstance(elt, UNICODE_TYPE) for elt in value):
            arg = new_constant(node, UNICODE_TYPE().join(sorted(value)))
        elif all(isinstance(elt, BYTES_TYPE) for elt in value):
            arg = new_constant(node, BYTES_TYPE().join(sorted(value)))
        else:
            elts = [new_constant(node, elt) for elt in value]
            arg = ast.Tuple(elts=elts, ctx=ast.Load())
            copy_lineno(node, arg)
        func = ast.Name(id='frozenset', ctx=ast.Load())
        new_node = ast.Call(func, [arg], [], None, None)
    else:
        raise NotImplementedError("unable to create an AST object for constant: %r" % (value,))
    return copy_lineno(node, new_node)

def _new_constant_list(node, elts):
    return [new_constant(node, elt) for elt in elts]

def new_tuple_elts(node, elts=None):
    if elts is None:
        elts = []
    new_node = ast.Tuple(elts=elts, ctx=ast.Load())
    return copy_lineno(node, new_node)

def new_tuple(node, iterable=()):
    elts = _new_constant_list(node, iterable)
    return new_tuple_elts(node, elts)

def new_list_elts(node, elts=None):
    if elts is None:
        elts = []
    new_node = ast.List(elts=elts, ctx=ast.Load())
    return copy_lineno(node, new_node)

def new_list(node, iterable=()):
    elts = _new_constant_list(node, iterable)
    return new_list_elts(node, elts)

def sort_set_elts(elts):
    elts = list(elts)
    try:
        # sort elements for astoptimizer unit tests
        elts.sort()
    except TypeError:
        # elements may be unsortable
        pass
    return elts

def new_dict_elts(node, keys=None, values=None):
    if keys is None:
        keys = []
    if values is None:
        values = []
    new_node = ast.Dict(keys=keys, values=values)
    return copy_lineno(node, new_node)

if sys.version_info >= (2, 7):
    def new_set_elts(node, elts=None):
        if elts is None:
            elts = []
        new_node = ast.Set(elts=elts)
        return copy_lineno(node, new_node)

    def new_set(node, iterable=()):
        elts = sort_set_elts(iterable)
        elts = _new_constant_list(node, elts)
        return new_set_elts(node, elts)

def new_literal(node, value):
    if isinstance(value, list):
        return new_list(node, value)
    elif sys.version_info >= (2, 7) and isinstance(value, set):
        return new_set(node, value)
    else:
        return new_constant(node, value)

def iter_all_ast(node):
    yield node
    for field, value in ast.iter_fields(node):
        if isinstance(value, list):
            for item in value:
                if isinstance(item, ast.AST):
                    for child in iter_all_ast(item):
                        yield child
        elif isinstance(value, ast.AST):
            for child in iter_all_ast(value):
                yield child

def ast_contains(tree, obj_type):
    if isinstance(tree, list):
        return any(ast_contains(node, obj_type) for node in tree)
    else:
        return any(isinstance(node, obj_type) for node in iter_all_ast(tree))

def new_call(node, name, *args):
    # name: str
    # args: ast objects
    name = ast.Name(id=name, ctx=ast.Load())
    copy_lineno(node, name)
    new_node = ast.Call(
        func=name,
        args=list(args),
        keywords=[],
        starargs=None,
        kwargs=None)
    return copy_lineno(node, new_node)

def check_func_args(node, min_narg=None, max_narg=None):
    keywords = node.keywords
    starargs = node.starargs
    kwargs = node.kwargs
    # Don't support keywords, *args, **kw yet
    if keywords or starargs or kwargs:
        return False
    if min_narg is not None and len(node.args) < min_narg:
        return False
    if max_narg is not None and len(node.args) > max_narg:
        return False
    return True

def new_pass(node):
    new_node = ast.Pass()
    return copy_lineno(node, new_node)

def clone_node_list(node_list):
    # FIXME: use something faster? or more specialized?
    return copy.deepcopy(node_list)

def is_empty_body(node_list):
    if len(node_list) == 0:
        return True
    if len(node_list) != 1:
        return False
    node = node_list[0]
    return isinstance(node, ast.Pass)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.