Commits

Victor Stinner  committed 89fab93

clean optimizer_iter()

  • Participants
  • Parent commits 6bf8c42

Comments (0)

Files changed (4)

File astoptimizer/ast_tools.py

 def _new_constant_list(node, elts):
     return [new_constant(node, elt) for elt in elts]
 
+def new_tuple_elts(node, elts=None):
+    if elts is None:
+        elts = []
+    new_node = ast.Tuple(elts=elts, ctx=ast.Load())
+    return copy_lineno(node, new_node)
+
+def new_tuple(node, iterable=()):
+    elts = _new_constant_list(node, iterable)
+    return new_tuple_elts(node, elts)
+
 def new_list_elts(node, elts=None):
     if elts is None:
         elts = []

File astoptimizer/compatibility.py

 import sys
 
 PYTHON2 = sys.version_info < (3,)
+PYTHON27 = sys.version_info >= (2, 7)
 PYTHON3 = sys.version_info >= (3,)
 
 try:
 else:
     INT_TYPES = (int, long)
 FLOAT_TYPES = INT_TYPES + (float,)
-COMPLEX_TYPES = INT_TYPES + (float, complex)
+COMPLEX_TYPES = FLOAT_TYPES + (complex,)
 
 IMMUTABLE_TYPES = COMPLEX_TYPES + STR_TYPES + (bool, NONE_TYPE)
 IMMUTABLE_ITERABLE_TYPES = STR_TYPES + (tuple, frozenset)

File astoptimizer/optimizer.py

 
 from astoptimizer import UNSET
 from astoptimizer.ast_tools import (
-    copy_lineno, sort_set_elts,
-    new_constant, new_literal, new_call, new_pass, new_list_elts,
+    copy_lineno,
+    new_constant, new_literal, new_call, new_pass,
+    sort_set_elts, new_tuple, new_tuple_elts, new_list, new_list_elts,
     ast_contains, check_func_args)
-if sys.version_info >= (2, 7):
-    from astoptimizer.ast_tools import new_set, new_set_elts
 from astoptimizer.config import optimize_unicode
 from astoptimizer.compatibility import (
     u,
-    PYTHON2, PYTHON3,
+    PYTHON2, PYTHON27, PYTHON3,
     is_bytes_ascii, is_singleton,
     INT_TYPES, FLOAT_TYPES, COMPLEX_TYPES, NONE_TYPE,
     BYTES_TYPE, UNICODE_TYPE, STR_TYPES,
     IMMUTABLE_ITERABLE_TYPES, ITERABLE_TYPES)
+if PYTHON27:
+    from astoptimizer.ast_tools import new_set, new_set_elts
 
 DROP_NODE = object()
 
             return False
         return self.check_func(node, name, min_narg, max_narg)
 
-    def get_constant(self, node, want_type=None):
+    def get_constant(self, node, to_type=None):
         if node is None:
             constant = None
         elif isinstance(node, ast.Num):
         elif isinstance(node, ast.Name):
             constant = self.load_name(node.id)
         elif isinstance(node, ast.Tuple):
-            if (want_type is not None
-            and not issubclass(tuple, want_type)):
+            if (to_type is not None
+            and not issubclass(tuple, to_type)):
                 return UNSET
             elts = node.elts
             if len(elts) > self.config.max_tuple_length:
                 constants.append(constant)
             return tuple(constants)
         elif self.check_builtin_func(node, 'frozenset', 0, 1):
-            if (want_type is not None
-            and not issubclass(frozenset, want_type)):
+            if (to_type is not None
+            and not issubclass(frozenset, to_type)):
                 return UNSET
             if len(node.args) == 1:
                 arg = self.get_literal(node.args[0], ITERABLE_TYPES)
         else:
             return UNSET
 
-        if (want_type is not None
-        and not isinstance(constant, want_type)):
+        if (to_type is not None
+        and not isinstance(constant, to_type)):
             return UNSET
         return constant
 
-    def get_literal(self, node, want_type=None):
+    def get_literal(self, node, to_type=None, check_length=True):
         if isinstance(node, ast.List):
-            if (want_type is not None
-            and not issubclass(list, want_type)):
+            if (to_type is not None
+            and not issubclass(list, to_type)):
+                return UNSET
+            if (check_length
+            and len(node.elts) > self.config.max_tuple_length):
                 return UNSET
             result = []
             for elt in node.elts:
                     return UNSET
                 result.append(literal)
             return result
-        elif sys.version_info >= (2, 7) and isinstance(node, ast.Set):
-            if (want_type is not None
-            and not issubclass(set, want_type)):
+        elif PYTHON27 and isinstance(node, ast.Set):
+            if (to_type is not None
+            and not issubclass(set, to_type)):
+                return UNSET
+            if (check_length
+            and len(node.elts) > self.config.max_tuple_length):
                 return UNSET
             result = set()
             for elt in node.elts:
                 result.add(literal)
             return result
         else:
-            return self.get_constant(node, want_type)
+            return self.get_constant(node, to_type)
 
-    def get_constant_list(self, nodes, want_type=None):
+    def get_constant_list(self, nodes, to_type=None):
         constants = []
         for node in nodes:
-            constant = self.get_constant(node, want_type)
+            constant = self.get_constant(node, to_type)
             if constant is UNSET:
                 return UNSET
             constants.append(constant)
     def compare_in(self, data):
         if isinstance(data, ast.List):
             # x in [1, 2] => x in (1, 2)
-            return self.list_to_tuple(data)
+            return self.node_to_type(data, tuple)
 
         if sys.version_info >= (3, 2):
             constant = self.get_constant(data, frozenset)
         if isinstance(node, (ast.List, ast.Tuple)):
             return len(node.elts) == 0
 
-        constant = self.get_literal(node)
+        constant = self.get_literal(node, ITERABLE_TYPES)
         if constant is not UNSET:
-            if isinstance(constant, ITERABLE_TYPES):
-                return len(constant) == 0
-            else:
-                return False
+            return len(constant) == 0
 
         if self.check_builtin_func(node, BUILTIN_ACCEPTING_ITERABLE, 0, 0):
             return True
         and not node.values):
             return True
 
-        if (sys.version_info >= (2, 7)
+        if (PYTHON27
         and isinstance(node, ast.Set)
         and not node.elts):
             return True
         qualname = self.namespace.get_qualname(node.func.id)
         if len(node.args) == 1:
             arg = node.args[0]
-            if qualname == 'list':
-                want_type = list
+            if qualname == 'tuple':
+                to_type = tuple
+            elif qualname == 'list':
+                to_type = list
             elif qualname == 'set':
-                want_type = set
+                to_type = set
+            elif qualname == 'frozenset':
+                to_type = tuple
             else:
-                want_type = None
-            new_arg = self.optimize_iter(arg, want_type=want_type)
+                to_type = None
+            new_arg = self.optimize_iter(arg, to_type)
         else:
             arg = UNSET
             new_arg = DROP_NODE
             if qualname == 'dict':
                 new_node = ast.Dict(keys=[], values=[])
                 return copy_lineno(node, new_node)
-            if (sys.version_info >= (2, 7)
+            if (PYTHON27
             and qualname == 'set'):
                 return new_set(node)
             if arg is not UNSET:
                 # list([1, 2, 3]) => [1, 2, 3]
                 return arg
         elif qualname in ('frozenset', 'set'):
-            constant = self.get_literal(arg, ITERABLE_TYPES)
+            if (qualname == 'set'
+            and PYTHON27
+            and isinstance(node, ast.Set)):
+                # set({1, 2, 3}) => {1, 2, 3}
+                return arg
+
+            constant = self.get_literal(arg, ITERABLE_TYPES, check_length=False)
             if constant is not UNSET:
                 elts = frozenset(constant)
-                use_literal_set = (qualname == 'set' and sys.version_info >= (2, 7))
+                use_literal_set = (qualname == 'set' and PYTHON27)
                 if (len(elts) != len(constant)) or use_literal_set:
                     elts = sort_set_elts(elts)
                     if len(elts) <= self.config.max_tuple_length:
             self.namespace.assign(name, value)
         return True
 
-    def _get_assign_name(self, main_node, node):
+    def _get_assign_name(self, node):
         if isinstance(node, ast.Name):
             # var = value
             return (node.id, True)
         elif isinstance(node, ast.Attribute):
             # var.attr = value, var1.attr1.attr2 = value
-            result = self._get_assign_name(main_node, node.value)
+            result = self._get_assign_name(node.value)
             if result is None:
                 return None
             name, supported = result
             return (name, False)
         elif isinstance(node, ast.Subscript):
             # var[index] = value, var[a:b] = value
-            result = self._get_assign_name(main_node, node.value)
+            result = self._get_assign_name(node.value)
             if result is None:
                 return None
             name, supported = result
         self.namespace.assign(name, value)
 
     def assign(self, node, target, value, assign_supported):
-        result = self._get_assign_name(node, target)
+        result = self._get_assign_name(target)
         if result is not None:
             # x = value
             name, supported = result
         names = []
         supported = assign_supported
         for elt in target.elts:
-            result = self._get_assign_name(node, elt)
+            result = self._get_assign_name(elt)
             if result is None:
                 self.disable_vars(node)
                 return
         self.namespace.disable_vars()
 
     def unassign(self, node):
-        if isinstance(node, ast.Name):
-            self.namespace.unassign(node.id)
-        elif isinstance(node, ast.Tuple):
+        if isinstance(node, ast.Tuple):
             for elt in node.elts:
                 self.unassign(elt)
-        else:
+            return
+
+        result = self._get_assign_name(node)
+        if result is None:
             self.disable_vars(node)
+            return
+        name, supported = result
+        if not supported:
+            self.disable_vars(node)
+            return
+        self.namespace.unassign(name)
 
-    def optimize_range(self, node, want_list=False):
+    def optimize_range(self, node, to_type):
         args = self.get_constant_list(node.args)
         if args is UNSET:
             return
             # OverflowError: Python int too large to convert to C ssize_t
             pass
         else:
-            if range_len <= self.config.max_tuple_length:
+            if (self.config.remove_dead_code
+            and range_len <= self.config.max_tuple_length):
                 # range(3) => (0, 1, 2)
-                if want_list:
+                if to_type == list:
                     constant = list(numbers)
+                elif to_type == set:
+                    constant = set(numbers)
                 else:
                     constant = tuple(numbers)
                 return new_literal(node, constant)
         if not self.can_use_builtin('xrange'):
             return
 
-        # range(int, [int[, int]]) => xrange(...)
+        # range(...) => xrange(...)
         node.func.id = 'xrange'
         return node
 
-    def optimize_iter(self, node, is_generator=False, want_type=None):
+    def node_to_set(self, node):
+        if not isinstance(node, (ast.Tuple, ast.List)):
+            return
+
+        # [1, 2, 3] => {1, 2, 3}
+        literal = self.get_literal(node)
+        if literal is UNSET:
+            return
+        literal = sort_set_elts(literal)
+        if PYTHON27:
+            return new_set(node, literal)
+        else:
+            return new_tuple(node, literal)
+
+    def node_to_type(self, node, to_type):
+        if PYTHON27:
+            ast_types = (ast.Tuple, ast.List, ast.Set)
+        else:
+            ast_types = (ast.Tuple, ast.List)
+        if not isinstance(node, ast_types):
+            return
+        if len(node.elts) > self.config.max_tuple_length:
+            return
+
+        if isinstance(node, ast.Tuple):
+            if to_type == tuple:
+                # (1, 2, 3)
+                return node
+            if to_type == list:
+                # [1, 2, 3] => (1, 2, 3)
+                return new_list_elts(node, node.elts)
+            if to_type == set:
+                return self.node_to_set(node)
+        elif isinstance(node, ast.List):
+            if to_type == list:
+                return node
+            if to_type == tuple:
+                # [1, 2, 3] => (1, 2, 3)
+                return new_tuple_elts(node, node.elts)
+            if to_type == set:
+                return self.node_to_set(node)
+        elif isinstance(node, ast.Set):
+            if to_type == set:
+                return node
+            if to_type in (tuple, list):
+                literal = self.get_literal(node)
+                if literal is UNSET:
+                    return
+                literal = sort_set_elts(literal)
+                if to_type == tuple:
+                    # {3, 1, 2} => (1, 2, 3)
+                    return new_tuple(node, literal)
+                else:
+                    # {3, 1, 2} => [1, 2, 3]
+                    return new_list(node, literal)
+
+    def literal_to_type(self, node, literal, to_type):
+        if not isinstance(literal, ITERABLE_TYPES):
+            return UNSET
+
+        if to_type == set:
+            # "abc" => {"a", "b", "c"}
+            literal = set(literal)
+            if len(literal) > self.config.max_tuple_length:
+                return UNSET
+            if PYTHON27:
+                return new_literal(node, literal)
+        else:
+            if len(literal) > self.config.max_tuple_length:
+                return UNSET
+
+        if isinstance(literal, (frozenset, set)):
+            literal = sort_set_elts(literal)
+        if to_type == list:
+            # "abc" => ["a", "b", "c"]
+            literal = list(literal)
+        elif to_type in (tuple, set):
+            # "abc" => ("a", "b", "c")
+            literal = tuple(literal)
+        else:
+            return UNSET
+
+        return new_literal(node, literal)
+
+    def _optimize_iter(self, node, is_generator, to_type):
         if (self.config.remove_dead_code
         and self.is_empty_iterable(node)):
             # set("") => set()
             return DROP_NODE
 
-        if isinstance(node, ast.List):
-            if want_type != list:
-                # for x in [1, 2, 3]: ... => for x in (1, 2, 3): ...
-                return self.list_to_tuple(node)
-            else:
+        if PYTHON27:
+            ast_types = (ast.Tuple, ast.List, ast.Set)
+        else:
+            ast_types = (ast.Tuple, ast.List)
+        if isinstance(node, ast_types):
+            return self.node_to_type(node, to_type)
+
+        literal = self.get_literal(node, ITERABLE_TYPES)
+        if literal is not UNSET:
+            new_literal = self.literal_to_type(node, literal, to_type)
+            if new_literal is UNSET:
                 return
+            return new_literal
 
-        if PYTHON2:
-            is_range = (self.check_builtin_func(node, 'range', 1, 3)
-                        or self.check_builtin_func(node, 'xrange', 1, 3))
-        else:
-            is_range = self.check_builtin_func(node, 'range', 1, 3)
-        if is_range:
-            return self.optimize_range(node, want_type == list)
+        if isinstance(node, ast.GeneratorExp):
+            # (x for x in "abc") => "abc"
+            new_iter = self.optimize_comprehension(node,
+                                                   is_generator=False,
+                                                   to_type=to_type)
+            if new_iter is not None:
+                return new_iter
+
+        if (not is_generator
+        and self.config.remove_dead_code
+        and self.check_builtin_func(node, ('list', 'tuple'),  1, 1)):
+            # list(iterable) => iterable
+            return node.args[0]
+
+        if (to_type == set
+        and self.config.remove_dead_code
+        and self.check_builtin_func(node, ('frozenset', 'set'),  1, 1)):
+            # set(iterable) => iterable
+            return node.args[0]
 
         if (self.config.remove_dead_code
         and self.check_builtin_func(node, 'iter', 1, 1)):
             # set(iter(iterable)) => set(iterable)
             return iter_arg
 
-        if (not is_generator
-        and self.config.remove_dead_code
-        and self.check_builtin_func(node, ('list', 'tuple'),  1, 1)):
-            # list(iterable) => iterable
-            return node.args[0]
+        if PYTHON2:
+            range_names = ('range', 'xrange')
+        else:
+            range_names = ('range',)
+        if self.check_builtin_func(node, range_names, 1, 3):
+            return self.optimize_range(node, to_type)
 
-        if (want_type == set
-        and self.config.remove_dead_code
-        and self.check_builtin_func(node, ('frozenset', 'set'),  1, 1)):
-            # set(iterable) => iterable
-            return node.args[0]
+    def optimize_generator(self, node):
+        return self._optimize_iter(node, True, tuple)
 
-        if isinstance(node, ast.GeneratorExp):
-            # (x for x in "abc") => "abc"
-            if is_generator:
-                node_is_generator = True
-                node_want_type = None
-            else:
-                node_is_generator = (want_type != list)
-                node_want_type = list
-            new_iter = self.optimize_comprehension(node,
-                                                   is_generator=node_is_generator,
-                                                   want_type=node_want_type)
-            if new_iter is not None:
-                return new_iter
-
-        if want_type == list and isinstance(node, ast.Tuple):
-            # [1, 2, 3] => (1, 2, 3)
-            return new_list_elts(node, node.elts)
-
-        constant = self.get_literal(node)
-        if (constant is not UNSET
-        and isinstance(constant, ITERABLE_TYPES)):
-            if len(constant) <= self.config.max_tuple_length:
-                if want_type == set and sys.version_info >= (2, 7):
-                    # "abc" => {"a", "b", "c"}
-                    constant = set(constant)
-                else:
-                    if isinstance(constant, (frozenset, set)):
-                        constant = sort_set_elts(constant)
-                    if want_type == list:
-                        # "abc" => ["a", "b", "c"]
-                        constant = list(constant)
-                    else:
-                        # want_type == tuple or is_generator
-                        # "abc" => ("a", "b", "c")
-                        constant = tuple(constant)
-                return new_literal(node, constant)
-
-    def list_to_tuple(self, node):
-        if len(node.elts) > self.config.max_tuple_length:
-            return node
-        new_node = ast.Tuple(elts=node.elts, ctx=node.ctx)
-        return copy_lineno(node, new_node)
+    def optimize_iter(self, node, to_type):
+        return self._optimize_iter(node, False, to_type)
 
     def fullvisit_For(self, node):
         node.iter = self.visit(node.iter)
 
-        new_iter = self.optimize_iter(node.iter, True)
+        new_iter = self.optimize_generator(node.iter)
         if new_iter is DROP_NODE:
             return new_pass(node)
         elif new_iter is not None:
             new_ifs = [false_cst]
         node.ifs = new_ifs
 
-    def optimize_comprehension(self, node, is_generator=False, want_type=None):
+    def optimize_comprehension(self, node, is_generator=False, to_type=None):
         # FIXME: support more than 1 generator
         if len(node.generators) != 1:
             return
             if (not isinstance(node.elt, ast.Name)
             or node.elt.id != name):
                 if is_generator:
-                    new_iter = self.optimize_iter(generator.iter, is_generator=True, want_type=tuple)
+                    new_iter = self.optimize_generator(generator.iter)
                 else:
-                    new_iter = self.optimize_iter(generator.iter, want_type=tuple)
+                    new_iter = self.optimize_iter(generator.iter, tuple)
                 if new_iter is DROP_NODE:
                     # y for x in ()
                     return DROP_NODE
             iter_expr = generator.iter
 
         if is_generator:
-            new_iter = self.optimize_iter(iter_expr, is_generator=True, want_type=tuple)
+            new_iter = self.optimize_generator(iter_expr)
         else:
-            new_iter = self.optimize_iter(iter_expr, want_type=want_type)
+            new_iter = self.optimize_iter(iter_expr, to_type)
         if new_iter is not None:
             return new_iter
         else:
         node.generators = self.visit_list(node.generators)
         node.elt = self.visit(node.elt)
 
-        iter_expr = self.optimize_comprehension(node, want_type=list)
+        iter_expr = self.optimize_comprehension(node, to_type=list)
         if iter_expr is None:
             return
         if iter_expr is DROP_NODE:
         node.generators = self.visit_list(node.generators)
         node.elt = self.visit(node.elt)
 
-        iter_expr = self.optimize_comprehension(node, want_type=set)
+        iter_expr = self.optimize_comprehension(node, to_type=set)
         if iter_expr is None:
             return
         if iter_expr is DROP_NODE:

File astoptimizer/tests.py

 
 from astoptimizer import Config, parse_ast, compile_ast
 from astoptimizer.compatibility import (
-    PYTHON2, PYTHON3,
+    PYTHON2, PYTHON27, PYTHON3,
     u, b,
     BYTES_TYPE, UNICODE_TYPE)
 from astoptimizer.optimizer import Namespace, Optimizer
             % (func, ', '.join(args)))
 
     def _text_set(self, *elts):
-        if sys.version_info >= (2, 7):
+        if PYTHON27:
             elts = (self._text_item(elt) for elt in elts)
             elts = '[%s]' % ', '.join(elts)
             return 'Set(elts=%s)' % elts
         else:
             self.check_not_optimized("x in frozenset((1, 2, 3))", config)
 
-        if sys.version_info >= (2, 7):
+        if PYTHON27:
             self.check("x in set((1, 2, 3))",
                        self.text_ast("x in {1, 2, 3}"),
                        config)
                    self.text_ast('frozenset((1, 2, 3))'))
         self.check('list([1, 2, 3])',
                    self.text_ast('[1, 2, 3]'))
-        if sys.version_info >= (2, 7):
+        if PYTHON27:
             self.check('set([1, 2, 3])',
                        self.text_ast('{1, 2, 3}'))
         else:
         self.check_not_optimized('(2,0,2,2).index(9)')
 
     def test_int_methods(self):
-        if sys.version_info >= (2, 7):
+        if PYTHON27:
             self.check('(12345).bit_length()', self.text_num(14))
 
     def test_float_methods(self):
         self.check('sys.hexversion', self.text_num(sys.hexversion), config)
         self.check('sys.version[:3]', self.text_native_str(sys.version[:3]), config)
 
-        if sys.version_info >= (2, 7):
+        if PYTHON27:
             self.check_not_optimized('sys.version_info', config)
             self.check('sys.version_info.minor', self.text_num(sys.version_info.minor), config)
         else: