Commits

Anonymous committed 36635a6

Extract method and multi-line header host functions

Comments (0)

Files changed (11)

 To Be Discussed
 ===============
 
-* Having a persistant form for PyNames and PyObjects for saving
+* Having a persistant format for PyNames and PyObjects for saving
   hard to compute information like class hierarchies.
 
 
 > Public Release 0.3m2 : September 17, 2006
 
 
-* Rename method/attribute
-
-
-* Dynamic type inference using Cartesian algorithm
-
-
 * Subversion support using pysvn
 
 
-* Renaming attributes in a hierarchy
+* Renaming methods in a hierarchy
 
 
 * Transform module to package refactoring
 Renaming Attributes In Class Hierarchies
 ========================================
 
-* Think about special cases for renaming
 * Refactor `rope.refactor.rename`
 * Refactor `rope.refactor.extract`
+* Think about special cases for renaming
 
+* Extract methos in a function with multi-line header
 * Adding `ropetest.classestest` to `runtests`
 * Saving active editor before "transform module to package" refactoring
 * PyCore.get_all_subclasses

rope/refactor/extract.py

 
 class _ExtractMethodPerformer(object):
     
-    def __init__(self, refactoring, resource, start_offset,
-                 end_offset, extracted_name):
+    def __init__(self, refactoring, resource, start_offset, end_offset, extracted_name):
         self.refactoring = refactoring
         source_code = resource.read()
         self.source_code = source_code
         self.holding_scope = self.scope.get_inner_scope_for_line(start_line)
         if self.holding_scope.pyobject.get_type() != \
            rope.pyobjects.PyObject.get_base_type('Module') and \
-           self.holding_scope.get_start()  == start_line:
+           self.holding_scope.get_start() == start_line:
             self.holding_scope = self.holding_scope.parent
         self.scope_start = self.lines.get_line_start(self.holding_scope.get_start())
         self.scope_end = self.lines.get_line_end(self.holding_scope.get_end()) + 1
 
-        self.is_method = self.holding_scope.parent is not None and \
-                         self.holding_scope.parent.pyobject.get_type() == \
-                         rope.pyobjects.PyObject.get_base_type('Type')
-        self.is_global = self.holding_scope.pyobject.get_type() == \
-                         rope.pyobjects.PyObject.get_base_type('Module')
         self.scope_indents = self._get_indents(self.holding_scope.get_start()) + 4
-        if self.is_global:
+        if self._is_global():
             self.scope_indents = 0
         self._check_exceptional_conditions()
+        self.info_collector = self._create_info_collector()
+
+    def _is_global(self):
+        return self.holding_scope.pyobject.get_type() == \
+               rope.pyobjects.PyObject.get_base_type('Module')
+
+    def _is_method(self):
+        return self.holding_scope.parent is not None and \
+               self.holding_scope.parent.pyobject.get_type() == \
+               rope.pyobjects.PyObject.get_base_type('Type')
     
     def _check_exceptional_conditions(self):
         if self.holding_scope.pyobject.get_type() == rope.pyobjects.PyObject.get_base_type('Type'):
             raise RefactoringException('Bad range selected for extract method')
         if _ReturnFinder.does_it_return(self.source_code[self.start_offset:self.end_offset]):
             raise RefactoringException('Extracted piece should not contain return statements')
-        
+
+    def _create_info_collector(self):
+        zero = self.holding_scope.get_start() - 1
+        start_line = self.lines.get_line_number(self.start_offset) - zero
+        end_line = self.lines.get_line_number(self.end_offset) - 1 - zero
+        info_collector = _FunctionInformationCollector(start_line, end_line,
+                                                       self._is_global())
+        indented_body = self.source_code[self.scope_start:self.scope_end]
+        body = _indent_lines(indented_body, -_find_minimum_indents(indented_body))
+        ast = compiler.parse(body)
+        compiler.walk(ast, info_collector)
+        return info_collector
+
     def extract(self):
         args = self._find_function_arguments()
         returns = self._find_function_returns()
         
         result = []
         result.append(self.source_code[:self.start_offset])
-        if self.is_global:
+        if self._is_global():
             result.append('\n%s\n' % self._get_function_definition())
         call_prefix = ''
         if returns:
         result.append(' ' * self.first_line_indents + call_prefix
                       + self._get_function_call(args) + '\n')
         result.append(self.source_code[self.end_offset:self.scope_end])
-        if not self.is_global:
+        if not self._is_global():
             result.append('\n%s' % self._get_function_definition())
         result.append(self.source_code[self.scope_end:])
         return ''.join(result)
     def _get_function_definition(self):
         args = self._find_function_arguments()
         returns = self._find_function_returns()
-        if not self.is_global:
+        if not self._is_global():
             function_indents = self.scope_indents
         else:
             function_indents = 4
     
     def _get_function_signature(self, args):
         args = list(args)
-        if self.is_method:
+        if self._is_method():
             if 'self' in args:
                 args.remove('self')
             args.insert(0, 'self')
     
     def _get_function_call(self, args):
         prefix = ''
-        if self.is_method:
+        if self._is_method():
             if  'self' in args:
                 args.remove('self')
             prefix = 'self.'
             result += names[0]
             for name in names[1:]:
                 result += ', ' + name
-        return result        
+        return result
     
     def _find_function_arguments(self):
-        start1 = self.lines.get_line_start(self.holding_scope.get_start() + 1)
-        code1 = self.source_code[start1:self.start_offset] + \
-                '%spass' % (' ' * self.first_line_indents)
-        read1, written1 = _VariableReadsAndWritesFinder.find_reads_and_writes(code1)
-        if self.holding_scope.pyobject.get_type() == rope.pyobjects.PyObject.get_base_type('Function'):
-            written1.update(self._get_function_arg_names())
-        
-        code2 = self.source_code[self.start_offset:self.end_offset]
-        read2, written2 = _VariableReadsAndWritesFinder.find_reads_and_writes(code2)
-        return list(written1.intersection(read2))
-    
-    def _get_function_arg_names(self):
-        indents = self._get_indents(self.holding_scope.get_start())
-        function_header_end = min(self.source_code.index('):\n', self.scope_start) + 1,
-                                  self.scope_end)
-        function_header = _indent_lines(self.source_code[self.scope_start:
-                                                              function_header_end], -indents) + \
-                                                              ':\n' + ' ' * 4 + 'pass'
-        ast = compiler.parse(function_header)
-        visitor = _FunctionArgnamesCollector()
-        compiler.walk(ast, visitor)
-        return visitor.argnames
-        
+        return list(self.info_collector.prewritten.intersection(self.info_collector.read))
     
     def _find_function_returns(self):
-        code2 = self.source_code[self.start_offset:self.end_offset]
-        read2, written2 = _VariableReadsAndWritesFinder.find_reads_and_writes(code2)
-        code3 = self.source_code[self.end_offset:self.scope_end]
-        read3, written3 = _VariableReadsAndWritesFinder.find_reads_and_writes(code3)
-        return list(written2.intersection(read3))
+        return list(self.info_collector.written.intersection(self.info_collector.postread))
         
     def _choose_closest_line_end(self, source_code, offset):
         lineno = self.lines.get_line_number(offset)
                 break
         return indents
     
+
 def _find_minimum_indents(source_code):
     result = 80
     lines = source_code.split('\n')
     return '\n'.join(result)
     
 
+class _FunctionInformationCollector(object):
+    
+    def __init__(self, start, end, is_global):
+        self.start = start
+        self.end = end
+        self.is_global = is_global
+        self.prewritten = set()
+        self.written = set()
+        self.read = set()
+        self.postread = set()
+        self.host_function = True
+    
+    def _read_variable(self, name, lineno):
+        if self.start <= lineno <= self.end:
+            self.read.add(name)
+        if self.end < lineno:
+            self.postread.add(name)
+    
+    def _written_variable(self, name, lineno):
+        if self.start <= lineno <= self.end:
+            self.written.add(name)
+        if self.start > lineno:
+            self.prewritten.add(name)
+        
+    def visitFunction(self, node):
+        if not self.is_global and self.host_function:
+            self.host_function = False
+            for name in node.argnames:
+                self._written_variable(name, node.lineno)
+            compiler.walk(node.code, self)
+        else:
+            self._written_variable(node.name, node.lineno)
+            visitor = _VariableReadsAndWritesFinder()
+            compiler.walk(node.code, visitor)
+            for name in visitor.read - visitor.written:
+                self._read_variable(name, node.lineno)
+
+    def visitAssName(self, node):
+        self._written_variable(node.name, node.lineno)
+    
+    def visitName(self, node):
+        self._read_variable(node.name, node.lineno)
+    
+    def visitClass(self, node):
+        self._written_variable(node.name, node.lineno)
+    
+
 class _VariableReadsAndWritesFinder(object):
     
     def __init__(self):
     def visitReturn(self, node):
         self.returns = True
 
+    def visitYield(self, node):
+        self.returns = True
+
     def visitFunction(self, node):
         pass
     
         compiler.walk(ast, visitor)
         return visitor.returns
 
-
-class _FunctionArgnamesCollector(object):
-    
-    def __init__(self):
-        self.argnames = []
-    
-    def visitFunction(self, node):
-        self.argnames = node.argnames

rope/ui/codeassist.py

 core = rope.ui.core.Core.get_core()
 actions = []
 
-actions.append(SimpleAction('Correct Line Indentation', do_correct_line_indentation,
-                            'C-i',
-                            MenuAddress(['Code', 'Correct Line Indentation'], 'i')))
-actions.append(SimpleAction('Quick Outline', do_quick_outline, 'C-o',
-                            MenuAddress(['Code', 'Quick Outline'], 'q')))
 actions.append(SimpleAction('Code Assist', do_code_assist, 'M-slash',
                             MenuAddress(['Code', 'Code Assist'], 'c')))
 actions.append(SimpleAction('Goto Definition', do_goto_definition, 'F3',
                             MenuAddress(['Code', 'Goto Definition'], 'g')))
 actions.append(SimpleAction('Show Doc', do_show_doc, 'F2',
                             MenuAddress(['Code', 'Show Doc'], 's')))
+
+actions.append(SimpleAction('Correct Line Indentation',
+                            do_correct_line_indentation, 'C-i',
+                            MenuAddress(['Code', 'Correct Line Indentation'], 'i', 1)))
+actions.append(SimpleAction('Quick Outline', do_quick_outline, 'C-o',
+                            MenuAddress(['Code', 'Quick Outline'], 'q', 2)))
 actions.append(SimpleAction('Run Module', do_run_module, 'C-F11',
-                            MenuAddress(['Code', 'Run Module'], 'm')))
+                            MenuAddress(['Code', 'Run Module'], 'm', 2)))
 
 for action in actions:
     core.register_action(action)
 actions = []
 actions.append(SimpleAction('Rename Refactoring', ConfirmAllEditorsAreSaved(rename), 'M-R',
                             MenuAddress(['Refactor', 'Rename'], 'r')))
+actions.append(SimpleAction('Extract Method', ConfirmAllEditorsAreSaved(extract_method), 'M-M',
+                            MenuAddress(['Refactor', 'Extract Method'], 'e')))
+actions.append(SimpleAction('Rename Local Variable', ConfirmAllEditorsAreSaved(local_rename), None,
+                            MenuAddress(['Refactor', 'Rename Local Variable'], 'l')))
 actions.append(SimpleAction('Transform Module To Package', 
                             ConfirmAllEditorsAreSaved(transform_module_to_package), None,
-                            MenuAddress(['Refactor', 'Transform Module To Package'], 't')))
+                            MenuAddress(['Refactor', 'Transform Module To Package'], 't', 1)))
 actions.append(SimpleAction('Undo Last Refactoring', 
                             ConfirmAllEditorsAreSaved(undo_last_refactoring), None,
-                            MenuAddress(['Refactor', 'Undo Last Refactoring'], 'u')))
-
-actions.append(SimpleAction('Rename Local Variable', ConfirmAllEditorsAreSaved(local_rename), None,
-                            MenuAddress(['Refactor', 'Rename Local Variable'], 'e', 1)))
-actions.append(SimpleAction('Extract Method', ConfirmAllEditorsAreSaved(extract_method), 'M-M',
-                            MenuAddress(['Refactor', 'Extract Method'], 'e', 1)))
+                            MenuAddress(['Refactor', 'Undo Last Refactoring'], 'u', 2)))
 
 core = rope.ui.core.Core.get_core()
 for action in actions:

ropetest/codeassisttest.py

         self.assertEquals(['name_var'], template.variables())
         self.assertEquals('Name = Ali', template.substitute({'name_var': 'Ali'}))
 
+    @testutils.assert_raises(KeyError)
     def test_unmapped_variable(self):
         template = Template('Name = ${name}')
-        try:
-            template.substitute({})
-            self.fail('Expected keyError')
-        except KeyError:
-            pass
+        template.substitute({})
 
     def test_double_dollar_sign(self):
         template = Template('Name = $${name}')

ropetest/projecttest.py

         projectFile = self.project.get_resource(self.projectMaker.get_sample_file_name())
         self.assertEquals(self.projectMaker.get_sample_file_contents(), projectFile.read())
     
+    @testutils.assert_raises(RopeException)
     def test_getting_not_existing_project_file(self):
-        try:
-            projectFile = self.project.get_resource('DoesNotExistFile.txt')
-            self.fail('Should have failed')
-        except RopeException:
-            pass
+        projectFile = self.project.get_resource('DoesNotExistFile.txt')
+        self.fail('Should have failed')
 
     def test_writing_in_project_files(self):
         projectFile = self.project.get_resource(self.projectMaker.get_sample_file_name())
         newFile = self.project.get_resource(projectFile)
         self.assertTrue(newFile is not None)
 
+    @testutils.assert_raises(RopeException)
     def test_creating_files_that_already_exist(self):
-        try:
-            self.project.get_root_folder().create_file(self.projectMaker.get_sample_file_name())
-            self.fail('Should have failed')
-        except RopeException:
-            pass
+        self.project.get_root_folder().create_file(self.projectMaker.get_sample_file_name())
+        self.fail('Should have failed')
 
     def test_making_root_folder_if_it_does_not_exist(self):
         projectRoot = 'SampleProject2'
         finally:
             testutils.remove_recursively(projectRoot)
 
+    @testutils.assert_raises(RopeException)
     def test_failure_when_project_root_exists_and_is_a_file(self):
-        projectRoot = 'SampleProject2'
-        open(projectRoot, 'w').close()
         try:
+            projectRoot = 'SampleProject2'
+            open(projectRoot, 'w').close()
             project = Project(projectRoot)
-            self.fail('Should have failed')
-        except RopeException:
+        finally:
             os.remove(projectRoot)
 
     def test_creating_folders(self):
         folderPath = os.path.join(self.project.get_root_address(), folderName)
         self.assertTrue(os.path.exists(folderPath) and os.path.isdir(folderPath))
 
+    @testutils.assert_raises(RopeException)
     def test_making_folder_that_already_exists(self):
         folderName = 'SampleFolder'
         self.project.get_root_folder().create_folder(folderName)
-        try:
-            self.project.get_root_folder().create_folder(folderName)
-            self.fail('Should have failed')
-        except RopeException:
-            pass
+        self.project.get_root_folder().create_folder(folderName)
 
+    @testutils.assert_raises(RopeException)
     def test_failing_if_creating_folder_while_file_already_exists(self):
         folderName = 'SampleFolder'
         self.project.get_root_folder().create_file(folderName)
-        try:
-            self.project.get_root_folder().create_folder(folderName)
-            self.fail('Should have failed')
-        except RopeException:
-            pass
+        self.project.get_root_folder().create_folder(folderName)
 
     def test_creating_file_inside_folder(self):
         folder_name = 'sampleFolder'
         file = self.project.get_resource(file_path)
         file.write('sample notes')
         self.assertEquals(file_path, file.get_path())
-        self.assertEquals('sample notes',
-                          open(os.path.join(self.project.get_root_address(),
-                                            file_path))
-                          .read())
+        self.assertEquals('sample notes', open(os.path.join(self.project.get_root_address(),
+                                                            file_path)).read())
 
+    @testutils.assert_raises(RopeException)
     def test_failing_when_creating_file_inside_non_existant_folder(self):
-        try:
-            self.project.get_root_folder().create_file('NonexistantFolder/SomeFile.txt')
-            self.fail('Should have failed')
-        except RopeException:
-            pass
+        self.project.get_root_folder().create_file('NonexistantFolder/SomeFile.txt')
 
     def test_nested_directories(self):
         folder_name = 'SampleFolder'
         self.assertFalse(os.path.exists(os.path.join(self.project.get_root_address(),
                                                      self.projectMaker.get_sample_folder_name())))
 
+    @testutils.assert_raises(RopeException)
     def test_removing_non_existant_files(self):
-        try:
-            self.project.get_resource('NonExistantFile.txt').remove()
-            self.fail('Should have failed')
-        except RopeException:
-            pass
+        self.project.get_resource('NonExistantFile.txt').remove()
 
     def test_removing_nested_files(self):
         fileName = self.projectMaker.get_sample_folder_name() + '/SampleFile.txt'

ropetest/pycoretest.py

         var = sample_class.get_attribute('InnerClass').get_object()
         self.assertEquals(PyObject.get_base_type('Type'), var.get_type())
 
+    @testutils.assert_raises(ModuleNotFoundException)
     def test_non_existant_module(self):
-        try:
-            self.pycore.get_module('mod')
-            self.fail('And exception should have been raised')
-        except ModuleNotFoundException:
-            pass
+        self.pycore.get_module('mod')
 
     def test_imported_names(self):
         self.pycore.create_module(self.project.get_root_folder(), 'mod1')

ropetest/refactortest.py

                    "    return inner_func\n"
         self.assertEquals(expected, refactored)
 
+    @testutils.assert_raises(RefactoringException)
     def test_extract_method_bad_range(self):
         code = "def a_func():\n    pass\na_var = 10\n"
         start, end = self._convert_line_range_to_offset(code, 2, 3)
-        try:
-            self.do_extract_method(code, start, end, 'new_func')
-        except RefactoringException:
-            pass
-        else:
-            self.fail('Should have thrown exception')
+        self.do_extract_method(code, start, end, 'new_func')
 
+    @testutils.assert_raises(RefactoringException)
     def test_extract_method_bad_range2(self):
         code = "class AClass(object):\n    pass\n"
         start, end = self._convert_line_range_to_offset(code, 1, 1)
-        try:
-            self.do_extract_method(code, start, end, 'new_func')
-        except RefactoringException:
-            pass
-        else:
-            self.fail('Should have thrown exception')
+        self.do_extract_method(code, start, end, 'new_func')
 
+    @testutils.assert_raises(RefactoringException)
     def test_extract_method_containing_return(self):
         code = "def a_func(arg):\n    return arg * 2\n"
         start, end = self._convert_line_range_to_offset(code, 2, 2)
-        try:
-            self.do_extract_method(code, start, end, 'new_func')
-        except RefactoringException:
-            pass
-        else:
-            self.fail('Should have thrown exception')
+        self.do_extract_method(code, start, end, 'new_func')
+
+    @testutils.assert_raises(RefactoringException)
+    def test_extract_method_containing_yield(self):
+        code = "def a_func(arg):\n    yield arg * 2\n"
+        start, end = self._convert_line_range_to_offset(code, 2, 2)
+        self.do_extract_method(code, start, end, 'new_func')
 
     def test_extract_function_and_argument_as_paramenter(self):
         code = "def a_func(arg):\n    print arg\n"
                    "def new_func(arg):\n    if True:\n        print arg\n"
         self.assertEquals(expected, refactored)
     
+    def test_extract_method_and_multi_line_headers(self):
+        code = "def a_func(\n           arg):\n    print arg\n"
+        start, end = self._convert_line_range_to_offset(code, 3, 3)
+        refactored = self.do_extract_method(code, start, end, 'new_func')
+        expected = "def a_func(\n           arg):\n    new_func(arg)\n\n" \
+                   "def new_func(arg):\n    print arg\n"
+        self.assertEquals(expected, refactored)
+    
     def test_transform_module_to_package(self):
         mod1 = self.pycore.create_module(self.project.get_root_folder(), 'mod1')
         mod1.write('import mod2\nfrom mod2 import AClass\n')

ropetest/testutils.py

 
 
 def run_only_for_25(func):
+    """Should be used as a decorator for a unittest.TestCase test"""
     if sys.version.startswith('2.5'):
         return func
     else:
         def do_nothing(self):
             pass
         return do_nothing
+
+
+def assert_raises(exception_class):
+    """Should be used as a decorator for a unittest.TestCase test"""
+    def _assert_raises(func):
+        def call_func(self, *args, **kws):
+            self.assertRaises(exception_class, func, self, *args, **kws)
+        return call_func
+    return _assert_raises