Commits

Ali Gholami Rudi  committed f99d91c

extract: not passing global names when extracting global functions

  • Participants
  • Parent commits bc520c6

Comments (0)

Files changed (2)

File rope/refactor/extract.py

         return call_prefix + self._get_function_call(args)
 
     def _find_function_arguments(self):
+        # if not make_global, do not pass any global names; they are
+        # all visible.
+        if self.info.global_ and not self.info.make_global:
+            return ()
         if not self.info.one_line:
             return list(self.info_collector.prewritten.
                         intersection(self.info_collector.read))

File ropetest/refactor/extracttest.py

         code = 'a = 1\na = 1 + a\n'
         start = code.index('\n') + 1
         end = len(code)
-        refactored = self.do_extract_method(code, start, end, 'new_f')
+        refactored = self.do_extract_method(code, start, end, 'new_f',
+                                            global_=True)
         expected = 'a = 1\n\ndef new_f(a):\n    a = 1 + a\n\nnew_f(a)\n'
         self.assertEquals(expected, refactored)
 
         code = 'a = 1\na += 1\n'
         start = code.index('\n') + 1
         end = len(code)
-        refactored = self.do_extract_method(code, start, end, 'new_f')
+        refactored = self.do_extract_method(code, start, end, 'new_f',
+                                            global_=True)
         expected = 'a = 1\n\ndef new_f():\n    a += 1\n\nnew_f()\n'
         self.assertEquals(expected, refactored)
 
                    '    except Exception:\n        pass\n'
         self.assertEquals(expected, refactored)
 
+    def test_extract_and_not_passing_global_functions(self):
+        code = 'def next(p):\n    return p + 1\nvar = next(1)\n'
+        start = code.rindex('next')
+        refactored = self.do_extract_method(code, start, len(code) - 1, 'two')
+        expected = 'def next(p):\n    return p + 1\n' \
+                   '\ndef two():\n    return next(1)\n\nvar = two()\n'
+        self.assertEquals(expected, refactored)
+
 
 if __name__ == '__main__':
     unittest.main()