Commits

Victor Stinner committed 014eaa2

Unroll loops

Comments (0)

Files changed (6)

  * Optimize loops (range => xrange needs "builtin_funcs" features). Examples:
 
    - ``while True: pass`` => ``while 1: pass``
-   - ``for x in range(3): pass`` => ``for x in (0, 1, 2): pass``
+   - ``for x in range(3): print(x)`` => ``x = 0; print(x); x = 1; print(x); x = 2; print(x)``
    - ``for x in range(1000): pass`` => ``for x in xrange(1000): pass`` (Python 2)
 
  * Optimize iterators, list, set and dict comprehension, and generators (need
 
  * Replace list with tuple (need "builtin_funcs" feature). Examples:
 
-   - ``for x in [1, 2, 3]: pass`` => ``for x in (1, 2, 3): pass``
+   - ``for x in [x, y, z]: pass`` => ``for x in (x, y, z): pass``
    - ``x in [1, 2, 3]`` => ``x in (1, 2, 3)``
    - ``list([x, y, z])`` => ``[x, y, z]``
    - ``set([1, 2, 3])`` => ``{1, 2, 3}`` (Python 2.7+)
 
 Changes:
 
+ * Unroll loops (no support for break/continue yet)
  * Remove useless instructions. Example:
    "x=1; 'abc'; print(x)" => "x=1; print(x)"
 
  - "type(iter([]))"
 
 
+Misc
+====
+
+ * unroll:
+
+   - support break/continue
+   - unroll list comprehension: "[x*2 for x in range(3)]" => "[0, 2, 4]"
+   - drop x if possible
+
+
 Major Optimizations
 ===================
 
    * "x=[0]; for ...: x.append(...)"
      => "x=[0]; x_append=x.append; for ...: x_append(...)"
 
- - unroll short loops: "for x in range(4): print(x)"
-
-   * duplicate the body and evaluate the body with x=0, x=1, ...
-   * handle continue/break
-   * drop x if possible
-   * "[x*2 for x in range(3)]" => "[0, 2, 4]"
-
  - convert naive loop to list comprehension: "x=[]; for item in data: x.append(item.upper())"
    => "x=[item.upper() for item in data]". Same for x=set() and x={}.
 

astoptimizer/ast_tools.py

 from astoptimizer.compatibility import (
     PYTHON3, COMPLEX_TYPES, BYTES_TYPE, UNICODE_TYPE)
 import sys
+import copy
 
 def copy_lineno(node, new_node):
     ast.fix_missing_locations(new_node)
     new_node = ast.Pass()
     return copy_lineno(node, new_node)
 
+def clone_node_list(node_list):
+    # FIXME: use something faster? or more specialized?
+    return copy.deepcopy(node_list)

astoptimizer/config.py

         # documentation
         self.max_size = 2**31 - 1
 
+        # Limit of loop unroll
+        # Use 0 to disable loop unroll.
+        self.unroll_limit = 100
+
         # Experimental support of assignment (support of variables)
         self.use_experimental_vars = False
 

astoptimizer/optimizer.py

     new_constant, new_literal, new_call, new_pass,
     sort_set_elts, new_tuple, new_tuple_elts, new_list, new_list_elts,
     new_dict_elts,
-    ast_contains, check_func_args)
+    ast_contains, check_func_args,
+    clone_node_list)
 from astoptimizer.config import optimize_unicode
 from astoptimizer.compatibility import (
     u,
 
     def visit_list(self, node_list, conditional=False):
         new_node_list = []
-        for  node in node_list:
+        for node in node_list:
             new_node = self.visit(node, conditional=conditional)
             if new_node is None:
                 continue
     def optimize_iter(self, node, to_type):
         return self._optimize_iter(node, False, to_type)
 
+    def create_expr_list(self, node, new_node_list):
+        test = ast.Num(1)
+        copy_lineno(node, test)
+        new_node = ast.If(test=test, body=new_node_list, orelse=[])
+        copy_lineno(node, new_node)
+        return new_node
+
+    def try_unroll_loop(self, node):
+        if self.config.unroll_limit <= 0:
+            return UNSET
+
+        itercst = self.get_constant(node.iter)
+        if itercst is UNSET:
+            return UNSET
+        if (not isinstance(itercst, tuple)
+        or len(itercst) > self.config.unroll_limit):
+            return UNSET
+
+        target = node.target
+        if not isinstance(target, ast.Name):
+            return UNSET
+        target_id = target.id
+
+        if (ast_contains(node.body, (ast.Break, ast.Continue))
+        or ast_contains(node.orelse, (ast.Break, ast.Continue))):
+            return UNSET
+
+        # FIXME: don't unassign temporary the target?
+        was_unassigned = target_id in self.namespace._unassigned
+        self.namespace._unassigned.add(target_id)
+        node.body = self.visit_list(node.body, conditional=True)
+        node.orelse = self.visit_list(node.orelse, conditional=True)
+        if not was_unassigned:
+            self.namespace._unassigned.remove(target_id)
+
+        unroll = []
+        for cst in itercst:
+            value = new_constant(node, cst)
+            assign = ast.Assign(targets=[target], value=value)
+            copy_lineno(node.body[0], assign)
+            unroll.append(assign)
+            body = clone_node_list(node.body)
+            unroll.extend(body)
+        unroll.extend(node.orelse)
+
+        unroll = self.visit_list(unroll)
+        return self.if_block(node, unroll)
+
     def fullvisit_For(self, node):
         node.iter = self.visit(node.iter)
 
         elif new_iter is not None:
             node.iter = new_iter
 
-        #node.target = self.visit(node.target)
+        unroll = self.try_unroll_loop(node)
+        if unroll is not UNSET:
+            return unroll
+
         self.unassign(node.target)
-
         node.body = self.visit_list(node.body, conditional=True)
         node.orelse = self.visit_list(node.orelse, conditional=True)
+        return node
 
     def fullvisit_arguments(self, node):
         # Don't visit arguments

astoptimizer/tests.py

         self.check_not_optimized('1[1]')
         self.check_not_optimized('1[:1]')
 
-    def check_pass(self, code):
+    def check_pass(self, code, config=None):
         expected = code.replace("pass; pass", "pass")
         expected = expected.replace("\n pass\n pass", "\n pass")
-        self.check(code, self.text_ast(expected))
+        self.check(code, self.text_ast(expected), config=config)
 
     def test_Pass(self):
         self.check_not_optimized('pass')
 
         self.check_pass('class Klass:\n pass\n pass')
         self.check_pass('def f():\n pass\n pass')
-        self.check_pass('for i in (1, 2, 3):\n pass\n pass')
+        config = self.create_config()
+        config.unroll_limit = 0
+        self.check_pass('for i in (1, 2, 3):\n pass\n pass', config=config)
         self.check_pass('while x:\n pass\n pass')
 
         code = '\n'.join((
         # disable removal of dead code
         config = Config()
         config.remove_dead_code = False
+        config.unroll_limit = 0
         self.check_not_optimized('def f():\n return 1\n return 2', config)
         self.check('if False: print("log")',
                    self.text_ast('if 0: print("log")'),
         self.check_not_optimized('for x in range(n): pass')
 
         config = self.create_config('builtin_funcs')
+        config.unroll_limit = 0
         self.check('for x in range(3): pass',
                    self.text_ast('for x in (0, 1, 2): pass'),
                    config)
                        config)
 
         self.check('for x in [1, 2, 3]: pass',
-                   self.text_ast('for x in (1, 2, 3): pass'))
+                   self.text_ast('for x in (1, 2, 3): pass'),
+                   config)
         self.check('for x in "": pass', self.TEXT_PASS)
 
         config = self.create_config()
         config.max_tuple_length = 2
+        config.unroll_limit = 0
         self.check('for x in [1, 2]: pass',
                    self.text_ast('for x in (1, 2): pass'), config)
         self.check_not_optimized('for x in [1, 2, 3]: pass', config)
 
+    def test_For_unroll(self):
+        no_unroll = self.create_config()
+        no_unroll.unroll_limit = 0
+        self.check_not_optimized('for i in (1, 2, 3):\n print(i)',
+                                 config=no_unroll)
+        self.check('for i in (1, 2, 3):\n print(i)',
+                   self.text_ast('i=1; print(i); i=2; print(i); i=3; print(i)'))
+        self.check('for i in (1, 2, 3):\n pass\n pass',
+                   self.text_ast('i=1; i=2; i=3'))
+
     def test_max_tuple_length(self):
         config = self.create_config('builtin_funcs')
         config.max_tuple_length = 3
         config._constants['sys.maxint'] = 2147483647
         return config
 
-    def check_line(self, line):
+    def check_line(self, line_number, line):
         if not line.startswith('   - ``'):
             return
         line = line[7:].rstrip()
     def test_readme_file(self):
         filename = os.path.join(os.path.dirname(__file__), '..', 'README')
         with open(filename) as fp:
-            for line in fp:
-                self.check_line(line)
+            for line_number, line in enumerate(fp, 1):
+                self.check_line(line_number, line)
 
 
 class TestStruct(BaseTestCase):
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.