Commits

Victor Stinner committed 5ed0baa

Optimize list comprehesion

If code removal is disabled, replace:

* "if True" with "if 1"
* "while False" with "while 0"

  • Participants
  • Parent commits ebb0cfd

Comments (0)

Files changed (4)

    - ``not(x in y)`` => ``x not in y``
    - ``4 and 5 and x and 6`` => ``x and 6``
 
- * Optimize iterators and generators. Examples:
+ * Optimize loops. Examples:
 
    - ``while True: ...`` => ``while 1: ...``
    - ``for x in range(3): ...`` => ``for x in xrange(3): ...`` (Python 2)
-   - ``[x for x in range(3)]`` => ``[x for x in xrange(3)]``
-   - ``tuple(x for x in "abc")`` => ``tuple(iter("abc"))``
-   - ``(x for x in "abc" if False)`` => ``iter(())``
+
+ * Optimize list comprehension and generators. Examples:
+
    - ``iter(set())`` => ``iter(())``
    - ``frozenset("")`` => ``frozenset()``
+   - ``(x for x in "abc" if False)`` => ``iter(())``
+   - ``[x for x in ""]`` => ``[]``
+   - ``[x for x in iterable]`` => ``list(iterable)``
+   - ``set([x for x in "abc"])`` => ``set("abc")``
+   - ``tuple(x for x in "abc")`` => ``tuple("abc")``
 
  * Replace list with tuple. Examples:
 
 Changes:
 
  * Optimize print() on Python 2 with "from __future__ import print_function"
- * Optimize iterators and generators
+ * Optimize iterators, list comprehension and generators
  * Replace list with tuple
 
 Version 0.3.1 (2012-09-12)

File astoptimizer/optimizer.py

         else:
             return node_list
 
+    def constant_to_test(self, node, constant):
+        # replace True with 1: avoid a lookup
+        # replace "long string" with 1: smaller constant
+        # replace [] with 0: smaller constant
+        return new_constant(node, int(bool(constant)))
+
     def fullvisit_If(self, node):
         if (not hasattr(node, 'test')
         and 'cpython_tests' in self.config.features):
         node.test = self.visit(node.test)
 
         drop = False
+        constant = self.get_constant(node.test)
         if self.config.remove_dead_code:
-            constant = self.get_constant(node.test)
             if constant is not UNSET:
                 if constant:
                     check_node = node.orelse
                 else:
                     check_node = node.body
                 drop = self.can_drop(check_node)
+        else:
+            if constant is not UNSET:
+                node.test = self.constant_to_test(node.test, constant)
 
         if drop:
             if constant:
         node.orelse = self.visit_list(node.orelse, conditional=True)
         node.test = self.visit(node.test)
 
-        if not self.config.remove_dead_code:
-            return
-
         constant = self.get_constant(node.test)
         if constant is UNSET:
             return
+        if not self.config.remove_dead_code:
+            node.test = self.constant_to_test(node.test, constant)
+            return
         if constant:
             always_true = new_constant(node.test, 1)
             node.test = always_true
 
         if self.check_builtin_func(node, BUILTIN_ACCEPTING_ITERABLE,  1, 1):
             arg = node.args[0]
-            new_arg = self.optimize_iter(arg, True)
+            new_arg = self.optimize_iter(arg)
             if new_arg is DROP_NODE:
+                qualname = self.namespace.get_qualname(node.func.id)
+                if qualname == 'list':
+                    new_node = ast.List(elts=[], ctx=ast.Load())
+                    return copy_lineno(node, new_node)
+                if qualname == 'dict':
+                    new_node = ast.Dict(keys=[], values=[])
+                    return copy_lineno(node, new_node)
                 del node.args[0]
             elif new_arg is not None:
                 node.args[0] = new_arg
 
-        elif (self.check_func(node, 'iter', 1, 1)
+        elif (self.check_builtin_func(node, 'iter', 1, 1)
         and self.is_empty_iterable(node.args[0])):
             # iter(set()) => iter(())
             node.args[0] = new_constant(node.args[0], ())
         else:
             self.disable_vars(node)
 
-    def optimize_iter(self, node, may_drop):
-        if not self.config.remove_dead_code:
-            may_drop = False
-
-        if may_drop and self.is_empty_iterable(node):
+    def optimize_iter(self, node):
+        if (self.config.remove_dead_code
+        and self.is_empty_iterable(node)):
             # set("") => set()
             return DROP_NODE
 
             node.func.id = 'xrange'
             return node
 
-        if may_drop and self.check_func(node, 'iter', 1, 1):
+        if (self.config.remove_dead_code
+        and self.check_builtin_func(node, 'iter', 1, 1)):
             iter_arg = node.args[0]
+            # no need to call is_empty_iterable(), iter(iterable) was already
+            # optimized by optimize_iter()
             if (isinstance(iter_arg, ast.Tuple)
             and len(iter_arg.elts) == 0):
                 # set(iter([])) => set()
                 return DROP_NODE
+            # set(iter(iterable)) => set(iterable)
+            return iter_arg
+
+        if (self.config.remove_dead_code
+        and self.check_builtin_func(node, 'list', 1, 1)):
+            # set(list(arg)) => set(arg)
+            iterable = node.args[0]
+            if self.is_empty_iterable(iterable):
+                # set("") => set()
+                return DROP_NODE
+            return iterable
 
     def list_to_tuple(self, node):
         if len(node.elts) > self.config.max_tuple_length:
     def fullvisit_For(self, node):
         node.iter = self.visit(node.iter)
 
-        new_iter = self.optimize_iter(node.iter, True)
+        new_iter = self.optimize_iter(node.iter)
         if new_iter is DROP_NODE:
             return new_pass(node)
         elif new_iter is not None:
             return optimizer._optimize(node, class_namespace)
 
     def fullvisit_comprehension(self, node):
-        self.unassign(node.target)
+        # don't visit node.target
         node.iter = self.visit(node.iter)
-        new_iter = self.optimize_iter(node.iter, False)
-        if new_iter is not None:
-            node.iter = new_iter
 
         new_ifs = []
         empty_gen = False
             constant = self.get_constant(ifexp)
             if constant is UNSET:
                 new_ifs.append(ifexp)
-            elif not constant:
-                empty_gen = True
-                break
+            else:
+                if self.config.remove_dead_code:
+                    if not constant:
+                        empty_gen = True
+                        break
+                else:
+                    cst = self.constant_to_test(ifexp, constant)
+                    new_ifs.append(cst)
         if empty_gen:
-            false_cst = new_constant(ifexp, False)
+            false_cst = new_constant(ifexp, 0)
             new_ifs = [false_cst]
         node.ifs = new_ifs
 
-    def visit_GeneratorExp(self, node):
-        # use iter() builtin function
-        if 'builtin_funcs' not in self.config.features:
-            return
+    def optimize_comprehension(self, node):
         # FIXME: support more than 1 generator
         if len(node.generators) != 1:
             return
         if not isinstance(generator.target, ast.Name):
             return
         if generator.ifs:
+            if not self.config.remove_dead_code:
+                return
+
+            # fullvisit_comprehension() already optimized ifs
             if len(generator.ifs) != 1:
                 return
             test_expr = generator.ifs[0]
             if constant:
                 return
             # (x for x in data if False) => iter(())
-            iter_expr = new_constant(test_expr, ())
+            return DROP_NODE
         else:
             name = generator.target.id
             if (not isinstance(node.elt, ast.Name)
             or node.elt.id != name):
                 # x*2 for x in data
                 # y for x in data
+                new_iter = self.optimize_iter(generator.iter)
+                if new_iter is DROP_NODE:
+                    return DROP_NODE
+                if new_iter is not None:
+                    generator.iter = new_iter
                 return
             iter_expr = generator.iter
-        return new_call(node, 'iter', iter_expr)
+
+        new_iter = self.optimize_iter(iter_expr)
+        if new_iter is not None:
+            return new_iter
+        else:
+            return iter_expr
+
+    def visit_GeneratorExp(self, node):
+        # use iter() builtin function
+        iter_expr = self.optimize_comprehension(node)
+        if iter_expr is None:
+            return
+
+        if 'builtin_funcs' in self.config.features:
+            if iter_expr is DROP_NODE:
+                # (x*2 for x in "abc" if False) => iter(())
+                iter_expr = new_constant(node, ())
+            # (x for x in "abc") => iter("abc")
+            return new_call(node, 'iter', iter_expr)
+        else:
+            if iter_expr is not DROP_NODE:
+                return
+            # (x*2 for x in "abc" if False) => (None for x in ())
+            # (x*2 for x in []) => (None for x in ())
+            generator = node.generators[0]
+            node.elt = new_constant(node, None)
+            empty_tuple = new_constant(node, ())
+            generator.iter = empty_tuple
+            del generator.ifs[:]
+
+    def fullvisit_ListComp(self, node):
+        for generator in node.generators:
+            self.unassign(generator.target)
+
+        node.generators = self.visit_list(node.generators)
+        node.elt = self.visit(node.elt)
+
+        # use list() builtin function
+        iter_expr = self.optimize_comprehension(node)
+        if iter_expr is None:
+            return
+        if iter_expr is not DROP_NODE:
+            if 'builtin_funcs' not in self.config.features:
+                return
+            return new_call(node, 'list', iter_expr)
+        else:
+            new_node = ast.List(elts=[], ctx=ast.Load())
+            return copy_lineno(node, new_node)
 
 
 class FunctionOptimizer(Optimizer):

File astoptimizer/tests.py

         config = Config()
         config.remove_dead_code = False
         self.check_not_optimized('def f():\n return 1\n return 2', config)
-        self.check_not_optimized('if 0: print("log")', config)
-        self.check_not_optimized('while 0: print("log")', config)
+        self.check('if False: print("log")',
+                   self.text_ast('if 0: print("log")'),
+                   config)
+        self.check('while False: print("log")',
+                   self.text_ast('while 0: print("log")'),
+                   config)
         self.check_not_optimized('for x in (): pass', config)
         self.check_not_optimized('for x in (1, 2, 3): pass', config)
+        self.check_not_optimized('[x for x in () if 1]', config)
+        self.check('(x for x in "" if True)',
+                   self.text_ast('(x for x in "" if 1)'),
+                   config)
+        self.check('[x for x in "abc" if False]',
+                   self.text_ast('[x for x in "abc" if 0]'),
+                   config)
+        self.check_not_optimized('(x for x in "abc" if 0)', config)
 
     def test_IfExp(self):
         self.check('4 if "abc" else 5', self.text_num(4))
 
             # optimized
             self.check('[x for x in range(3)]',
-                       self.text_ast('[x for x in xrange(3)]'),
+                       self.text_ast('list(xrange(3))'),
                        config)
-            self.check('[x for x in range(1, 10)]',
-                       self.text_ast('[x for x in xrange(1, 10)]'),
+            self.check('[x*2 for x in range(1, 10)]',
+                       self.text_ast('[x*2 for x in xrange(1, 10)]'),
                        config)
-            self.check('[x for x in range(0, 10, 4)]',
-                       self.text_ast('[x for x in xrange(0, 10, 4)]'),
+            self.check('[x*2 for x in range(0, 10, 4)]',
+                       self.text_ast('[x*2 for x in xrange(0, 10, 4)]'),
                        config)
 
             # not optimized
-            self.check('[x for x in range(0, 2 ** 100, 2 ** 99)]',
-                       self.text_ast('[x for x in range(0, %s, %s)]' % (2 ** 100, 2 ** 99)),
+            self.check('[x*2 for x in range(0, 2 ** 100, 2 ** 99)]',
+                       self.text_ast('[x*2 for x in range(0, %s, %s)]' % (2 ** 100, 2 ** 99)),
                        config)
-            self.check_not_optimized('[x for x in range(1, 2, 3, 4)]',
+            self.check_not_optimized('[x*2 for x in range(1, 2, 3, 4)]',
                        config)
-            self.check_not_optimized('[x for x in range(1, 2, step=3)]',
+            self.check_not_optimized('[x*2 for x in range(1, 2, step=3)]',
                        config)
-            self.check_not_optimized('[x for x in range(0, 5, 0.1)]',
+            self.check_not_optimized('[x*2 for x in range(0, 5, 0.1)]',
                        config)
 
     def test_For(self):
 
     def test_iter_empty_iterable(self):
         config = self.create_config('builtin_funcs')
-        self.check('dict(iter(()))', self.text_ast('dict()'), config)
+        self.check('dict(iter(()))', self.text_ast('{}'), config)
         self.check('frozenset(iter(()))', self.text_ast('frozenset()'), config)
-        self.check('list(iter(()))', self.text_ast('list()'), config)
+        self.check('list(iter(()))', self.text_ast('[]'), config)
         self.check('set(iter(()))', self.text_ast('set()'), config)
         self.check('tuple(iter(()))', self.text_ast('tuple()'), config)
 
 
     def test_drop_iter_empty_iterable(self):
         config = self.create_config('builtin_funcs')
+        self.check('tuple(iter(""))', self.text_ast('tuple()'), config)
         self.check('tuple(iter(()))', self.text_ast('tuple()'), config)
         self.check('tuple(iter([]))', self.text_ast('tuple()'), config)
         self.check('tuple(iter(tuple()))', self.text_ast('tuple()'), config)
         self.check('tuple(iter(frozenset()))', self.text_ast('tuple()'), config)
 
     def test_GeneratorExp(self):
-        self.check('tuple(x*2 for x in (1, 2, 3) if True)',
-                   self.text_ast('tuple(x*2 for x in (1, 2, 3))'))
-        self.check('tuple(x*2 for x in (1, 2, 3) if 0)',
-                   self.text_ast('tuple(x*2 for x in (1, 2, 3) if False)'))
-        self.check_not_optimized('tuple(x*2 for x in (1, 2, 3) if x % 2)')
+        config = self.create_config('builtin_funcs')
+
+        # generators
+        self.check('(x*2 for x in "")',
+                   self.text_ast('(None for x in ())'))
+        self.check('(x*2 for x in range(3))',
+                   self.text_ast('(x*2 for x in xrange(3))'),
+                   config)
+
+        # if
+        self.check('(x*2 for x in (1, 2, 3) if True)',
+                   self.text_ast('(x*2 for x in (1, 2, 3))'))
+        self.check('(x*2 for x in (1, 2, 3) if 0)',
+                   self.text_ast('(None for x in ())'))
+        self.check('(x*2 for x in (1, 2, 3) if 0)',
+                   self.text_ast('iter(())'),
+                   config)
+        self.check('(x*2 for x in (1, 2, 3) if True if 0 if 1)',
+                   self.text_ast('iter(())'),
+                   config)
+        self.check_not_optimized('(x*2 for x in (1, 2, 3) if x % 2)')
 
         # replace with iter()
-        config = self.create_config('builtin_funcs')
-        self.check('tuple(x for x in "abc")',
-                   self.text_ast('tuple(iter("abc"))'),
+        self.check('(x * 2 for x in "")',
+                   self.text_ast('iter(())'),
                    config)
-        self.check('tuple(x for x in "abc" if True)',
-                   self.text_ast('tuple(iter("abc"))'),
+        self.check('(x for x in "abc")',
+                   self.text_ast('iter("abc")'),
+                   config)
+        self.check('(x for x in "abc" if True)',
+                   self.text_ast('iter("abc")'),
                    config)
         self.check('(x*2 for x in (1, 2, 3) if 0)',
                    self.text_ast('iter(())'),
                    config)
-        self.check('tuple(x*2 for x in (1, 2, 3) if 0)',
+        self.check('(x for x in iterable)',
+                   self.text_ast('iter(iterable)'),
+                   config)
+
+        # tuple(generator)
+        self.check('tuple(x for x in "abc")',
+                   self.text_ast('tuple("abc")'),
+                   config)
+        self.check('tuple(x*2 for x in "abc" if 0)',
                    self.text_ast('tuple()'),
                    config)
-        self.check_not_optimized('tuple(x for x in "abc" if x)', config)
-        self.check_not_optimized('tuple([x for x in "abc"])', config)
+        self.check('tuple(x for x in iterable)',
+                   self.text_ast('tuple(iterable)'),
+                   config)
+
+    def test_ListComp(self):
+        config = self.create_config('builtin_funcs')
+
+        # list comprehension
+        self.check('[x*2 for x in ""]',
+                   self.text_ast('[]'))
+        self.check('[x*2 for x in range(3)]',
+                   self.text_ast('[x*2 for x in xrange(3)]'),
+                   config)
+
+        # if
+        self.check('[x*2 for x in (1, 2, 3) if True]',
+                   self.text_ast('[x*2 for x in (1, 2, 3)]'))
+        self.check('[x*2 for x in (1, 2, 3) if True if 1]',
+                   self.text_ast('[x*2 for x in (1, 2, 3)]'))
+        self.check('[x*2 for x in (1, 2, 3) if False]',
+                   self.text_ast('[]'))
+        self.check('[x*2 for x in (1, 2, 3) if 0 if True]',
+                   self.text_ast('[]'))
+        self.check_not_optimized('[x*2 for x in (1, 2, 3) if x % 2]')
+
+        # replace with list()
+        self.check('[x * 2 for x in ""]',
+                   self.text_ast('[]'),
+                   config)
+        self.check('[x for x in "abc"]',
+                   self.text_ast('list("abc")'),
+                   config)
+        self.check('[x for x in "abc" if True]',
+                   self.text_ast('list("abc")'),
+                   config)
+        self.check('[x for x in "abc" if False]',
+                   self.text_ast('[]'),
+                   config)
+        self.check('[x for x in iterable]',
+                   self.text_ast('list(iterable)'),
+                   config)
+
+        # tuple(list comprehesion)
+        self.check('tuple([x for x in "abc"])',
+                   self.text_ast('tuple("abc")'),
+                   config)
+        self.check('tuple([x for x in "abc" if False])',
+                   self.text_ast('tuple()'),
+                   config)
+        self.check('tuple([x for x in iterable])',
+                   self.text_ast('tuple(iterable)'),
+                   config)
 
 
 class TestFrozenset(BaseTestCase):
                           {}, ['x', 'y'])
         self.check_values('x=1\nfor x in range(3): print(x)',
                           {}, ['x'])
-        self.check('x = 5\nfor x in "": x = 9\nprint(x)',
-                   self.text_ast('x = 5\nprint("5")'))
+        self.check('def f():\n x = 5\n for x in "": x = 9\n return x',
+                   self.text_ast('def f():\n x = 5\n return 5'))
 
     def test_listcomp(self):
         self.check_values('[x for x in range(3)]',
 
     config = Config('builtin_funcs', 'pythonenv')
     config.use_experimental_vars = True
+    config.remove_dead_code = False
     print("Config features: %s" % ', '.join(sorted(config.features)))
     print("")