Commits

Ali Gholami Rudi  committed 8bb9d18

usefunction: merged return count parts with extract

  • Participants
  • Parent commits 90d67b7

Comments (0)

Files changed (2)

File rope/refactor/extract.py

     def returned(self):
         """Does the extracted piece contain return statement"""
         if self._returned is None:
-            self._returned = _returns_last(self.extracted)
+            node = _parse_text(self.extracted)
+            self._returned = usefunction._returns_last(node)
         return self._returned
 
 
                                    'span multiple lines.')
 
     def multi_line_conditions(self, info):
-        code = info.source[info.region[0]:info.region[1]]
-        count = _return_count(code)
+        node = _parse_text(info.source[info.region[0]:info.region[1]])
+        count = usefunction._return_count(node)
         if count > 1:
             raise RefactoringError('Extracted piece can have only one '
                                    'return statement.')
-        if count == 1 and not _returns_last(code):
+        if usefunction._yield_count(node):
+            raise RefactoringError('Extracted piece cannot '
+                                   'have yield statements.')
+        if count == 1 and not usefunction._returns_last(node):
             raise RefactoringError('Return should be the last statement.')
         if info.region != info.lines_region:
             raise RefactoringError('Extracted piece should '
         end_line = self.info.region_lines[1] - zero
         info_collector = _FunctionInformationCollector(start_line, end_line,
                                                        self.info.global_)
-        indented_body = self.info.source[self.info.scope_region[0]:
-                                         self.info.scope_region[1]]
-        body = sourceutils.fix_indentation(indented_body, 0)
+        body = self.info.source[self.info.scope_region[0]:
+                                self.info.scope_region[1]]
         node = _parse_text(body)
         ast.walk(node, info_collector)
         return info_collector
     def find_reads_and_writes(code):
         if code.strip() == '':
             return set(), set()
-        indented_code = sourceutils.fix_indentation(code, 0)
-        if isinstance(indented_body, unicode):
-            indented_body = indented_body.encode('utf-8')
-        node = _parse_text(indented_code)
+        if isinstance(code, unicode):
+            code = code.encode('utf-8')
+        node = _parse_text(code)
         visitor = _VariableReadsAndWritesFinder()
         ast.walk(node, visitor)
         return visitor.read, visitor.written
         return visitor.read
 
 
-class _ReturnOrYieldFinder(object):
-
-    def __init__(self):
-        self.returns = 0
-
-    def _Return(self, node):
-        self.returns += 1
-
-    def _Yield(self, node):
-        self.returns += 1
-
-    def _FunctionDef(self, node):
-        pass
-
-    def _ClassDef(self, node):
-        pass
-
-def _return_count(code):
-    if code.strip() == '':
-        return False
-    indented_code = sourceutils.fix_indentation(code, 0)
-    node = _parse_text(indented_code)
-    visitor = _ReturnOrYieldFinder()
-    ast.walk(node, visitor)
-    return visitor.returns
-
-def _returns_last(code):
-    if code.strip() == '':
-        return False
-    indented_code = sourceutils.fix_indentation(code, 0)
-    node = _parse_text(indented_code)
-    return node.body and isinstance(node.body[-1], ast.Return)
-
-
 class _UnmatchedBreakOrContinueFinder(object):
 
     def __init__(self):
     def has_errors(code):
         if code.strip() == '':
             return False
-        indented_code = sourceutils.fix_indentation(code, 0)
-        node = _parse_text(indented_code)
+        node = _parse_text(code)
         visitor = _UnmatchedBreakOrContinueFinder()
         ast.walk(node, visitor)
         return visitor.error
 
 
 def _parse_text(body):
+    body = sourceutils.fix_indentation(body, 0)
     node = ast.parse(body)
     return node
 

File rope/refactor/usefunction.py

         self._check_returns()
 
     def _check_returns(self):
-        class CountReturns(object):
-            returns = 0
-            yields = 0
-            def __call__(self, node):
-                if isinstance(node, ast.Return):
-                    self.returns += 1
-                if isinstance(node, ast.Yield):
-                    self.yields += 1
-        counter = CountReturns()
         node = self.pyfunction.get_ast()
-        ast.call_for_nodes(node, counter, recursive=True)
-        if counter.yields:
+        if _yield_count(node):
             raise exceptions.RefactoringError('Use function should not '
                                               'be used on generators.')
-        if counter.returns > 1:
-            raise exceptions.RefactoringError(
-                'usefunction: Function has more than '
-                'one return statement.')
-        if counter.returns == 1 and not isinstance(node.body[-1], ast.Return):
-            raise exceptions.RefactoringError(
-                'usefunction: return should be the last statement.')
+        returns = _return_count(node)
+        if returns > 1:
+            raise exceptions.RefactoringError('usefunction: Function has more '
+                                              'than one return statement.')
+        if returns == 1 and not _returns_last(node):
+            raise exceptions.RefactoringError('usefunction: return should '
+                                              'be the last statement.')
 
     def get_changes(self, resources=None,
                     task_handle=taskhandle.NullTaskHandle()):
         if isinstance(pyname, pynames.AssignedName):
             result.append(name)
     return result
+
+
+def _returns_last(node):
+    return node.body and isinstance(node.body[-1], ast.Return)
+
+def _yield_count(node):
+    visitor = _ReturnOrYieldFinder()
+    visitor.start_walking(node)
+    return visitor.yields
+
+def _return_count(node):
+    visitor = _ReturnOrYieldFinder()
+    visitor.start_walking(node)
+    return visitor.returns
+
+class _ReturnOrYieldFinder(object):
+
+    def __init__(self):
+        self.returns = 0
+        self.yields = 0
+
+    def _Return(self, node):
+        self.returns += 1
+
+    def _Yield(self, node):
+        self.yields += 1
+
+    def _FunctionDef(self, node):
+        pass
+
+    def _ClassDef(self, node):
+        pass
+
+    def start_walking(self, node):
+        nodes = [node]
+        if isinstance(node, ast.FunctionDef):
+            nodes = ast.get_child_nodes(node)
+        for child in nodes:
+            ast.walk(child, self)