Commits

Anonymous committed 0f3399b

Adding support for all nodes in patchedast

Comments (0)

Files changed (3)

docs/dev/workingon.txt

 Patched AST
 ===========
 
-- GenExpr and Tuple should consume surrounding parens if exists
-- Module not loosing blank lines in the begining and end of the file
-- Add `patchedast.write_ast()`
-- Ignoring comments
+- tuple parameter unpacking
+- handle Sliceobj
+- handle long slices
 
-* tuple parameter unpacking
-* handle long slices
-* handle Sliceobj
+* profiling
 * py3k: How to handle both versions?
 
   * metadata=? and classes
   * // and /
   * ...
 
-* profiling
-* report error on ``return``
 * adding custom source folders in ``config.py``
 
 

rope/refactor/patchedast.py

 
     def _handle(self, node, base_children, eat_parens=False, eat_spaces=False):
         children = []
+        formats = []
         suspected_start = self.source.offset
         start = suspected_start
         first_token = True
                 child = self.source.get(region[0], region[1])
                 token_start = region[0]
             if not first_token:
+                formats.append(self.source.get(offset, token_start))
                 children.append(self.source.get(offset, token_start))
             else:
                 first_token = False
                 start = token_start
             children.append(child)
-        start = self._handle_parens(children, start)
+        start = self._handle_parens(children, start, formats)
         if eat_parens:
             start = self._eat_surrounding_parens(
                 children, suspected_start, start)
         node.sorted_children = children
         node.region = (start, self.source.offset)
 
-    def _handle_parens(self, children, start):
+    def _handle_parens(self, children, start, formats):
         """Changes `children` and returns new start"""
-        opens, closes = self._count_needed_parens(children)
+        opens, closes = self._count_needed_parens(formats)
         old_end = self.source.offset
         new_end = None
         for i in range(closes):
             new_end = self.source.consume(')')[1]
         if new_end is not None:
             children.append(self.source.get(old_end, new_end))
-        new_start = None
+        new_start = start
         for i in range(opens):
-            new_start = self.source.find_backwards('(', start)
-        if new_start is not None:
+            new_start = self.source.rfind_token('(', 0, new_start)
+        if new_start != start:
             children.insert(0, self.source.get(new_start, start))
             start = new_start
         return start
 
     def _eat_surrounding_parens(self, children, suspected_start, start):
-        if '(' in self.source[suspected_start:start]:
+        index = self.source.rfind_token('(', suspected_start, start)
+        if index is not None:
             old_start = start
             old_offset = self.source.offset
-            start = self.source.find_backwards('(', start)
+            start = index
             children.insert(0, '(')
             children.insert(1, self.source[start + 1:old_start])
             token_start, token_end = self.source.consume(')')
         self._handle(node, ['break'])
 
     def visitCallFunc(self, node):
-        children = []
-        children.append(node.node)
-        children.append('(')
+        children = [node.node, '(']
         children.extend(self._child_nodes(node.args, ','))
         if node.star_args is not None:
             if node.args:
         self._handle(node, children)
 
     def visitDiscard(self, node):
-        self._handle(node, [node.expr])
+        children = []
+        if not self._is_none_or_const_none(node.expr):
+            children.append(node.expr)
+        self._handle(node, children)
 
     def visitDiv(self, node):
         self._handle(node, [node.left, '/', node.right])
         if flags & compiler.consts.CO_VARARGS:
             star_args = args.pop()
         defaults = [None] * (len(args) - len(defaults)) + list(defaults)
-        for arg, default in zip(args[:-1], defaults[:-1]):
+        for index, (arg, default) in enumerate(zip(args, defaults)):
+            if index > 0:
+                children.append(',')
             self._add_args_to_children(children, arg, default)
-            children.append(',')
-        if args:
-            self._add_args_to_children(children, args[-1], defaults[-1])
         if star_args is not None:
             if args:
                 children.append(',')
         return children
 
     def _add_args_to_children(self, children, arg, default):
-        children.append(arg)
+        if isinstance(arg, (list, tuple)):
+            self._add_tuple_parameter(children, arg)
+        else:
+            children.append(arg)
         if default is not None:
             children.append('=')
             children.append(default)
 
+    def _add_tuple_parameter(self, children, arg):
+        children.append('(')
+        for index, token in enumerate(arg):
+            if index > 0:
+                children.append(',')
+            if isinstance(token, (list, tuple)):
+                self._add_tuple_parameter(children, token)
+            else:
+                children.append(token)
+        children.append(')')
+
     def visitGenExpr(self, node):
         self._handle(node, [node.code], eat_parens=True)
 
             children.extend(['else', ':', node.else_])
         self._handle(node, children)
 
+    def visitIfExp(self, node):
+        return self._handle(node, [node.then, 'if', node.test,
+                                   'else', node.else_])
+
     def visitImport(self, node):
         children = ['import']
         for index, (name, alias) in enumerate(node.names):
     def _base_print(self, node):
         children = ['print']
         if node.dest:
-            children.extend(['>>', node.dest, ','])
+            children.extend(['>>', node.dest])
+            if node.nodes:
+                children.append(',')
         children.extend(self._child_nodes(node.nodes, ','))
         return children
 
 
     def visitReturn(self, node):
         children = ['return']
-        if node.value and not (isinstance(node.value, compiler.ast.Const) and
-                               node.value.value == None):
+        if not self._is_none_or_const_none(node.value):
             children.append(node.value)
         self._handle(node, children)
 
+    def _is_none_or_const_none(self, node):
+        return node is None or (isinstance(node, compiler.ast.Const) and
+                                node.value == None)
+
     def visitRightShift(self, node):
         self._handle(node, [node.left, '>>', node.right])
 
         if node.lower:
             children.append(node.lower)
         children.append(':')
-        if node.lower:
+        if node.upper:
             children.append(node.upper)
         children.append(']')
         self._handle(node, children)
 
+    def visitSliceobj(self, node):
+        children = []
+        for index, slice in enumerate(node.nodes):
+            if index > 0:
+                children.append(':')
+            if not self._is_none_or_const_none(slice):
+                children.append(slice)
+        self._handle(node, children)
+
     def visitStmt(self, node):
         self._handle(node, node.nodes)
 
         self._handle(node, children    )
 
     def _handle_tuple(self, node):
-        self._handle(node, self._child_nodes(node.nodes, ','), eat_parens=True)
+        if node.nodes:
+            self._handle(node, self._child_nodes(node.nodes, ','),
+                         eat_parens=True)
+        else:
+            self._handle(node, ['(', ')'])
 
     def visitUnaryAdd(self, node):
         self._handle(node, ['+', node.expr])
         self.offset = new_offset + len(token)
         return (new_offset, self.offset)
 
-    def _good_token(self, token, offset):
+    def _good_token(self, token, offset, start=None):
         """Checks whether consumed token is in comments"""
-        return not '#' in self.source[self.offset:offset]
+        if start is None:
+            start = self.offset
+        try:
+            comment_index = self.source.rindex('#', start, offset)
+        except ValueError:
+            return True
+        try:
+            new_line_index = self.source.rindex('\n', start, offset)
+        except ValueError:
+            return False
+        return comment_index < new_line_index
 
     def _skip_comment(self):
         self.offset = self.source.index('\n', self.offset + 1)
     def consume_string(self):
         if _Source._string_pattern is None:
             original = codeanalyze.get_string_pattern()
-            pattern = r'(%s)((\s|\\\n)*(%s))*' % (original, original)
+            pattern = r'(%s)((\s|\\\n|#[^\n]*\n)*(%s))*' % (original, original)
             _Source._string_pattern = re.compile(pattern)
         repattern = _Source._string_pattern
         return self._consume_pattern(repattern)
     def get(self, start, end):
         return self.source[start:end]
 
+    def rfind_token(self, token, start, end):
+        index = start
+        while True:
+            try:
+                index = self.source.rindex(token, start, end)
+                if self._good_token(token, index, start=start):
+                    return index
+                else:
+                    end = index
+            except ValueError:
+                return None
+
     def from_offset(self, offset):
         return self.get(offset, self.offset)
 

ropetest/refactor/patchedasttest.py

                          ' ', '**', '', 'p2', '', ')', '', ':', '\n    ',
                          '"""docs"""', '\n    ', 'Stmt'])
 
+    def test_function_node_and_tuple_parameters(self):
+        source = 'def f(a, (b, c)):\n    pass\n'
+        ast = patchedast.get_patched_ast(source)
+        checker = _ResultChecker(self, ast)
+        checker.check_region('Function', 0, len(source) - 1)
+        checker.check_children(
+            'Function', ['def', ' ', 'f', '', '(', '', 'a', '', ',', ' ', '(',
+                         '', 'b', '', ',', ' ', 'c', '', ')', '', ')' , '',
+                         ':', '\n    ', 'Stmt'])
+
     def test_dict_node(self):
         source = '{1: 2, 3: 4}\n'
         ast = patchedast.get_patched_ast(source)
         checker.check_children(
             'Tuple', ['(', '', 'Const(1)', '', ',', ' ', 'Const(2)', '', ')'])
 
+    def test_tuple_node(self):
+        source = '#(\n1, 2\n'
+        ast = patchedast.get_patched_ast(source)
+        checker = _ResultChecker(self, ast)
+        checker.check_children('Tuple', ['Const(1)', '', ',', ' ', 'Const(2)'])
+
     def test_one_item_tuple_node(self):
         source = '(1,)\n'
         ast = patchedast.get_patched_ast(source)
         checker = _ResultChecker(self, ast)
         checker.check_children('Tuple', ['(', '', 'Const(1)', ',', ')'])
 
+    def test_empty_tuple_node(self):
+        source = '()\n'
+        ast = patchedast.get_patched_ast(source)
+        checker = _ResultChecker(self, ast)
+        checker.check_children('Tuple', ['(', '', ')'])
+
     def test_yield_node(self):
         source = 'def f():\n    yield None\n'
         ast = patchedast.get_patched_ast(source)
         start = source.rindex('1')
         checker.check_region('Const(1)', start, start + 1)
 
+    def test_simple_sliceobj(self):
+        source = 'a[1::3]\n'
+        ast = patchedast.get_patched_ast(source)
+        checker = _ResultChecker(self, ast)
+        checker.check_children(
+            'Sliceobj', ['Const(1)', '', ':', '', ':', '', 'Const(3)'])
+
+    def test_ignoring_strings_that_start_with_a_char(self):
+        source = 'r"""("""\n1\n'
+        ast = patchedast.get_patched_ast(source)
+        checker = _ResultChecker(self, ast)
+        checker.check_children(
+            'Module', ['', 'r"""("""', '\n', 'Stmt', '\n'])
+
+    # XXX: ``<>`` will be removed in Python 3.0
+    def xxx_test_how_to_handle_old_not_equals(self):
+        source = '1 <> 2\n'
+        ast = patchedast.get_patched_ast(source)
+        checker = _ResultChecker(self, ast)
+        checker.check_children(
+            'Module', ['Const(1)', ' ', '<>', ' ', 'Const(2)'])
+
+    def test_semicolon(self):
+        source = '1;\n'
+        ast = patchedast.get_patched_ast(source)
+
+    @testutils.run_only_for_25
+    def test_if_exp_node(self):
+        source = '1 if True else 2\n'
+        ast = patchedast.get_patched_ast(source)
+        checker = _ResultChecker(self, ast)
+        checker.check_children(
+            'IfExp', ['Const(1)', ' ', 'if', ' ', 'Name', ' ', 'else',
+                      ' ', 'Const(2)'])
+
 
 class _ResultChecker(object):