Commits

Victor Stinner committed a882a9f

fix removal of dead code for generators

Comments (0)

Files changed (2)

astoptimizer/optimizer.py

             raise NotImplementedError("unable to create an AST object for constant: %r" % (value,))
         return copy_lineno(node, new_node)
 
-def optimize_node_list(node_list):
+def remove_duplicate_pass(node_list):
     # Remove duplicate pass instructions
     index = 1
     while index < len(node_list):
     and isinstance(node_list[0], ast.Pass)):
         del node_list[0]
 
-    # Remove dead code
-    # Example: "return 1; return 2" => "return 1"
-    truncate = None
-    for index, node in enumerate(node_list[:-1]):
-        if isinstance(node, (ast.Return, ast.Raise)):
-            truncate = index
-            break
-    if truncate is None:
-        return
-    if truncate == len(node_list) - 1:
-        return
-    del node_list[truncate+1:]
-
 def iter_all_ast(node):
     yield node
     for field, value in ast.iter_fields(node):
             constants.append(constant)
         return constants
 
+    def remove_dead_code(self, node_list):
+        # Remove dead code
+        # Example: "return 1; return 2" => "return 1"
+        truncate = None
+        for index, node in enumerate(node_list[:-1]):
+            if isinstance(node, (ast.Return, ast.Raise)):
+                truncate = index
+                break
+        if truncate is None:
+            return
+        if truncate == len(node_list) - 1:
+            return
+        del node_list[truncate+1:]
+
+    def optimize_node_list(self, node_list):
+        remove_duplicate_pass(node_list)
+        self.remove_dead_code(node_list)
+
     def generic_visit(self, node):
         for field, old_value in ast.iter_fields(node):
             old_value = getattr(node, field, None)
                             new_values.extend(value)
                             continue
                     new_values.append(value)
-                optimize_node_list(new_values)
+                self.optimize_node_list(new_values)
                 old_value[:] = new_values
             elif isinstance(old_value, ast.AST):
                 new_node = self.visit(old_value)
 
     def visit_list(self, node_list):
         new_node_list = [self.visit(node) for node in node_list]
-        optimize_node_list(new_node_list)
+        self.optimize_node_list(new_node_list)
         return new_node_list
 
     def visit_Name(self, node):
         else:
             return not ast_contains(node, (ast.Yield, ast.Return))
 
+    def remove_dead_code(self, node_list):
+        if not self.is_generator:
+            BaseOptimizer.remove_dead_code(self, node_list)
+            return
+
+        truncate = None
+        for index, node in enumerate(node_list[:-1]):
+            if ((isinstance(node, ast.Return) and node.value is None)
+            or  isinstance(node, ast.Raise)):
+                truncate = index
+                break
+        if truncate is None:
+            return
+        if truncate == len(node_list) - 1:
+            return
+        del node_list[truncate+1:]
+
     def visit_Yield(self, node):
         self.seen_yield = True
 

astoptimizer/tests.py

         tree = parse_ast(code)
         return ast.dump(tree)
 
-    def _optimize_ast(self, tree, config=None, check_bytecode=True):
+    def _optimize_ast(self, tree, config=None, catch_syntaxerror=False):
         if config is None:
             config = Config()
         optimizer = Optimizer(config)
         tree = optimizer.optimize(tree)
-        if (check_bytecode
-        and sys.version_info >= (2,6)):
+        if sys.version_info >= (2,6):
             # Ensure that the tree is compilable to bytecode
-            compile_ast(tree)
+            try:
+                compile_ast(tree)
+            except SyntaxError:
+                if not catch_syntaxerror:
+                    raise
         return tree, optimizer
 
     def check(self, code, expected, config=None):
         text = ast.dump(tree)
         self.assertEqual(text, expected)
 
-    def check_not_optimized(self, code, config=None, check_bytecode=True):
+    def check_not_optimized(self, code, config=None, catch_syntaxerror=False):
         old_tree = parse_ast(code)
         old = ast.dump(old_tree)
         old = re.sub(r"UnaryOp\(op=USub\(\), operand=Num\(n=([0-9]+(?:\.[0-9]*)?)\)\)", lambda regs: "Num(n=-%s)" % regs.group(1), old)
-        new_tree, optimizer = self._optimize_ast(old_tree, config, check_bytecode=check_bytecode)
+        new_tree, optimizer = self._optimize_ast(old_tree, config, catch_syntaxerror=catch_syntaxerror)
         new = ast.dump(new_tree)
         self.assertEqual(new, old)
 
         self.check_not_optimized('def f():\n if 0:\n  yield')
         self.check_not_optimized('def f():\n if 1:\n  pass\n else:\n  yield')
         self.check_not_optimized('def f():\n if 0:\n  yield\n yield 3')
-        self.check_not_optimized('def f():\n yield 3\n if 0:\n  return 3', check_bytecode=False)
+        self.check_not_optimized('def f():\n yield 3\n if 0:\n  return 3', catch_syntaxerror=True)
 
     def test_remove_dead_code(self):
         self.check('def f():\n return 1\n return 2',
                    self.text_ast('def f():\n g()\n return 1'))
         self.check('def f():\n g()\n raise ValueError("error")\n h()\n return 2',
                    self.text_ast('def f():\n g()\n raise ValueError("error")'))
+        self.check_not_optimized('def f():\n return 22\n yield 1', catch_syntaxerror=True)
 
     def test_IfExp(self):
         self.check('4 if "abc" else 5', self.text_num(4))
         self.check_not_optimized('def f():\n while 0:\n  yield')
         self.check('def f():\n yield 3\n while 0:\n  yield 5',
                    self.text_ast('def f():\n yield 3'))
-        self.check_not_optimized('while 1: print("log")', check_bytecode=False)
+        self.check_not_optimized('while 1: print("log")')
         self.check_not_optimized('def f():\n while 0:\n  yield 5')
         self.check_not_optimized('def f():\n while 0:\n  yield 5\n yield 3')
 
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.