Commits

Victor Stinner committed 0de3010

{x for x in "abc"} => set(["a", "b", "c"])

Comments (0)

Files changed (3)

astoptimizer/ast_tools.py

         raise NotImplementedError("unable to create an AST object for constant: %r" % (value,))
     return copy_lineno(node, new_node)
 
-def new_literal(node, value):
-    if isinstance(value, list):
-        elts = [new_constant(node, elt) for elt in value]
-        new_node = ast.List(elts=elts, ctx=ast.Load())
-        return copy_lineno(node, new_node)
-    else:
-        return new_constant(node, value)
+def _new_constant_list(node, elts):
+    return [new_constant(node, elt) for elt in elts]
 
-def new_list(node, elts=None):
+def new_list_elts(node, elts=None):
     if elts is None:
         elts = []
     new_node = ast.List(elts=elts, ctx=ast.Load())
     return copy_lineno(node, new_node)
 
+def new_list(node, iterable=()):
+    elts = _new_constant_list(node, iterable)
+    return new_list_elts(node, elts)
+
+def sort_set_elts(elts):
+    elts = list(elts)
+    try:
+        # sort elements for astoptimizer unit tests
+        elts.sort()
+    except TypeError:
+        # elements may be unsortable
+        pass
+    return elts
+
 if sys.version_info >= (2, 7):
     def new_set_elts(node, elts=None):
         if elts is None:
         return copy_lineno(node, new_node)
 
     def new_set(node, iterable=()):
-        elts = [new_constant(node, elt) for elt in iterable]
+        elts = sort_set_elts(iterable)
+        elts = _new_constant_list(node, elts)
         return new_set_elts(node, elts)
 
+def new_literal(node, value):
+    if isinstance(value, list):
+        return new_list(node, value)
+    elif sys.version_info >= (2, 7) and isinstance(value, set):
+        return new_set(node, value)
+    else:
+        return new_constant(node, value)
+
 def iter_all_ast(node):
     yield node
     for field, value in ast.iter_fields(node):

astoptimizer/optimizer.py

 
 from astoptimizer import UNSET
 from astoptimizer.ast_tools import (
-    copy_lineno,
-    new_constant, new_literal, new_call, new_pass, new_list,
+    copy_lineno, sort_set_elts,
+    new_constant, new_literal, new_call, new_pass, new_list_elts,
     ast_contains, check_func_args)
 if sys.version_info >= (2, 7):
     from astoptimizer.ast_tools import new_set, new_set_elts
             and not issubclass(list, want_type)):
                 return UNSET
             return [self.get_literal(elt) for elt in node.elts]
+        elif sys.version_info >= (2, 7) and isinstance(node, ast.Set):
+            if (want_type is not None
+            and not issubclass(set, want_type)):
+                return UNSET
+            return set(self.get_literal(elt) for elt in node.elts)
         else:
             return self.get_constant(node, want_type)
 
         qualname = self.namespace.get_qualname(node.func.id)
         if len(node.args) == 1:
             arg = node.args[0]
-            new_arg = self.optimize_iter(arg, want_list=(qualname == 'list'))
+            if qualname == 'list':
+                want_type = list
+            elif qualname == 'set':
+                want_type = set
+            else:
+                want_type = None
+            new_arg = self.optimize_iter(arg, want_type=want_type)
         else:
             arg = UNSET
             new_arg = DROP_NODE
             if qualname == 'tuple':
                 return new_constant(node, ())
             if qualname == 'list':
-                return new_list(node)
+                return new_list_elts(node)
             if qualname == 'dict':
                 new_node = ast.Dict(keys=[], values=[])
                 return copy_lineno(node, new_node)
                 # list([1, 2, 3]) => [1, 2, 3]
                 return arg
         elif qualname in ('frozenset', 'set'):
-            constant = self.get_constant(arg, IMMUTABLE_ITERABLE_TYPES)
+            constant = self.get_literal(arg, ITERABLE_TYPES)
             if constant is not UNSET:
                 elts = frozenset(constant)
                 use_literal_set = (qualname == 'set' and sys.version_info >= (2, 7))
                 if (len(elts) != len(constant)) or use_literal_set:
-                    try:
-                        # sort elements for astoptimizer unit tests
-                        elts = list(elts)
-                        elts.sort()
-                    except TypeError:
-                        # elements may be unsortable
-                        pass
-
+                    elts = sort_set_elts(elts)
                     if len(elts) <= self.config.max_tuple_length:
                         if use_literal_set:
                             # set((1, 2, 3)) => {1, 2, 3}
         node.func.id = 'xrange'
         return node
 
-    def optimize_iter(self, node, is_generator=False, want_list=False):
+    def optimize_iter(self, node, is_generator=False, want_type=None):
         if (self.config.remove_dead_code
         and self.is_empty_iterable(node)):
             # set("") => set()
             return DROP_NODE
 
         if isinstance(node, ast.List):
-            if not want_list:
+            if want_type != list:
                 # for x in [1, 2, 3]: ... => for x in (1, 2, 3): ...
                 return self.list_to_tuple(node)
             else:
         else:
             is_range = self.check_builtin_func(node, 'range', 1, 3)
         if is_range:
-            return self.optimize_range(node, want_list)
+            return self.optimize_range(node, want_type == list)
 
         if (self.config.remove_dead_code
         and self.check_builtin_func(node, 'iter', 1, 1)):
         if (not is_generator
         and self.config.remove_dead_code
         and self.check_builtin_func(node, ('list', 'tuple'),  1, 1)):
-            # set(list(iterable)) => set(iterable)
+            # list(iterable) => iterable
             return node.args[0]
 
         if isinstance(node, ast.GeneratorExp):
             # (x for x in "abc") => "abc"
             if is_generator:
                 node_is_generator = True
+                node_want_type = None
             else:
-                node_is_generator = not want_list
-            new_iter = self.optimize_comprehension(node, node_is_generator)
+                node_is_generator = (want_type != list)
+                node_want_type = list
+            new_iter = self.optimize_comprehension(node,
+                                                   is_generator=node_is_generator,
+                                                   want_type=node_want_type)
             if new_iter is not None:
                 return new_iter
 
-        if want_list and isinstance(node, ast.Tuple):
-            new_node = ast.List(elts=node.elts, ctx=ast.Load())
-            return copy_lineno(node, new_node)
+        if want_type == list and isinstance(node, ast.Tuple):
+            # [1, 2, 3] => (1, 2, 3)
+            return new_list_elts(node, node.elts)
+
         constant = self.get_literal(node)
         if (constant is not UNSET
         and isinstance(constant, ITERABLE_TYPES)):
             if len(constant) <= self.config.max_tuple_length:
-                if want_list:
+                if want_type == list:
+                    # "abc" => ["a", "b", "c"]
                     constant = list(constant)
+                elif want_type == set and sys.version_info >= (2, 7):
+                    # "abc" => {"a", "b", "c"}
+                    constant = set(constant)
                 else:
+                    # "abc" => ("a", "b", "c")
                     constant = tuple(constant)
                 return new_literal(node, constant)
 
             new_ifs = [false_cst]
         node.ifs = new_ifs
 
-    def optimize_comprehension(self, node, is_generator):
+    def optimize_comprehension(self, node, is_generator=False, want_type=None):
         # FIXME: support more than 1 generator
         if len(node.generators) != 1:
             return
             if (not isinstance(node.elt, ast.Name)
             or node.elt.id != name):
                 if is_generator:
-                    new_iter = self.optimize_iter(generator.iter, True)
+                    new_iter = self.optimize_iter(generator.iter, is_generator=True, want_type=tuple)
                 else:
-                    new_iter = self.optimize_iter(generator.iter)
+                    new_iter = self.optimize_iter(generator.iter, want_type=tuple)
                 if new_iter is DROP_NODE:
                     # y for x in ()
                     return DROP_NODE
             iter_expr = generator.iter
 
         if is_generator:
-            new_iter = self.optimize_iter(iter_expr, True)
+            new_iter = self.optimize_iter(iter_expr, is_generator=True, want_type=tuple)
         else:
-            new_iter = self.optimize_iter(iter_expr, want_list=True)
+            new_iter = self.optimize_iter(iter_expr, want_type=want_type)
         if new_iter is not None:
             return new_iter
         else:
 
     def visit_GeneratorExp(self, node):
         # use iter() builtin function
-        iter_expr = self.optimize_comprehension(node, True)
+        iter_expr = self.optimize_comprehension(node, is_generator=True)
         if iter_expr is None:
             return
 
         node.generators = self.visit_list(node.generators)
         node.elt = self.visit(node.elt)
 
-        # use list() builtin function
-        iter_expr = self.optimize_comprehension(node, False)
+        iter_expr = self.optimize_comprehension(node, want_type=list)
         if iter_expr is None:
             return
-        if iter_expr is not DROP_NODE:
-            if isinstance(iter_expr, ast.List):
-                # [x for x in "abc"] => ["a", "b", "c"]
-                return iter_expr
-            elif self.can_use_builtin('list'):
-                # [x for x in range(1000)] => list(xrange(1000))
-                list_iter = new_call(node, 'list', iter_expr)
-                return list_iter
-        else:
+        if iter_expr is DROP_NODE:
             # [x*2 for x in "abc" if False] => []
             # [x*2 for x in []] => []
-            new_node = ast.List(elts=[], ctx=ast.Load())
-            return copy_lineno(node, new_node)
+            return new_list_elts(node)
+
+        if isinstance(iter_expr, ast.List):
+            # [x for x in "abc"] => ["a", "b", "c"]
+            return iter_expr
+        elif self.can_use_builtin('list'):
+            # [x for x in range(1000)] => list(xrange(1000))
+            return new_call(node, 'list', iter_expr)
+
+    def fullvisit_SetComp(self, node):
+        for generator in node.generators:
+            self.unassign(generator.target)
+
+        node.generators = self.visit_list(node.generators)
+        node.elt = self.visit(node.elt)
+
+        iter_expr = self.optimize_comprehension(node, want_type=set)
+        if iter_expr is None:
+            return
+        if iter_expr is DROP_NODE:
+            # [x*2 for x in "abc" if False] => {}
+            # [x*2 for x in []] => {}
+            return new_set(node)
+
+        if isinstance(iter_expr, ast.Set):
+            # {x for x in "abc"} => {"a", "b", "c"}
+            return iter_expr
+        elif self.can_use_builtin('set'):
+            # {x for x in range(1000)} => set(xrange(1000))
+            return new_call(node, 'set', iter_expr)
 
 
 class FunctionOptimizer(Optimizer):

astoptimizer/tests.py

                    self.text_ast('["a", "b", "c"]'))
         self.check('[x for x in "abc" if False]',
                    self.text_ast('[]'))
+        self.check_not_optimized('[x for x in iterable]')
         self.check('[x for x in iterable]',
                    self.text_ast('list(iterable)'),
                    config)
                    self.text_set(0, 1, 2),
                    config)
 
+    def test_SetComp(self):
+        if sys.version_info < (2, 7):
+            return self.skipTest("need python 2.7+")
+        config = self.create_config('builtin_funcs')
+
+        # list comprehension
+        self.check('{x for x in "abc"}',
+                   self.text_set('a', 'b', 'c'))
+        self.check('{x for x in (0, 1, 2)}',
+                   self.text_set(0, 1, 2))
+        self.check('{x*2 for x in ""}',
+                   self.text_set())
+        self.check('{x*2 for x in ""}',
+                   self.text_set(),
+                   config)
+        if PYTHON2:
+            self.check('{x*2 for x in range(1000)}',
+                       self.text_ast('{x*2 for x in xrange(1000)}'),
+                       config)
+        else:
+            self.check_not_optimized('{x*2 for x in range(1000)}', config)
+        self.check('{x*2 for x in range(3)}',
+                   self.text_ast('{x*2 for x in (0, 1, 2)}'),
+                   config)
+
+        # if
+        self.check('{x*2 for x in (1, 2, 3) if True}',
+                   self.text_ast('{x*2 for x in (1, 2, 3)}'))
+        self.check('{x*2 for x in (1, 2, 3) if True if 1}',
+                   self.text_ast('{x*2 for x in (1, 2, 3)}'))
+        self.check('{x*2 for x in (1, 2, 3) if False}',
+                   self.text_set())
+        self.check('{x*2 for x in (1, 2, 3) if 0 if True}',
+                   self.text_set())
+        self.check_not_optimized('{x*2 for x in (1, 2, 3) if x % 2}')
+
+        # replace with list()
+        self.check('{x * 2 for x in ""}',
+                   self.text_set())
+        self.check('{x for x in "abc"}',
+                   self.text_set('a', 'b', 'c'))
+        self.check('{x for x in "abc" if True}',
+                   self.text_set('a', 'b', 'c'))
+        self.check('{x for x in "abc" if False}',
+                   self.text_set())
+        self.check_not_optimized('{x for x in iterable}')
+        self.check('{x for x in iterable}',
+                   self.text_ast('set(iterable)'),
+                   config)
+
+        # tuple(list comprehesion)
+        self.check('tuple({x for x in "abc"})',
+                   self.text_tuple("a", "b", "c"),
+                   config)
+        self.check('tuple({x for x in "abc" if False})',
+                   self.text_tuple(),
+                   config)
+        self.check('tuple({x for x in iterable})',
+                   self.text_ast('tuple(iterable)'),
+                   config)
+        if PYTHON2:
+            self.check('set({x for x in range(1000)})',
+                       self.text_ast('set(xrange(1000))'),
+                       config)
+        else:
+            self.check('set({x for x in range(1000)})',
+                       self.text_ast('set(range(1000))'),
+                       config)
+
+        self.check('set({x for x in range(3)})',
+                   self.text_set(0, 1, 2),
+                   config)
+
+
 class TestFrozenset(BaseTestCase):
     def create_default_config(self):
         config = BaseTestCase.create_default_config(self)