Commits

Ali Gholami Rudi  committed 4b6524b

Not changing static and class methods when extracting normal methods

  • Participants
  • Parent commits 0049e73

Comments (0)

Files changed (9)

File docs/dev/stories.txt

 * Extract class
 
 
+* Split tuple assignment refactoring
+
+
 * Supporting templates in text modes
 
 
 * Inferring the object, list comprehensions or generator expressions hold
 
 
+* Renaming and moving normal files/folders
+
+
 > Public Release 0.6m2 : June 3, 2007
-
-
-* Renaming and moving normal files/folders

File docs/dev/workingon.txt

 Small Stories
 =============
 
-- Adding `PyScope.get_body_start()`
+- Not changing static and class methods when extracting normal methods
 
-* Using get_method_kind wherever possible
-* Using `LogicalLineFinder` in `_HoldingScopeFinder.find_scope_end()`?
 * Handling strings in following lines in `patchedast`
-* Split tuple assignment refactoring; ``a, b = 1, 2`` with ``a = 1\nb = 2``
-  Or ``a, b = x`` with ``a = x[0]\nb = x[1]``
 * Extracting subexpressions; look at `extracttest` for more info
 * Create ... and implicit interfaces
 * Add custom refactoring section to ``overview.txt`` file;

File rope/base/oi/transform.py

 
     def transform(self, textual):
         """Transform an object from textual form to `PyObject`"""
+        if textual is None:
+            return None
         type = textual[0]
         try:
             method = getattr(self, type + '_to_pyobject')

File rope/base/pyscopes.py

         if not scope.parent:
             return self.lines.length()
         end = scope.pyobject.get_ast().body[-1].lineno
+        # IDEA: Can we use LogicalLineFinder here?
         body_indents = self._get_body_indents(scope)
         for l in range(end + 1, self.lines.length() + 1):
             if not self._is_empty_line(l):

File rope/refactor/extract.py

         return extract_collector
 
     def _find_matches(self, collector):
+        regions = self._where_to_search()
+        finder = similarfinder.CheckingFinder(self.info.pymodule, {})
+        matches = []
+        for start, end in regions:
+            matches.extend((finder.get_matches(
+                            collector.body_pattern, start, end)))
+        collector.matches = matches
+
+    def _where_to_search(self):
         if self.info.similar:
             if self.info.method and not self.info.variable:
                 class_scope = self.info.scope.parent
-                start = self.info.lines.get_line_start(class_scope.get_start())
-                end = self.info.lines.get_line_end(class_scope.get_end())
+                regions = []
+                method_kind = _get_method_kind(self.info.scope)
+                for scope in class_scope.get_scopes():
+                    if method_kind == 'normal' and \
+                       _get_method_kind(scope) != 'normal':
+                        continue
+                    start = self.info.lines.get_line_start(scope.get_start())
+                    end = self.info.lines.get_line_end(scope.get_end())
+                    regions.append((start, end))
+                return regions
             else:
-                start, end = self.info.scope_region
+                return [self.info.scope_region]
         else:
-            start, end = self.info.region
-        finder = similarfinder.CheckingFinder(self.info.pymodule,
-                                              {}, start, end)
-        collector.matches = list(finder.get_matches(collector.body_pattern))
+            return [self.info.region]
 
     def _find_definition_location(self, collector):
         matched_lines = []
         args = self._find_function_arguments()
         returns = self._find_function_returns()
         result = []
-        if self.info.method and self._get_method_kind() != 'normal':
+        if self.info.method and _get_method_kind(self.info.scope) != 'normal':
             result.append('@staticmethod\n')
         result.append('def %s:\n' % self._get_function_signature(args))
         unindented_body = self._get_unindented_function_body(returns)
     def _get_function_signature(self, args):
         args = list(args)
         prefix = ''
-        if self.info.method and self._get_method_kind() == 'normal':
+        if self.info.method and _get_method_kind(self.info.scope) == 'normal':
             self_name = self._get_self_name()
             if self_name in args:
                 args.remove(self_name)
         return prefix + self.info.new_name + \
                '(%s)' % self._get_comma_form(args)
 
-    def _get_method_kind(self):
-        """Get the type of a method
-
-        It returns 'normal', 'static', or 'class'
-
-        """
-        ast = self.info.scope.pyobject.get_ast()
-        for decorator in ast.decorators:
-            pyname = evaluate.get_statement_result(self.info.scope.parent,
-                                                   decorator)
-            if pyname == builtins.builtins['staticmethod']:
-                return 'static'
-            if pyname == builtins.builtins['classmethod']:
-                return 'class'
-        return 'normal'
-
     def _get_self_name(self):
         param_names = self.info.scope.pyobject.get_param_names()
         if param_names:
     def _get_function_call(self, args):
         prefix = ''
         if self.info.method:
-            if self._get_method_kind() == 'normal':
+            if _get_method_kind(self.info.scope) == 'normal':
                 self_name = self._get_self_name()
                 if  self_name in args:
                     args.remove(self_name)
         ast.walk(node, visitor)
         return visitor.error
 
+def _get_method_kind(scope):
+    """Get the type of a method
+
+    It returns 'normal', 'static', or 'class'
+
+    """
+    ast = scope.pyobject.get_ast()
+    for decorator in ast.decorators:
+        pyname = evaluate.get_statement_result(scope.parent,
+                                               decorator)
+        if pyname == builtins.builtins['staticmethod']:
+            return 'static'
+        if pyname == builtins.builtins['classmethod']:
+            return 'class'
+    return 'normal'
 
 def _parse_text(body):
     if isinstance(body, unicode):

File rope/refactor/similarfinder.py

 class SimilarFinder(object):
     """A class for finding similar expressions and statements"""
 
-    def __init__(self, source, start=0, end=None):
+    def __init__(self, source):
         node = ast.parse(source)
-        self._init_using_ast(node, source, start, end)
+        self._init_using_ast(node, source)
 
-    def _init_using_ast(self, node, source, start, end):
-        self.start = start
-        self.end = len(source)
-        if end is not None:
-            self.end = end
+    def _init_using_ast(self, node, source):
+        self.source = source
+        self._matched_asts = {}
         if not hasattr(node, 'sorted_children'):
             self.ast = patchedast.patch_ast(node, source)
 
-    def get_matches(self, code):
+    def get_matches(self, code, start=0, end=None):
         """Search for `code` in source and return a list of `Match`\es
 
         `code` can contain wildcards.  ``${name}`` matches normal
         You can use `Match.get_ast()` for getting the node that has
         matched a given pattern.
         """
-        wanted = self._create_pattern(code)
-        matches = _ASTMatcher(self.ast, wanted).find_matches()
-        for match in matches:
-            start, end = match.get_region()
-            if self.start <= start and end <= self.end:
+        if end is None:
+            end = len(self.source)
+        for match in self._get_matched_asts(code):
+            match_start, match_end = match.get_region()
+            if start <= match_start and match_end <= end:
                 yield match
 
-    def get_match_regions(self, code):
-        for match in self.get_matches(code):
+    def _get_matched_asts(self, code):
+        if code not in self._matched_asts:
+            wanted = self._create_pattern(code)
+            matches = _ASTMatcher(self.ast, wanted).find_matches()
+            self._matched_asts[code] = matches
+        return self._matched_asts[code]
+
+    def get_match_regions(self, code, start=0, end=None):
+        for match in self.get_matches(code, start=start, end=end):
             yield match.get_region()
 
     def _create_pattern(self, expression):
 
     """
 
-    def __init__(self, pymodule, checks, start=0, end=None):
+    def __init__(self, pymodule, checks):
         super(CheckingFinder, self)._init_using_ast(
-            pymodule.get_ast(), pymodule.source_code, start, end)
+            pymodule.get_ast(), pymodule.source_code)
         self.pymodule = pymodule
         self.checks = checks
 
-    def get_matches(self, code):
-        for match in SimilarFinder.get_matches(self, code):
+    def get_matches(self, code, start=0, end=None):
+        if end is None:
+            end = len(self.source)
+        for match in SimilarFinder.get_matches(self, code,
+                                               start=start, end=end):
             matched = True
             for check, expected in self.checks.items():
                 name, kind = self._split_name(check)

File rope/ui/refactor.py

         self.similar = Tkinter.IntVar()
         self.similar.set(1)
         similar = Tkinter.Checkbutton(
-            frame, text='Match similar expressions/statements',
+            frame, text='Extract similar expressions/statements',
             variable=self.similar)
         similar.grid(row=1, column=0, columnspan=2)
 

File ropetest/refactor/extracttest.py

                    '    @staticmethod\n    def one():\n        return 1\n'
         self.assertEquals(expected, refactored)
 
+    def test_extract_normal_method_with_staticmethods(self):
+        code = 'class AClass(object):\n\n' \
+               '    @staticmethod\n    def func1():\n        b = 1\n' \
+               '    def func2(self):\n        b = 1\n'
+        start = code.rindex(' 1') + 1
+        refactored = self.do_extract_method(code, start, start + 1,
+                                            'one', similar=True)
+        expected = 'class AClass(object):\n\n' \
+                   '    @staticmethod\n    def func1():\n        b = 1\n' \
+                   '    def func2(self):\n        b = self.one()\n\n' \
+                   '    def one(self):\n        return 1\n'
+        self.assertEquals(expected, refactored)
+
 
 if __name__ == '__main__':
     unittest.main()

File ropetest/refactor/similarfindertest.py

 
     def test_restricting_the_region_to_search(self):
         source = '1\n\n1\n'
-        finder = similarfinder.SimilarFinder(source, start=2)
-        result = list(finder.get_match_regions('1'))
+        finder = similarfinder.SimilarFinder(source)
+        result = list(finder.get_match_regions('1', start=2))
         start = source.rfind('1')
         self.assertEquals([(start, start + 1)], result)