Commits

Victor Stinner committed e1535ec

Fix visit_Attribute() and removal of dead code in generators

  • Participants
  • Parent commits 1505645

Comments (0)

Files changed (2)

File astoptimizer/optimizer.py

             constants.append(constant)
         return constants
 
+    def can_drop(self, node_list):
+        return True
+
     def remove_dead_code(self, node_list):
         # Remove dead code
         # Example: "return 1; return 2" => "return 1"
         truncate = None
         for index, node in enumerate(node_list[:-1]):
-            if isinstance(node, (ast.Return, ast.Raise)):
-                truncate = index
-                break
+            if not isinstance(node, (ast.Return, ast.Raise)):
+                continue
+            if not self.can_drop(node_list[index+1:]):
+                continue
+            truncate = index
+            break
         if truncate is None:
             return
         if truncate == len(node_list) - 1:
                         if value is None:
                             continue
                         elif not isinstance(value, ast.AST):
+                            assert isinstance(value, list), value
                             new_values.extend(value)
                             continue
                     new_values.append(value)
             return
         return new_constant(node, constant)
 
-    def attribute_name(self, node):
+    def get_attribute_name(self, node):
         if isinstance(node.value, ast.Name):
             name = node.value.id
         elif isinstance(node.value, ast.Attribute):
-            name = self.attribute_name(node.value)
+            name = self.get_attribute_name(node.value)
             if name is UNSET:
                 return UNSET
         else:
         return "%s.%s" % (name, node.attr)
 
     def visit_Attribute(self, node):
-        name = self.attribute_name(node)
+        name = self.get_attribute_name(node)
         if name is UNSET:
             return
         name = self.namespace.get_qualname(name)
         if name is UNSET:
-            return UNSET
+            return
         constant = self.config.get_constant(name)
         if constant is UNSET:
             return
         else:
             return nodes
 
-    def can_drop(self, node):
-        return True
-
     def fullvisit_If(self, node):
         if (not hasattr(node, 'test')
         and 'cpython_tests' in self.config.features):
         self.is_generator = None
         self.seen_yield = None
 
-    def can_drop(self, node):
+    def can_drop(self, node_list):
         if not self.is_generator:
             return True
         # Without "yield", a function is no more a generator.
         # "return 3" is a SyntaxError in a generator.
         # http://tomlee.co/2008/04/the-internals-of-python-generator-functions-in-the-ast/
         if self.seen_yield:
-            return not ast_contains(node, ast.Return)
+            # FIXME: drop "return" but not "return 3"
+            return not ast_contains(node_list, ast.Return)
         else:
-            return not ast_contains(node, (ast.Yield, ast.Return))
+            # FIXME: handle seen_yield=None case for remove_dead_code()
+            return not ast_contains(node_list, (ast.Yield, ast.Return))
 
     def remove_dead_code(self, node_list):
-        if not self.is_generator:
+        seen = self.seen_yield
+        try:
+            self.seen_yield = None
             BaseOptimizer.remove_dead_code(self, node_list)
-            return
-
-        truncate = None
-        for index, node in enumerate(node_list[:-1]):
-            if ((isinstance(node, ast.Return) and node.value is None)
-            or  isinstance(node, ast.Raise)):
-                truncate = index
-                break
-        if truncate is None:
-            return
-        if truncate == len(node_list) - 1:
-            return
-        del node_list[truncate+1:]
+        finally:
+            self.seen_yield = seen
 
     def visit_Yield(self, node):
         self.seen_yield = True

File astoptimizer/tests.py

                    self.text_ast('def f():\n g()\n return 1'))
         self.check('def f():\n g()\n raise ValueError("error")\n h()\n return 2',
                    self.text_ast('def f():\n g()\n raise ValueError("error")'))
+
+        # generators
+        self.check_not_optimized('def f():\n return\n yield 1')
         self.check_not_optimized('def f():\n return 22\n yield 1', catch_syntaxerror=True)
 
+        code = """
+def f():
+    if 0:
+        lambda x:  x        # shouldn't trigger here
+        return              # or here
+        def f(i):
+            return 2*i      # or here
+        if 0:
+            return 3        # but *this* sucks (line 8)
+    if 0:
+        yield 2             # because it's a generator (line 10)
+        """.strip()
+        self.check_not_optimized(code, catch_syntaxerror=True)
+
     def test_IfExp(self):
         self.check('4 if "abc" else 5', self.text_num(4))
         self.check('4 if 0 else 5', self.text_num(5))