Commits

Victor Stinner committed dd3a4a9

Fix optimizations on GeneratorExpr

Improve also optimizations on list comprehension, generator expressions, and
call to builtins accepting an iterable (tuple, list, set, frozenset, dict).

Comments (0)

Files changed (5)

 
  - "i=0; while i < 10: print(i); i = i + 1": don't replace print(i) with print('0')
  - "for x in (): try: pass finally: continue" must raise a SyntaxError
- - "gen=(i for i in range(5)); next(gen); gen.send(1)"
  - "type(iter([]))"
 
 major optimizations:

astoptimizer/ast_tools.py

         raise NotImplementedError("unable to create an AST object for constant: %r" % (value,))
     return copy_lineno(node, new_node)
 
+def new_literal(node, value):
+    if isinstance(value, list):
+        elts = [new_constant(node, elt) for elt in value]
+        new_node = ast.List(elts=elts, ctx=ast.Load())
+        return copy_lineno(node, new_node)
+    else:
+        return new_constant(node, value)
+
 if sys.version_info >= (2, 7):
     def new_set(node, iterable=()):
         elts = [new_constant(node, elt) for elt in iterable]

astoptimizer/compatibility.py

 
 IMMUTABLE_TYPES = COMPLEX_TYPES + STR_TYPES + (bool, NONE_TYPE)
 IMMUTABLE_ITERABLE_TYPES = STR_TYPES + (tuple, frozenset)
+ITERABLE_TYPES = IMMUTABLE_ITERABLE_TYPES + (set, list)
 
 def is_singleton(config, value):
     return id(value) in _SINGLETONS

astoptimizer/optimizer.py

 from astoptimizer import UNSET
 from astoptimizer.ast_tools import (
     copy_lineno,
-    new_constant, new_call, new_pass,
+    new_constant, new_literal, new_call, new_pass,
     ast_contains, check_func_args)
 if sys.version_info >= (2, 7):
     from astoptimizer.ast_tools import new_set
     is_bytes_ascii, is_singleton,
     INT_TYPES, FLOAT_TYPES, COMPLEX_TYPES,
     BYTES_TYPE, UNICODE_TYPE, STR_TYPES,
-    IMMUTABLE_ITERABLE_TYPES)
+    ITERABLE_TYPES)
 
 DROP_NODE = object()
 
             return tuple(constants)
         if self.check_builtin_func(node, 'frozenset', 0, 1):
             if len(node.args) == 1:
-                arg = self.get_constant(node.args[0])
+                arg = self.get_literal(node.args[0])
                 if arg is UNSET:
                     return UNSET
-                if not isinstance(arg, IMMUTABLE_ITERABLE_TYPES):
+                if not isinstance(arg, ITERABLE_TYPES):
                     return UNSET
                 if len(arg) > self.config.max_tuple_length:
                     return UNSET
                 return frozenset()
         return UNSET
 
+    def get_literal(self, node):
+        if isinstance(node, ast.List):
+            return [self.get_literal(elt) for elt in node.elts]
+        return self.get_constant(node)
+
     def get_constant_list(self, nodes):
         constants = []
         for node in nodes:
         if isinstance(node, (ast.List, ast.Tuple)):
             return len(node.elts) == 0
 
-        constant = self.get_constant(node)
+        constant = self.get_literal(node)
         if constant is not UNSET:
-            if isinstance(constant, IMMUTABLE_ITERABLE_TYPES):
+            if isinstance(constant, ITERABLE_TYPES):
                 return len(constant) == 0
             else:
                 return False
         qualname = self.namespace.get_qualname(node.func.id)
         if len(node.args) == 1:
             arg = node.args[0]
-            new_arg = self.optimize_iter(arg)
+            new_arg = self.optimize_iter(arg, want_list=(qualname == 'list'))
         else:
             arg = UNSET
             new_arg = DROP_NODE
                 del node.args[0]
             return
 
+        if new_arg is not None:
+            arg = new_arg
+
         if qualname in ('frozenset', 'set', 'tuple'):
-            if new_arg is not None:
-                constant = self.get_constant(new_arg)
-            else:
-                constant = self.get_constant(arg)
+            constant = self.get_literal(arg)
             if (constant is not UNSET
-            and isinstance(constant, IMMUTABLE_ITERABLE_TYPES)):
+            and isinstance(constant, ITERABLE_TYPES)):
                 if qualname != 'tuple':
                     elts = frozenset(constant)
                     try:
                         and sys.version_info >= (2, 7)):
                             return new_set(node, elts)
                         else:
-                            new_arg = new_constant(arg, tuple(elts))
+                            arg = new_constant(arg, tuple(elts))
                 else:
                     # qualname == 'tuple'
                     if len(constant) <= self.config.max_tuple_length:
                         return new_constant(node, tuple(constant))
         elif qualname == 'list':
-            if new_arg is not None:
-                arg = new_arg
-            if isinstance(arg, ast.Tuple):
-                new_node = ast.List(elts=arg.elts, ctx=ast.Load())
-                return copy_lineno(node, new_node)
+            if isinstance(arg, ast.List):
+                return arg
 
-        if new_arg is None:
-            return
-        node.args[0] = new_arg
+        node.args[0] = arg
 
     def call_name(self, node):
         name = self.namespace.get_qualname(node.func.id)
         else:
             self.disable_vars(node)
 
-    def optimize_range(self, node):
+    def optimize_range(self, node, want_list=False):
         args = self.get_constant_list(node.args)
         if args is UNSET:
             return
                 return
 
         if PYTHON3:
-            range_tuple = range(*args)
+            numbers = range(*args)
         else:
-            range_tuple = xrange(*args)
+            numbers = xrange(*args)
         try:
-            range_len = len(range_tuple)
+            range_len = len(numbers)
         except OverflowError:
             # OverflowError: Python int too large to convert to C ssize_t
             pass
         else:
             if range_len <= self.config.max_tuple_length:
                 # range(3) => (0, 1, 2)
-                return new_constant(node, tuple(range_tuple))
+                if want_list:
+                    constant = list(numbers)
+                else:
+                    constant = tuple(numbers)
+                return new_literal(node, constant)
 
         if PYTHON3:
             return
         node.func.id = 'xrange'
         return node
 
-    def optimize_iter(self, node):
+    def optimize_iter(self, node, is_generator=False, want_list=False):
         if (self.config.remove_dead_code
         and self.is_empty_iterable(node)):
             # set("") => set()
             return DROP_NODE
 
         if isinstance(node, ast.List):
-            # for x in [1, 2, 3]: ... => for x in (1, 2, 3): ...
-            return self.list_to_tuple(node)
+            if not want_list:
+                # for x in [1, 2, 3]: ... => for x in (1, 2, 3): ...
+                return self.list_to_tuple(node)
+            else:
+                return
 
         if PYTHON2:
             is_range = (self.check_builtin_func(node, 'range', 1, 3)
         else:
             is_range = self.check_builtin_func(node, 'range', 1, 3)
         if is_range:
-            return self.optimize_range(node)
+            return self.optimize_range(node, want_list)
 
         if (self.config.remove_dead_code
         and self.check_builtin_func(node, 'iter', 1, 1)):
             # optimized by optimize_iter()
             if (isinstance(iter_arg, ast.Tuple)
             and len(iter_arg.elts) == 0):
-                # set(iter([])) => set()
+                # set(iter(())) => set()
                 return DROP_NODE
             # set(iter(iterable)) => set(iterable)
             return iter_arg
 
-        if (self.config.remove_dead_code
+        if (not is_generator
+        and self.config.remove_dead_code
         and self.check_builtin_func(node, ('list', 'tuple'),  1, 1)):
             # set(list(iterable)) => set(iterable)
             return node.args[0]
 
+        if isinstance(node, ast.GeneratorExp):
+            # (x for x in "abc") => "abc"
+            if is_generator:
+                node_is_generator = True
+            else:
+                node_is_generator = not want_list
+            new_iter = self.optimize_comprehension(node, node_is_generator)
+            if new_iter is not None:
+                return new_iter
+
+        if want_list and isinstance(node, ast.Tuple):
+            new_node = ast.List(elts=node.elts, ctx=ast.Load())
+            return copy_lineno(node, new_node)
+        constant = self.get_literal(node)
+        if (constant is not UNSET
+        and isinstance(constant, ITERABLE_TYPES)):
+            if len(constant) <= self.config.max_tuple_length:
+                if want_list:
+                    constant = list(constant)
+                else:
+                    constant = tuple(constant)
+                return new_literal(node, constant)
+
     def list_to_tuple(self, node):
         if len(node.elts) > self.config.max_tuple_length:
             return node
     def fullvisit_For(self, node):
         node.iter = self.visit(node.iter)
 
-        new_iter = self.optimize_iter(node.iter)
+        new_iter = self.optimize_iter(node.iter, True)
         if new_iter is DROP_NODE:
             return new_pass(node)
         elif new_iter is not None:
             new_ifs = [false_cst]
         node.ifs = new_ifs
 
-    def optimize_comprehension(self, node):
+    def optimize_comprehension(self, node, is_generator):
         # FIXME: support more than 1 generator
         if len(node.generators) != 1:
             return
             name = generator.target.id
             if (not isinstance(node.elt, ast.Name)
             or node.elt.id != name):
-                # x*2 for x in data
-                # y for x in data
-                new_iter = self.optimize_iter(generator.iter)
+                if is_generator:
+                    new_iter = self.optimize_iter(generator.iter, True)
+                else:
+                    new_iter = self.optimize_iter(generator.iter)
                 if new_iter is DROP_NODE:
+                    # y for x in ()
                     return DROP_NODE
                 if new_iter is not None:
+                    # y for x in [1, 2] => y for x in (1, 2)
                     generator.iter = new_iter
                 return
             iter_expr = generator.iter
 
-        new_iter = self.optimize_iter(iter_expr)
+        if is_generator:
+            new_iter = self.optimize_iter(iter_expr, True)
+        else:
+            new_iter = self.optimize_iter(iter_expr, want_list=True)
         if new_iter is not None:
             return new_iter
         else:
 
     def visit_GeneratorExp(self, node):
         # use iter() builtin function
-        iter_expr = self.optimize_comprehension(node)
+        iter_expr = self.optimize_comprehension(node, True)
         if iter_expr is None:
             return
 
-        if self.can_use_builtin('iter'):
-            if iter_expr is DROP_NODE:
-                # (x*2 for x in "abc" if False) => iter(())
-                iter_expr = new_constant(node, ())
-            # (x for x in "abc") => iter("abc")
-            return new_call(node, 'iter', iter_expr)
-        else:
-            if iter_expr is not DROP_NODE:
-                return
+        generator = node.generators[0]
+        if iter_expr is DROP_NODE:
             # (x*2 for x in "abc" if False) => (None for x in ())
             # (x*2 for x in []) => (None for x in ())
-            generator = node.generators[0]
             node.elt = new_constant(node, None)
             empty_tuple = new_constant(node, ())
             generator.iter = empty_tuple
             del generator.ifs[:]
+        else:
+            generator.iter = iter_expr
 
     def fullvisit_ListComp(self, node):
         for generator in node.generators:
         node.elt = self.visit(node.elt)
 
         # use list() builtin function
-        iter_expr = self.optimize_comprehension(node)
+        iter_expr = self.optimize_comprehension(node, False)
         if iter_expr is None:
             return
         if iter_expr is not DROP_NODE:
-            if not self.can_use_builtin('list'):
-                return
-            return new_call(node, 'list', iter_expr)
+            if isinstance(iter_expr, ast.List):
+                # [x for x in "abc"] => ["a", "b", "c"]
+                return iter_expr
+            elif self.can_use_builtin('list'):
+                # [x for x in range(1000)] => list(xrange(1000))
+                list_iter = new_call(node, 'list', iter_expr)
+                return list_iter
         else:
+            # [x*2 for x in "abc" if False] => []
+            # [x*2 for x in []] => []
             new_node = ast.List(elts=[], ctx=ast.Load())
             return copy_lineno(node, new_node)
 

astoptimizer/tests.py

 
     def setUp(self):
         del self.warnings[:]
+        self._default_config = self.create_default_config()
 
     def tearDown(self):
         if self.warnings:
 
     def _optimize_ast(self, tree, config=None, catch_syntaxerror=False):
         if config is None:
-            config = self.create_default_config()
+            config = self._default_config
         optimizer = Optimizer(config)
         tree = optimizer.optimize(tree)
         # Ensure that the tree is compilable to bytecode
         return tree, optimizer
 
     def _check_warnings(self, message):
+        emitted = self.warnings[:]
+        del self.warnings[:]
         if message:
             if isinstance(message, (list, tuple)):
                 expected = message
             else:
                 expected = [message]
-            self.assertEqual(len(self.warnings), len(expected), self.warnings)
-            emitted = self.warnings[:]
-            del self.warnings[:]
+            self.assertEqual(len(emitted), len(expected), emitted)
             for emitted, message in zip(emitted, expected):
                 if isinstance(message, str):
                     self.assertEqual(emitted, message)
                 else:
                     # regex
                     self.assertTrue(message.match(emitted), emitted)
-        elif self.warnings:
-            raise Exception("WARNINGS: %s" % self.warnings)
+        elif emitted:
+            raise Exception("WARNINGS: %s" % emitted)
 
     def check(self, code, expected, config=None, warning=None):
         tree = parse_ast(code)
 
         # optimized
         self.check('[x for x in range(3)]',
-                   self.text_ast('list((0, 1, 2))'),
+                   self.text_ast('[0, 1, 2]'),
                    config)
         self.check('[x*2 for x in range(0, 2 ** 20, 2 ** 18)]',
                    self.text_ast('[x*2 for x in (0, %s, %s, %s)]' % (2 ** 18, 2 * 2 ** 18, 3 * 2 ** 18)),
         self.check('for x in range(3): pass',
                    self.text_ast('for x in (0, 1, 2): pass'),
                    config)
+        self.check('for x in list("abc"): pass',
+                   self.text_ast('for x in ("a", "b", "c"): pass'),
+                   config)
+        self.check_not_optimized('for x in tuple(iterable): pass', config)
+        self.check_not_optimized('for x in list(iterable): pass', config)
         if PYTHON2:
             self.check('for x in range(1000): pass',
                        self.text_ast('for x in xrange(1000): pass'),
         self.check('(x*2 for x in (1, 2, 3) if 0)',
                    self.text_ast('(None for x in ())'))
         self.check('(x*2 for x in (1, 2, 3) if 0)',
-                   self.text_ast('iter(())'),
-                   config)
+                   self.text_ast('(None for x in ())'))
         self.check('(x*2 for x in (1, 2, 3) if True if 0 if 1)',
-                   self.text_ast('iter(())'),
-                   config)
+                   self.text_ast('(None for x in ())'))
         self.check_not_optimized('(x*2 for x in (1, 2, 3) if x % 2)')
 
-        # replace with iter()
+        # generators
         self.check('(x * 2 for x in "")',
-                   self.text_ast('iter(())'),
-                   config)
+                   self.text_ast('(None for x in ())'))
         self.check('(x for x in "abc")',
-                   self.text_ast('iter("abc")'),
-                   config)
-        self.check('(x for x in "abc" if True)',
-                   self.text_ast('iter("abc")'),
-                   config)
-        self.check('(x*2 for x in (1, 2, 3) if 0)',
-                   self.text_ast('iter(())'),
-                   config)
-        self.check('(x for x in iterable)',
-                   self.text_ast('iter(iterable)'),
-                   config)
+                   self.text_ast('(x for x in ("a", "b", "c"))'))
+        self.check_not_optimized('(x for x in iterable)')
 
         # tuple(generator)
         self.check('tuple(x for x in "abc")',
                    self.text_ast('tuple(iterable)'),
                    config)
 
+        # list(generator)
+        self.check('list(x for x in range(3))',
+                   self.text_ast('[0, 1, 2]'),
+                   config)
+
     def test_ListComp(self):
         config = self.create_config('builtin_funcs')
 
         # list comprehension
+        self.check('[x for x in "abc"]',
+                   self.text_ast('["a", "b", "c"]'))
+        self.check('[x for x in (0, 1, 2)]',
+                   self.text_ast('[0, 1, 2]'))
         self.check('[x*2 for x in ""]',
                    self.text_ast('[]'))
         if PYTHON2:
 
         # replace with list()
         self.check('[x * 2 for x in ""]',
-                   self.text_ast('[]'),
-                   config)
+                   self.text_ast('[]'))
         self.check('[x for x in "abc"]',
-                   self.text_ast('list("abc")'),
-                   config)
+                   self.text_ast('["a", "b", "c"]'))
         self.check('[x for x in "abc" if True]',
-                   self.text_ast('list("abc")'),
-                   config)
+                   self.text_ast('["a", "b", "c"]'))
         self.check('[x for x in "abc" if False]',
-                   self.text_ast('[]'),
-                   config)
+                   self.text_ast('[]'))
         self.check('[x for x in iterable]',
                    self.text_ast('list(iterable)'),
                    config)