Commits

Anonymous committed d353b38

extract: handle conditional variable writes in extracted region

Reported by Issac Trotts <issac.trotts@gmail.com>

  • Participants
  • Parent commits 8af0a3d

Comments (0)

Files changed (2)

File rope/refactor/extract.py

         if self.info.global_ and not self.info.make_global:
             return ()
         if not self.info.one_line:
-            return list(self.info_collector.prewritten.
-                        intersection(self.info_collector.read))
+            result = (self.info_collector.prewritten &
+                      self.info_collector.read)
+            result |= (self.info_collector.maybe_written &
+                       self.info_collector.postread)
+            return list(result)
         start = self.info.region[0]
         if start == self.info.lines_region[0]:
             start = start + re.search('\S', self.info.extracted).start()
     def _find_function_returns(self):
         if self.info.one_line or self.info.returned:
             return []
-        return list(self.info_collector.written.
-                    intersection(self.info_collector.postread))
+        written = self.info_collector.written | \
+                  self.info_collector.maybe_written
+        return list(written & self.info_collector.postread)
 
     def _get_unindented_function_body(self, returns):
         if self.info.one_line:
         self.end = end
         self.is_global = is_global
         self.prewritten = set()
+        self.maybe_written = set()
         self.written = set()
         self.read = set()
         self.postread = set()
         self.postwritten = set()
         self.host_function = True
+        self.conditional = False
 
     def _read_variable(self, name, lineno):
         if self.start <= lineno <= self.end:
 
     def _written_variable(self, name, lineno):
         if self.start <= lineno <= self.end:
-            self.written.add(name)
+            if self.conditional:
+                self.maybe_written.add(name)
+            else:
+                self.written.add(name)
         if self.start > lineno:
             self.prewritten.add(name)
         if self.end < lineno:
     def _ClassDef(self, node):
         self._written_variable(node.name, node.lineno)
 
+    def _handle_conditional_node(self, node):
+        self.conditional = True
+        try:
+            for child in ast.get_child_nodes(node):
+                ast.walk(child, self)
+        finally:
+            self.conditional = False
+
+    def _If(self, node):
+        self._handle_conditional_node(node)
+
+    def _While(self, node):
+        self._handle_conditional_node(node)
+
+    def _For(self, node):
+        self._handle_conditional_node(node)
+
+
 
 def _get_argnames(arguments):
     result = [node.id for node in arguments.args

File ropetest/refactor/extracttest.py

         expected = '\ndef f():\n    return "1" "2"\n\ns = (f())\n'
         self.assertEquals(expected, refactored)
 
+    def test_passing_conditional_updated_vars_in_extracted(self):
+        code = 'def f(a):\n' \
+               '    if 0:\n' \
+               '        a = 1\n' \
+               '    print(a)\n'
+        start, end = self._convert_line_range_to_offset(code, 2, 4)
+        refactored = self.do_extract_method(code, start, end, 'g')
+        expected = 'def f(a):\n' \
+                   '    g(a)\n\n' \
+                   'def g(a):\n' \
+                   '    if 0:\n' \
+                   '        a = 1\n' \
+                   '    print(a)\n'
+        self.assertEquals(expected, refactored)
+
+    def test_returning_conditional_updated_vars_in_extracted(self):
+        code = 'def f(a):\n' \
+               '    if 0:\n' \
+               '        a = 1\n' \
+               '    print(a)\n'
+        start, end = self._convert_line_range_to_offset(code, 2, 3)
+        refactored = self.do_extract_method(code, start, end, 'g')
+        expected = 'def f(a):\n' \
+                   '    a = g(a)\n' \
+                   '    print(a)\n\n' \
+                   'def g(a):\n' \
+                   '    if 0:\n' \
+                   '        a = 1\n' \
+                   '    return a\n'
+        self.assertEquals(expected, refactored)
+
 
 if __name__ == '__main__':
     unittest.main()