| commit 1546: | d353b3842661 |
| parent 1545: | 8af0a3d05bc1 |
| branch: | trunk |
extract: handle conditional variable writes in extracted region
Reported by Issac Trotts <issac.trotts@gmail.com>
13 months ago
Changed (Δ2.0 KB):
raw changeset »
rope/refactor/extract.py (32 lines added, 5 lines removed)
ropetest/refactor/extracttest.py (31 lines added, 0 lines removed)
Up to file-list rope/refactor/extract.py:
| … | … | @@ -538,8 +538,11 @@ class _ExtractMethodParts(object): |
538 |
538 |
if self.info.global_ and not self.info.make_global: |
539 |
539 |
return () |
540 |
540 |
if not self.info.one_line: |
541 |
return list(self.info_collector.prewritten. |
|
542 |
intersection(self.info_collector.read)) |
|
541 |
result = (self.info_collector.prewritten & |
|
542 |
self.info_collector.read) |
|
543 |
result |= (self.info_collector.maybe_written & |
|
544 |
self.info_collector.postread) |
|
545 |
return list(result) |
|
543 |
546 |
start = self.info.region[0] |
544 |
547 |
if start == self.info.lines_region[0]: |
545 |
548 |
start = start + re.search('\S', self.info.extracted).start() |
| … | … | @@ -551,8 +554,9 @@ class _ExtractMethodParts(object): |
551 |
554 |
def _find_function_returns(self): |
552 |
555 |
if self.info.one_line or self.info.returned: |
553 |
556 |
return [] |
554 |
return list(self.info_collector.written. |
|
555 |
intersection(self.info_collector.postread)) |
|
557 |
written = self.info_collector.written | \ |
|
558 |
self.info_collector.maybe_written |
|
559 |
return list(written & self.info_collector.postread) |
|
556 |
560 |
|
557 |
561 |
def _get_unindented_function_body(self, returns): |
558 |
562 |
if self.info.one_line: |
| … | … | @@ -591,11 +595,13 @@ class _FunctionInformationCollector(obje |
591 |
595 |
self.end = end |
592 |
596 |
self.is_global = is_global |
593 |
597 |
self.prewritten = set() |
598 |
self.maybe_written = set() |
|
594 |
599 |
self.written = set() |
595 |
600 |
self.read = set() |
596 |
601 |
self.postread = set() |
597 |
602 |
self.postwritten = set() |
598 |
603 |
self.host_function = True |
604 |
self.conditional = False |
|
599 |
605 |
|
600 |
606 |
def _read_variable(self, name, lineno): |
601 |
607 |
if self.start <= lineno <= self.end: |
| … | … | @@ -607,7 +613,10 @@ class _FunctionInformationCollector(obje |
607 |
613 |
|
608 |
614 |
def _written_variable(self, name, lineno): |
609 |
615 |
if self.start <= lineno <= self.end: |
610 |
|
|
616 |
if self.conditional: |
|
617 |
self.maybe_written.add(name) |
|
618 |
else: |
|
619 |
self.written.add(name) |
|
611 |
620 |
if self.start > lineno: |
612 |
621 |
self.prewritten.add(name) |
613 |
622 |
if self.end < lineno: |
| … | … | @@ -642,6 +651,24 @@ class _FunctionInformationCollector(obje |
642 |
651 |
def _ClassDef(self, node): |
643 |
652 |
self._written_variable(node.name, node.lineno) |
644 |
653 |
|
654 |
def _handle_conditional_node(self, node): |
|
655 |
self.conditional = True |
|
656 |
try: |
|
657 |
for child in ast.get_child_nodes(node): |
|
658 |
ast.walk(child, self) |
|
659 |
finally: |
|
660 |
self.conditional = False |
|
661 |
||
662 |
def _If(self, node): |
|
663 |
self._handle_conditional_node(node) |
|
664 |
||
665 |
def _While(self, node): |
|
666 |
self._handle_conditional_node(node) |
|
667 |
||
668 |
def _For(self, node): |
|
669 |
self._handle_conditional_node(node) |
|
670 |
||
671 |
||
645 |
672 |
|
646 |
673 |
def _get_argnames(arguments): |
647 |
674 |
result = [node.id for node in arguments.args |
Up to file-list ropetest/refactor/extracttest.py:
| … | … | @@ -826,6 +826,37 @@ class ExtractMethodTest(unittest.TestCas |
826 |
826 |
expected = '\ndef f():\n return "1" "2"\n\ns = (f())\n' |
827 |
827 |
self.assertEquals(expected, refactored) |
828 |
828 |
|
829 |
def test_passing_conditional_updated_vars_in_extracted(self): |
|
830 |
code = 'def f(a):\n' \ |
|
831 |
' if 0:\n' \ |
|
832 |
' a = 1\n' \ |
|
833 |
' print(a)\n' |
|
834 |
start, end = self._convert_line_range_to_offset(code, 2, 4) |
|
835 |
refactored = self.do_extract_method(code, start, end, 'g') |
|
836 |
expected = 'def f(a):\n' \ |
|
837 |
' g(a)\n\n' \ |
|
838 |
'def g(a):\n' \ |
|
839 |
' if 0:\n' \ |
|
840 |
' a = 1\n' \ |
|
841 |
' print(a)\n' |
|
842 |
self.assertEquals(expected, refactored) |
|
843 |
||
844 |
def test_returning_conditional_updated_vars_in_extracted(self): |
|
845 |
code = 'def f(a):\n' \ |
|
846 |
' if 0:\n' \ |
|
847 |
' a = 1\n' \ |
|
848 |
' print(a)\n' |
|
849 |
start, end = self._convert_line_range_to_offset(code, 2, 3) |
|
850 |
refactored = self.do_extract_method(code, start, end, 'g') |
|
851 |
expected = 'def f(a):\n' \ |
|
852 |
' a = g(a)\n' \ |
|
853 |
' print(a)\n\n' \ |
|
854 |
'def g(a):\n' \ |
|
855 |
' if 0:\n' \ |
|
856 |
' a = 1\n' \ |
|
857 |
' return a\n' |
|
858 |
self.assertEquals(expected, refactored) |
|
859 |
||
829 |
860 |
|
830 |
861 |
if __name__ == '__main__': |
831 |
862 |
unittest.main() |
