Commits

Victor Stinner  committed 3534a77

Unroll list comprehension

  • Participants
  • Parent commits c986d51

Comments (0)

Files changed (4)

    - ``iter(set())`` => ``iter(())``
    - ``frozenset("")`` => ``frozenset()``
    - ``(x for x in "abc" if False)`` => ``(None for x in ())``
+   - ``[x*10 for x in range(1, 4)]`` => ``[10, 20, 30]``
    - ``(x*2 for x in "abc" if True)`` => ``(x*2 for x in ("a", "b", "c"))``
    - ``list(x for x in iterable)`` => ``list(iterable)``
    - ``tuple(x for x in "abc")`` => ``("a", "b", "c")``
  * Add Config.enable_all_optimizations() method
  * Add a more aggressive option to remove dead code
    (config.remove_almost_dead_code), disabled by default
- * Unroll loops (no support for break/continue yet)
+ * Unroll loops (no support for break/continue yet) and list comprehension.
+   Example: ``[x*10 for x in range(1, 4)]`` => ``[10, 20, 30]``.
  * Remove useless instructions. Example:
    "x=1; 'abc'; print(x)" => "x=1; print(x)"
 
 
  * remove useless code: "try: pass except: pass"
  * replace '(a and b) and c' (2 op) with 'a and b and c' (1 op), same for "or" operator
- * unroll:
-
-   - support break/continue
-   - unroll list comprehension: "[x*2 for x in range(3)]" => "[0, 2, 4]"
-   - drop x if possible
+ * unroll: support break/continue
 
 
 Major Optimizations
    * "x=1" => "pass" (drop x)
    * "return x" => "return x" (keep x)
 
- - drop unused variables with a warning:
+ - drop unused local variables with a warning:
 
    * "def f(): x=1; return 2"
 

File astoptimizer/optimizer.py

             return False
         return self.check_func(node, name, min_narg, max_narg)
 
-    def get_constant(self, node, to_type=None):
+    def get_constant(self, node, to_type=None, max_length=None):
         if node is None:
             constant = None
         elif isinstance(node, ast.Num):
             and not issubclass(tuple, to_type)):
                 return UNSET
             elts = node.elts
-            if len(elts) > self.config.max_tuple_length:
+            if max_length is None:
+                max_length = self.config.max_tuple_length
+            if len(elts) > max_length:
                 return UNSET
             constants = []
             for elt in elts:
         if (to_type == set
         and self.config.remove_dead_code
         and self.check_builtin_func(node, ('frozenset', 'set'),  1, 1)):
-            # set(iterable) => iterable
+            # set(set(iterable)) => set(iterable)
             return node.args[0]
 
         if (self.config.remove_dead_code
         if self.config.unroll_limit <= 0:
             return UNSET
 
-        itercst = self.get_constant(node.iter)
+        itercst = self.get_constant(node.iter, tuple,
+                                    max_length=self.config.unroll_limit)
         if itercst is UNSET:
             return UNSET
-        if (not isinstance(itercst, tuple)
-        or len(itercst) > self.config.unroll_limit):
-            return UNSET
 
         target = node.target
         if not isinstance(target, ast.Name):
         unroll = self.visit_list(unroll)
         return self.if_block(node, unroll)
 
+    def replace_var(self, node, name, value):
+        replace = ReplaceVariable(self.config, name, value)
+        return replace.visit(node)
+
+    def try_unroll_listcomp(self, node):
+        if self.config.unroll_limit <= 0:
+            return
+
+        # FIXME: support more than 1 generator
+        if len(node.generators) != 1:
+            return
+        generator = node.generators[0]
+
+        # FIXME: support more than 1 generator
+        if len(node.generators) != 1:
+            return
+        generator = node.generators[0]
+
+        if generator.ifs:
+            return
+
+        itercst = self.get_constant(generator.iter, tuple,
+                                    max_length=self.config.unroll_limit)
+        if itercst is UNSET:
+            return
+
+        target = generator.target
+        if not isinstance(target, ast.Name):
+            return
+        target_id = target.id
+
+        items = []
+        for cst in itercst:
+            item = []
+            value = new_constant(node, cst)
+            elt = clone_node_list(node.elt)
+            elt = self.replace_var(elt, target_id, value)
+            items.append(elt)
+
+        return new_list_elts(node, items)
+
     def fullvisit_For(self, node):
         node.iter = self.visit(node.iter)
 
         node.elt = self.visit(node.elt)
 
         iter_expr = self.optimize_comprehension(node, to_type=list)
-        if iter_expr is None:
-            return
-        if iter_expr is DROP_NODE:
-            # [x*2 for x in "abc" if False] => []
-            # [x*2 for x in []] => []
-            return new_list_elts(node)
+        if iter_expr is not None:
+            if iter_expr is DROP_NODE:
+                # [x*2 for x in "abc" if False] => []
+                # [x*2 for x in []] => []
+                return new_list_elts(node)
 
-        if isinstance(iter_expr, ast.List):
-            # [x for x in "abc"] => ["a", "b", "c"]
-            return iter_expr
-        elif self.can_use_builtin('list'):
-            # [x for x in range(1000)] => list(xrange(1000))
-            return new_call(node, 'list', iter_expr)
+            if isinstance(iter_expr, ast.List):
+                # [x for x in "abc"] => ["a", "b", "c"]
+                return iter_expr
+
+            if self.can_use_builtin('list'):
+                # [x for x in range(1000)] => list(xrange(1000))
+                return new_call(node, 'list', iter_expr)
+
+            node.generators[0].iter = iter_expr
+
+        return self.try_unroll_listcomp(node)
 
     def fullvisit_SetComp(self, node):
         for generator in node.generators:
         self.seen_yield = False
         return Optimizer._optimize(self, tree, namespace)
 
+class ReplaceVariable(Optimizer):
+    def __init__(self, config, name, value):
+        Optimizer.__init__(self, config)
+        self.name = name
+        self.value = value
+
+    def visit_Name(self, node):
+        if node.id == self.name:
+            return self.value
+
+        return Optimizer.visit_Name(node)
+

File astoptimizer/tests.py

                    self.text_ast('[0, 1, 2]'),
                    config)
         self.check('[x*2 for x in range(0, 2 ** 20, 2 ** 18)]',
-                   self.text_ast('[x*2 for x in (0, %s, %s, %s)]' % (2 ** 18, 2 * 2 ** 18, 3 * 2 ** 18)),
+                   self.text_ast('[0, 524288, 1048576, 1572864]'),
                    config)
 
         # not optimized
         else:
             self.check_not_optimized('[x*2 for x in range(1000)]', config)
         self.check('[x*2 for x in range(3)]',
-                   self.text_ast('[x*2 for x in (0, 1, 2)]'),
+                   self.text_ast('[0, 2, 4]'),
                    config)
 
         # if
-        self.check('[x*2 for x in (1, 2, 3) if True]',
-                   self.text_ast('[x*2 for x in (1, 2, 3)]'))
-        self.check('[x*2 for x in (1, 2, 3) if True if 1]',
-                   self.text_ast('[x*2 for x in (1, 2, 3)]'))
+        self.check('[x*3 for x in (1, 2, 3) if True]',
+                   self.text_ast('[3, 6, 9]'))
+        self.check('[x*5 for x in (1, 2, 3) if True if 1]',
+                   self.text_ast('[5, 10, 15]'))
         self.check('[x*2 for x in (1, 2, 3) if False]',
                    self.text_ast('[]'))
         self.check('[x*2 for x in (1, 2, 3) if 0 if True]',