Commits

Victor Stinner  committed e96e0fd

{key: value for key, value in {1: 2, 3: 4}} => {1: 2, 3: 4}

  • Participants
  • Parent commits 89fab93

Comments (0)

Files changed (5)

    - ``for x in range(3): pass`` => ``for x in (0, 1, 2): pass``
    - ``for x in range(1000): pass`` => ``for x in xrange(1000): pass`` (Python 2)
 
- * Optimize list comprehension and generators (need "builtin_funcs" feature).
-   Examples:
+ * Optimize iterators, list, set and dict comprehension, and generators (need
+   "builtin_funcs" feature). Examples:
 
    - ``iter(set())`` => ``iter(())``
    - ``frozenset("")`` => ``frozenset()``
 
  * Add the "struct" configuration feature: functions of the struct module
  * Optimize print() on Python 2 with "from __future__ import print_function"
- * Optimize iterators, list comprehension and generators
+ * Optimize iterators, list, set and dict comprehension, and generators
  * Replace list with tuple
 
 Version 0.3.1 (2012-09-12)
  - "i=0; while i < 10: print(i); i = i + 1": don't replace print(i) with print('0')
  - "for x in (): try: pass finally: continue" must raise a SyntaxError
  - "type(iter([]))"
+ - "list([1, 2])" and "dict({...})" should create a copy?
 
 major optimizations:
 
 
 other:
 
- - SetComp, DictComp:
-
-   * {x for x in "abc"} => {"a", "b", "c"}
-   * {x for x in a} => set(a)
-   * {x: y for x, y in a} => dict(a)
-
  - operator module:
 
    * lambda x: x[1] => operator.itemgetter(1)

File astoptimizer/ast_tools.py

         pass
     return elts
 
+def new_dict_elts(node, keys=None, values=None):
+    if keys is None:
+        keys = []
+    if values is None:
+        values = []
+    new_node = ast.Dict(keys=keys, values=values)
+    return copy_lineno(node, new_node)
+
 if sys.version_info >= (2, 7):
     def new_set_elts(node, elts=None):
         if elts is None:

File astoptimizer/optimizer.py

     copy_lineno,
     new_constant, new_literal, new_call, new_pass,
     sort_set_elts, new_tuple, new_tuple_elts, new_list, new_list_elts,
+    new_dict_elts,
     ast_contains, check_func_args)
 from astoptimizer.config import optimize_unicode
 from astoptimizer.compatibility import (
                 to_type = set
             elif qualname == 'frozenset':
                 to_type = tuple
+            elif qualname == 'dict':
+                to_type = dict
             else:
                 to_type = None
             new_arg = self.optimize_iter(arg, to_type)
             if qualname == 'list':
                 return new_list_elts(node)
             if qualname == 'dict':
-                new_node = ast.Dict(keys=[], values=[])
-                return copy_lineno(node, new_node)
+                return new_dict_elts(node)
             if (PYTHON27
             and qualname == 'set'):
                 return new_set(node)
             if isinstance(arg, ast.List):
                 # list([1, 2, 3]) => [1, 2, 3]
                 return arg
+        elif qualname == 'dict':
+            if isinstance(arg, ast.Dict):
+                # dict({1: 2, 3: 4}) => {1: 2, 3: 4}
+                return arg
         elif qualname in ('frozenset', 'set'):
             if (qualname == 'set'
             and PYTHON27
 
     def node_to_type(self, node, to_type):
         if PYTHON27:
-            ast_types = (ast.Tuple, ast.List, ast.Set)
+            ast_types = (ast.Tuple, ast.List, ast.Dict, ast.Set)
         else:
-            ast_types = (ast.Tuple, ast.List)
+            ast_types = (ast.Tuple, ast.List, ast.Dict)
         if not isinstance(node, ast_types):
             return
-        if len(node.elts) > self.config.max_tuple_length:
+        if isinstance(node, ast.Dict):
+            length = len(node.keys)
+            assert len(node.keys) == len(node.values)
+        else:
+            length = len(node.elts)
+        if length > self.config.max_tuple_length:
             return
 
         if isinstance(node, ast.Tuple):
                 return new_tuple_elts(node, node.elts)
             if to_type == set:
                 return self.node_to_set(node)
+        elif isinstance(node, ast.Dict):
+            if to_type == dict:
+                return node
+            # FIXME: support other types
         elif isinstance(node, ast.Set):
             if to_type == set:
                 return node
             return DROP_NODE
 
         if PYTHON27:
-            ast_types = (ast.Tuple, ast.List, ast.Set)
+            ast_types = (ast.Tuple, ast.List, ast.Dict, ast.Set)
         else:
-            ast_types = (ast.Tuple, ast.List)
+            ast_types = (ast.Tuple, ast.List, ast.Dict)
         if isinstance(node, ast_types):
             return self.node_to_type(node, to_type)
 
         if len(node.generators) != 1:
             return
         generator = node.generators[0]
-        if not isinstance(generator.target, ast.Name):
-            return
+
         if generator.ifs:
             if not self.config.remove_dead_code:
                 return
                 return
             # (x for x in data if False) => iter(())
             return DROP_NODE
+
+        if to_type == dict:
+            # dict comprehension
+            target = generator.target
+            if not (isinstance(target, ast.Tuple)
+                    and len(target.elts) == 2
+                    and isinstance(target.elts[0], ast.Name)
+                    and isinstance(target.elts[1], ast.Name)):
+                return
+            key = target.elts[0].id
+            value = target.elts[1].id
+            if not (isinstance(node.key, ast.Name)
+                    and isinstance(node.value, ast.Name)
+                    and node.key.id == key
+                    and node.value.id == value):
+                # {value: key for key, value in iterable}
+                if is_generator:
+                    new_iter = self.optimize_generator(generator.iter)
+                else:
+                    new_iter = self.optimize_iter(generator.iter, tuple)
+                if new_iter is DROP_NODE:
+                    # y for x in ()
+                    return DROP_NODE
+                if new_iter is not None:
+                    # y for x in [1, 2] => y for x in (1, 2)
+                    generator.iter = new_iter
+                return
         else:
+            # generator expression, list or set comprehension:
+            # to_type in (tuple, list, set)
+            if not isinstance(generator.target, ast.Name):
+                return
             name = generator.target.id
             if (not isinstance(node.elt, ast.Name)
             or node.elt.id != name):
                     # y for x in [1, 2] => y for x in (1, 2)
                     generator.iter = new_iter
                 return
-            iter_expr = generator.iter
+        iter_expr = generator.iter
 
         if is_generator:
             new_iter = self.optimize_generator(iter_expr)
             # {x for x in range(1000)} => set(xrange(1000))
             return new_call(node, 'set', iter_expr)
 
+    def fullvisit_DictComp(self, node):
+        for generator in node.generators:
+            self.unassign(generator.target)
+
+        node.generators = self.visit_list(node.generators)
+        node.key = self.visit(node.key)
+        node.value = self.visit(node.value)
+
+        iter_expr = self.optimize_comprehension(node, to_type=dict)
+        if iter_expr is None:
+            return
+        if iter_expr is DROP_NODE:
+            # [x*2 for x in "abc" if False] => {}
+            # [x*2 for x in []] => {}
+            return new_dict_elts(node)
+
+        if isinstance(iter_expr, ast.Dict):
+            # {x, y for x, y in {1: 2, 3: 4}} => {1: 2, 3: 4}
+            return iter_expr
+        elif self.can_use_builtin('dict'):
+            # {x, y for x, y in iterable} => dict(iterable)
+            return new_call(node, 'dict', iter_expr)
+
 
 class FunctionOptimizer(Optimizer):
     def __init__(self, config):

File astoptimizer/tests.py

                    config)
         self.check_not_optimized('(x for x in "abc" if 0)', config)
 
+        if PYTHON27:
+            self.check('{x*2 for x in (1, 2, 3) if False}',
+                       self.text_ast('{x*2 for x in (1, 2, 3) if 0}'),
+                       config)
+            self.check('{x: x*2 for x in (1, 2, 3) if False}',
+                       self.text_ast('{x: x*2 for x in (1, 2, 3) if 0}'),
+                       config)
+
     def test_IfExp(self):
         self.check('4 if "abc" else 5', self.text_num(4))
         self.check('4 if 0 else 5', self.text_num(5))
                    self.text_set(0, 1, 2),
                    config)
 
+    def test_DictComp(self):
+        if sys.version_info < (2, 7):
+            return self.skipTest("need python 2.7+")
+        config = self.create_config('builtin_funcs')
+
+        # dict comprehension
+        self.check('{key: value for key, value in {1: 2, 3: 4}}',
+                   self.text_ast('{1: 2, 3: 4}'))
+        self.check('{key: value for key, value in iterable}',
+                   self.text_ast('dict(iterable)'),
+                   config)
+        self.check_not_optimized('{value: key for key, value in iterable}')
+
+        # if
+        self.check('{x: x*2 for x in (1, 2, 3) if True}',
+                   self.text_ast('{x: x*2 for x in (1, 2, 3)}'))
+        self.check('{x: x*2 for x in (1, 2, 3) if True if 1}',
+                   self.text_ast('{x: x*2 for x in (1, 2, 3)}'))
+        self.check('{x: x*2 for x in (1, 2, 3) if False}',
+                   self.text_ast('{}'))
+        self.check('{x: x*2 for x in (1, 2, 3) if 0 if True}',
+                   self.text_ast('{}'))
+        self.check_not_optimized('{x*2 for x in (1, 2, 3) if x % 2}')
+
+
 
 class TestFrozenset(BaseTestCase):
     def create_default_config(self):
     def test_while(self):
         self.check_not_optimized('i=0\nwhile i < 10: i = i + 1')
 
+    def test_DictComp(self):
+        if sys.version_info < (2, 7):
+            return self.skipTest("need python 2.7+")
+        self.check_not_optimized('k=1\nv=2\n{v:k for k, v in iterable}')
+
     def test_FunctionDef(self):
         self.check('x=1\ndef f():\n x=2\n return x',
                    self.text_ast('x=1\ndef f():\n x=2\n return 2'))
         config.enable('builtin_funcs')
         return config
 
-    def test_list_to_tuple(self):
+    def test_optimize_iter_builtin(self):
         config = self.create_config()
         self.check_not_optimized('set([1, 2, 3])', config)
 
             self.check('set([1, 2, 3])',
                        self.text_ast('set((1, 2, 3))'))
 
+        self.check('dict({1: 2, 3: 4})',
+                   self.text_ast('{1: 2, 3: 4}'))
+
     def test_no_builtin_funcs(self):
         config = self.create_config()
         self.check_not_optimized('abs(-5)', config)