Commits

Anonymous committed f742770

extract: better global function extractions

Comments (0)

Files changed (2)

rope/refactor/extract.py

             self.matched_lines.append(self.info.region_lines[0])
 
     def find_lineno(self):
-        if self.info.make_global and not self.info.global_:
+        if self.info.variable and not self.info.make_global:
+            return self._get_before_line()
+        if self.info.make_global or self.info.global_:
             toplevel = self._find_toplevel(self.info.scope)
             ast = self.info.pymodule.get_ast()
             newlines = sorted(self.matched_lines + [toplevel.get_end() + 1])
             return suites.find_visible(ast, newlines)
-        if self.info.global_ or self.info.variable:
-            return self._get_before_line()
         return self._get_after_scope()
 
     def _find_toplevel(self, scope):
         if toplevel.parent is not None:
             while toplevel.parent.parent is not None:
                 toplevel = toplevel.parent
-            return toplevel
+        return toplevel
 
     def find_indents(self):
-        if self.info.make_global:
-            return 0
-        if self.info.global_ or self.info.variable:
+        if self.info.variable and not self.info.make_global:
             return sourceutils.get_indents(self.info.lines,
                                            self._get_before_line())
+        else:
+            if self.info.global_ or self.info.make_global:
+                return 0
         return self.info.scope_indents
 
     def _get_before_line(self):

ropetest/refactor/extracttest.py

         code = 'if True:\n    a = 10\n'
         start, end = self._convert_line_range_to_offset(code, 2, 2)
         refactored = self.do_extract_method(code, start, end, 'new_func')
-        expected = 'if True:\n\n    ' \
-                   'def new_func():\n        a = 10\n\n    new_func()\n'
+        expected = '\ndef new_func():\n    a = 10\n\nif True:\n' \
+                   '    new_func()\n'
         self.assertEquals(expected, refactored)
 
     def test_extract_function_while_inner_function_reads(self):
                    'def one():\n    return 1\n'
         self.assertEquals(expected, refactored)
 
+    def test_extracting_methods_in_global_functions_should_be_global(self):
+        code = 'if 1:\n    var = 2\n'
+        start = code.rindex('2')
+        refactored = self.do_extract_method(code, start, start + 1, 'two',
+                                            similar=True, global_=False)
+        expected = '\ndef two():\n    return 2\n\nif 1:\n    var = two()\n'
+        self.assertEquals(expected, refactored)
+
 
 if __name__ == '__main__':
     unittest.main()