Anonymous avatar Anonymous committed 6e84f5e

extract: extracting functions with only one return as their last stmt

Comments (0)

Files changed (2)

rope/refactor/extract.py

     def logical_lines(self):
         return self.pymodule.logical_lines
 
-
     def _init_scope(self):
         start_line = self.region_lines[0]
         scope = self.global_scope.get_inner_scope_for_line(start_line)
     def extracted(self):
         return self.source[self.region[0]:self.region[1]]
 
+    _returned = None
+    @property
+    def returned(self):
+        """Does the extracted piece contain return statement"""
+        if self._returned is None:
+            self._returned = _returns_last(self.extracted)
+        return self._returned
+
 
 class _ExtractCollector(object):
     """Collects information needed for performing the extract"""
         if end_scope != info.scope and end_scope.get_end() != end_line:
             raise RefactoringError('Bad region selected for extract method')
         try:
-            if _ReturnOrYieldFinder.does_it_return(
-                info.source[info.region[0]:info.region[1]]):
-                raise RefactoringError('Extracted piece should not '
-                                       'contain return statements.')
             if _UnmatchedBreakOrContinueFinder.has_errors(
-                info.source[info.region[0]:info.region[1]]):
+               info.source[info.region[0]:info.region[1]]):
                 raise RefactoringError('A break/continue without having a '
                                        'matching for/while loop.')
         except SyntaxError:
                                    'span multiple lines.')
 
     def multi_line_conditions(self, info):
+        code = info.source[info.region[0]:info.region[1]]
+        count = _return_count(code)
+        if count > 0 and not (count == 1 and _returns_last(code)):
+            raise RefactoringError('Extracted piece should not '
+                                   'contain more than one return statements.')
         if info.region != info.lines_region:
-            raise RefactoringError('Extracted piece should'
-                                   ' contain complete statements.')
+            raise RefactoringError('Extracted piece should '
+                                   'contain complete statements.')
 
     def _is_region_on_a_word(self, info):
         if info.region[0] > 0 and self._is_on_a_word(info, info.region[0] - 1) or \
         call_prefix = ''
         if returns:
             call_prefix = self._get_comma_form(returns) + ' = '
+        if self.info.returned:
+            call_prefix = 'return '
         return call_prefix + self._get_function_call(args)
 
     def _find_function_arguments(self):
         return list(self.info_collector.prewritten.intersection(read))
 
     def _find_function_returns(self):
-        if self.info.one_line:
+        if self.info.one_line or self.info.returned:
             return []
         return list(self.info_collector.written.
                     intersection(self.info_collector.postread))
 class _ReturnOrYieldFinder(object):
 
     def __init__(self):
-        self.returns = False
-        self.loop_count = 0
-
-    def check_loop(self):
-        if self.loop_count < 1:
-            self.error = True
+        self.returns = 0
 
     def _Return(self, node):
-        self.returns = True
+        self.returns += 1
 
     def _Yield(self, node):
-        self.returns = True
+        self.returns += 1
 
     def _FunctionDef(self, node):
         pass
     def _ClassDef(self, node):
         pass
 
-    @staticmethod
-    def does_it_return(code):
-        if code.strip() == '':
-            return False
-        indented_code = sourceutils.fix_indentation(code, 0)
-        node = _parse_text(indented_code)
-        visitor = _ReturnOrYieldFinder()
-        ast.walk(node, visitor)
-        return visitor.returns
+def _return_count(code):
+    if code.strip() == '':
+        return False
+    indented_code = sourceutils.fix_indentation(code, 0)
+    node = _parse_text(indented_code)
+    visitor = _ReturnOrYieldFinder()
+    ast.walk(node, visitor)
+    return visitor.returns
+
+def _returns_last(code):
+    if code.strip() == '':
+        return False
+    indented_code = sourceutils.fix_indentation(code, 0)
+    node = _parse_text(indented_code)
+    return node.body and isinstance(node.body[-1], ast.Return)
 
 
 class _UnmatchedBreakOrContinueFinder(object):

ropetest/refactor/extracttest.py

 
     @testutils.assert_raises(rope.base.exceptions.RefactoringError)
     def test_extract_method_containing_return(self):
-        code = "def a_func(arg):\n    return arg * 2\n"
-        start, end = self._convert_line_range_to_offset(code, 2, 2)
+        code = 'def a_func(arg):\n    if arg:\n        return arg * 2\n    return 1'
+        start, end = self._convert_line_range_to_offset(code, 2, 4)
         self.do_extract_method(code, start, end, 'new_func')
 
     @testutils.assert_raises(rope.base.exceptions.RefactoringError)
                    '\ndef two():\n    return next(1)\n\nvar = two()\n'
         self.assertEquals(expected, refactored)
 
+    def test_extracting_with_only_one_return(self):
+        code = 'def f():\n    var = 1\n    return var\n'
+        start, end = self._convert_line_range_to_offset(code, 2, 3)
+        refactored = self.do_extract_method(code, start, end, 'g')
+        expected = 'def f():\n    return g()\n\n' \
+                   'def g():\n    var = 1\n    return var\n'
+        self.assertEquals(expected, refactored)
+
 
 if __name__ == '__main__':
     unittest.main()
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.