Source

astoptimizer / astoptimizer / config_builtin_funcs.py

Full commit
"""
Enable optimizations on the builtins functions, example: len().
"""
from astoptimizer.compatibility import (
    is_bytes_ascii, is_unicode_ascii,
    PYTHON2, PYTHON3,
    INT_TYPES, FLOAT_TYPES, COMPLEX_TYPES,
    BYTES_TYPE, UNICODE_TYPE, STR_TYPES, NATIVE_STR,
    IMMUTABLE_ITERABLE_TYPES, IMMUTABLE_TYPES)
from astoptimizer.config import Function
import sys

def check_ord_args(config, args):
    arg = args[0]
    return (len(arg) == 1)

def check_max_args(config, args):
    arg = args[0]
    if isinstance(arg, STR_TYPES):
        return True
    return all(isinstance(item, COMPLEX_TYPES) for item in arg)

def check_sum_args(config, args):
    arg = args[0]
    if not all(isinstance(item, COMPLEX_TYPES) for item in arg):
        return False
    if len(args) == 2:
        if not isinstance(args[1], COMPLEX_TYPES):
            return False
    return True

def check_len(config, result):
    return (result <= config.max_size)

if PYTHON2:
    def check_str_args(config, args):
        arg = args[0]
        if isinstance(arg, UNICODE_TYPE):
            return is_unicode_ascii(arg)
        else:
            return True

    def check_unicode_args(config, args):
        arg = args[0]
        if isinstance(arg, BYTES_TYPE):
            return is_bytes_ascii(arg)
        else:
            return True

def check_pow(config, args):
    if len(args) >= 3:
        if not all(isinstance(arg, INT_TYPES) for arg in args[:2]):
            return False

    num = args[0]
    exp = args[1]
    if exp < 1.0 and exp != 0.0 and num < 0:
        # pow(-25, 0.5) raises a ValueError
        return False
    if len(args) >= 3:
        mod = args[2]
        if mod == 0:
            # pow(2, 1024, 0) raises a ValueError('pow() 3rd argument cannot be 0')
            return False
    return True

def check_chr(config, args):
    code = args[0]
    return 0 <= code <= 0xff

def check_unichr(config, args):
    code = args[0]
    if PYTHON3:
        return 0 <= code <= 0x10ffff
    else:
        return 0 <= code <= sys.maxunicode

def setup_config(config):
    # pure builtin functions
    config.add_func('abs', Function(abs, 1, COMPLEX_TYPES))
    config.add_func('bin', Function(bin, 1, INT_TYPES))
    config.add_func('bool', Function(bool, 1, FLOAT_TYPES + STR_TYPES))
    if PYTHON3:
        config.add_func('chr', Function(chr, 1, INT_TYPES, check_args=check_unichr))
    else:
        config.add_func('chr', Function(chr, 1, INT_TYPES, check_args=check_chr))
    config.add_func('divmod', Function(divmod, 2, FLOAT_TYPES, FLOAT_TYPES))
    config.add_func('float', Function(float, 1, FLOAT_TYPES + STR_TYPES, catch=ValueError))
    config.add_func('int', Function(int, 1, FLOAT_TYPES + STR_TYPES, catch=ValueError))
    config.add_func('len', Function(len, 1, IMMUTABLE_ITERABLE_TYPES, check_result=check_len))
    config.add_func('oct', Function(oct, 1, INT_TYPES))
    config.add_func('ord', Function(ord, 1, STR_TYPES, check_args=check_ord_args))
    config.add_func('min', Function(min, 1, STR_TYPES + (tuple, frozenset), check_args=check_max_args))
    config.add_func('max', Function(max, 1, STR_TYPES + (tuple, frozenset), check_args=check_max_args))
    config.add_func('pow', Function(pow, (2, 3), FLOAT_TYPES, FLOAT_TYPES, FLOAT_TYPES, check_args=check_pow))
    config.add_func('repr', Function(repr, 1, STR_TYPES + COMPLEX_TYPES))
    config.add_func('round', Function(round, (1, 2), FLOAT_TYPES, INT_TYPES))
    if PYTHON3:
        config.add_func('str', Function(str, 1, COMPLEX_TYPES + (UNICODE_TYPE,)))
    else:
        config.add_func('str', Function(str, 1, COMPLEX_TYPES + STR_TYPES, check_args=check_str_args))
    config.add_func('sum', Function(sum, (1, 2), (tuple, frozenset), COMPLEX_TYPES, check_args=check_sum_args))

    if PYTHON2:
        config.add_func('long', Function(long, 1, FLOAT_TYPES))
        config.add_func('unichr', Function(unichr, 1, INT_TYPES, check_args=check_unichr))
        config.add_func('unicode', Function(unicode, 1, STR_TYPES + COMPLEX_TYPES, check_args=check_unicode_args))

    # float
    config.add_func('float.fromhex', Function(float.fromhex, 1, NATIVE_STR))