Commits

Anonymous committed 23b8fec

Enhancing extract methods in staticmethods and classmethods

  • Participants
  • Parent commits 1077f27

Comments (0)

Files changed (4)

File docs/dev/done.txt

 ===========
 
 
+- Enhancing extract method on staticmethods/classmethods : June 2, 2007
+
+
 - Extracting similar expressions/statements : May 30, 2007
 
 

File docs/dev/workingon.txt

 Small Stories
 =============
 
-* Not relying on ``self`` name in extract
-* Handling `staticmethod` and `classmethod` in extract?
+- Not relying on ``self`` name in extract
+- Handling `staticmethod` in extract method
+- Handling `classmethod` in extract method
+
+* Using get_method_kind wherever possible
 * Using `LogicalLineFinder` in `_HoldingScopeFinder.find_scope_end()`?
-
 * Handling strings in following lines in `patchedast`
 * Split tuple assignment refactoring; ``a, b = 1, 2`` with ``a = 1\nb = 2``
   Or ``a, b = x`` with ``a = x[0]\nb = x[1]``

File rope/refactor/extract.py

 import re
 
 import rope.base.pyobjects
-from rope.base import ast, codeanalyze
+from rope.base import ast, builtins, codeanalyze, evaluate
 from rope.base.change import ChangeSet, ChangeContents
 from rope.base.exceptions import RefactoringError
 from rope.refactor import sourceutils, similarfinder, patchedast, suites
         args = self._find_function_arguments()
         returns = self._find_function_returns()
         result = []
+        if self.info.method and self._get_method_kind() != 'normal':
+            result.append('@staticmethod\n')
         result.append('def %s:\n' % self._get_function_signature(args))
         unindented_body = self._get_unindented_function_body(returns)
         indents = sourceutils.get_indent(self.info.pycore)
 
     def _get_function_signature(self, args):
         args = list(args)
-        if self.info.method:
+        prefix = ''
+        if self.info.method and self._get_method_kind() == 'normal':
             self_name = self._get_self_name()
             if self_name in args:
                 args.remove(self_name)
             args.insert(0, self_name)
-        return self.info.new_name + '(%s)' % self._get_comma_form(args)
+        return prefix + self.info.new_name + \
+               '(%s)' % self._get_comma_form(args)
+
+    def _get_method_kind(self):
+        """Get the type of a method
+
+        It returns 'normal', 'static', or 'class'
+
+        """
+        ast = self.info.scope.pyobject.get_ast()
+        for decorator in ast.decorators:
+            pyname = evaluate.get_statement_result(self.info.scope.parent,
+                                                   decorator)
+            if pyname == builtins.builtins['staticmethod']:
+                return 'static'
+            if pyname == builtins.builtins['classmethod']:
+                return 'class'
+        return 'normal'
 
     def _get_self_name(self):
         param_names = self.info.scope.pyobject.get_param_names()
         if param_names:
             return param_names[0]
-        else:
-            raise RefactoringError(
-                'Extracting from a non-method in class body is not supported yet')
 
     def _get_function_call(self, args):
         prefix = ''
         if self.info.method:
-            self_name = self._get_self_name()
-            if  self_name in args:
-                args.remove(self_name)
-            prefix = self_name + '.'
+            if self._get_method_kind() == 'normal':
+                self_name = self._get_self_name()
+                if  self_name in args:
+                    args.remove(self_name)
+                prefix = self_name + '.'
+            else:
+                prefix = self.info.scope.parent.pyobject.get_name() + '.'
         return prefix + '%s(%s)' % (self.info.new_name, self._get_comma_form(args))
 
     def _get_comma_form(self, names):

File ropetest/refactor/extracttest.py

                    '    def func2(self):\n        b = 1\n'
         self.assertEquals(expected, refactored)
 
+    def test_extract_method_in_staticmethods(self):
+        code = 'class AClass(object):\n\n' \
+               '    @staticmethod\n    def func2():\n        b = 1\n'
+        start = code.index(' 1') + 1
+        refactored = self.do_extract_method(code, start, start + 1,
+                                            'one', similar=True)
+        expected = 'class AClass(object):\n\n' \
+                   '    @staticmethod\n    def func2():\n        b = AClass.one()\n\n' \
+                   '    @staticmethod\n    def one():\n        return 1\n'
+        self.assertEquals(expected, refactored)
+
 
 if __name__ == '__main__':
     unittest.main()