Commits

Victor Stinner  committed 05e1f59

x in (1, 2, 3) => x in {1, 2, 3}

  • Participants
  • Parent commits 891fe65

Comments (0)

Files changed (5)

    - ``list(x for x in range(3))`` => ``[0, 1, 2]``
    - ``[x for x in ""]`` => ``[]``
    - ``[x for x in iterable]`` => ``list(iterable)``
-   - ``set([x for x in "abc"])`` => ``set(("a", "b", "c"))``
+   - ``set([x for x in "abc"])`` => ``{"a", "b", "c"}`` (Python 2.7+) or ``set(("a", "b", "c"))``
 
  * Replace list with tuple (need "builtin_funcs" feature). Examples:
 
    - ``frozenset("ab") | frozenset("bc")`` => ``frozenset("abc")``
    - ``None is None`` => ``True``
    - ``"2" in "python2.7"`` => ``True``
+   - ``x in [1, 2, 3]`` => ``x in {1, 2, 3}`` (Python 3) or ``x in (1, 2, 3)`` (Python 2)
    - ``def f(): return 2 if 4 < 5 else 3`` => ``def f(): return 2``
 
  * Remove dead code. Examples:
 
 other:
 
- - Python 3.2+: x in (1, 2) => x in {1, 2} ? (create a frozenset).
  - SetComp, DictComp:
 
    * {x for x in "abc"} => {"a", "b", "c"}

File astoptimizer/ast_tools.py

     else:
         return new_constant(node, value)
 
+def new_list(node, elts=None):
+    if elts is None:
+        elts = []
+    new_node = ast.List(elts=elts, ctx=ast.Load())
+    return copy_lineno(node, new_node)
+
 if sys.version_info >= (2, 7):
+    def new_set_elts(node, elts=None):
+        if elts is None:
+            elts = []
+        new_node = ast.Set(elts=elts)
+        return copy_lineno(node, new_node)
+
     def new_set(node, iterable=()):
         elts = [new_constant(node, elt) for elt in iterable]
-        new_node = ast.Set(elts=elts)
-        return copy_lineno(node, new_node)
+        return new_set_elts(node, elts)
 
 def iter_all_ast(node):
     yield node

File astoptimizer/optimizer.py

 from astoptimizer import UNSET
 from astoptimizer.ast_tools import (
     copy_lineno,
-    new_constant, new_literal, new_call, new_pass,
+    new_constant, new_literal, new_call, new_pass, new_list,
     ast_contains, check_func_args)
 if sys.version_info >= (2, 7):
-    from astoptimizer.ast_tools import new_set
+    from astoptimizer.ast_tools import new_set, new_set_elts
 from astoptimizer.config import optimize_unicode
 from astoptimizer.compatibility import (
     u,
 
         return self.compare_cst(node, op, left_cst, right_cst)
 
+    def compare_in(self, data):
+        if sys.version_info >= (3, 2):
+            # Python 3.2+ bytecode peepholer replaces x in {1, 2} with x in
+            # frozenset({1, 2}), where frozenset({1, 2}) is a constant.
+            if isinstance(data, ast.List):
+                return new_set_elts(data, data.elts)
+
+            constant = self.get_constant(data)
+            if constant is UNSET:
+                return
+            if isinstance(constant, (tuple, list, frozenset)):
+                # x in (1, 2) => x in {1, 2}
+                # x in frozenset((1, 2)) => x in {1, 2}
+                return new_set(data, constant)
+        else:
+            if isinstance(data, ast.List):
+                # x in [1, 2] => x in (1, 2)
+                return self.list_to_tuple(data)
+
     def visit_Compare(self, node):
         # FIXME: implement 1 < 2 < 3
         if len(node.ops) != 1:
         if new_node is not UNSET:
             return new_node
 
-        if (isinstance(op, ast.In)
-        and isinstance(node.comparators[0], ast.List)):
-            node.comparators[0] = self.list_to_tuple(node.comparators[0])
+        data = node.comparators[0]
+        if isinstance(op, (ast.In, ast.NotIn)):
+            new_data = self.compare_in(data)
+            if new_data is not None:
+                node.comparators[0] = new_data
 
     def if_block(self, parent, node_list):
         # Create a pass instruction if needed.
             if qualname == 'tuple':
                 return new_constant(node, ())
             if qualname == 'list':
-                new_node = ast.List(elts=[], ctx=ast.Load())
-                return copy_lineno(node, new_node)
+                return new_list(node)
             if qualname == 'dict':
                 new_node = ast.Dict(keys=[], values=[])
                 return copy_lineno(node, new_node)

File astoptimizer/tests.py

         self.check('0 in (1, 2, 3)', self.text_bool(False))
         self.check('2 in (1, 2, 3)', self.text_bool(True))
 
-        self.check("x in [1, 2, 3]", self.text_ast("x in (1, 2, 3)"))
+        if sys.version_info >= (3, 2):
+            config = self.create_config('builtin_funcs')
+
+            self.check("x in (1, 2, 3)", self.text_ast('x in {1, 2, 3}'))
+            self.check("x in [1, 2, 3]", self.text_ast('x in {1, 2, 3}'))
+            self.check("x in frozenset((1, 2, 3))", self.text_ast('x in {1, 2, 3}'), config)
+            self.check_not_optimized("x in {1, 2, 3}")
+
+            self.check("x not in (1, 2, 3)", self.text_ast('x not in {1, 2, 3}'))
+        else:
+            self.check_not_optimized("x in (1, 2, 3)")
+            self.check("x in [1, 2, 3]", self.text_ast("x in (1, 2, 3)"))
+            self.check_not_optimized("x in set((1, 2, 3))")
+            self.check_not_optimized("x in frozenset((1, 2, 3))")
+
+            self.check("x not in [1, 2, 3]", self.text_ast('x not in (1, 2, 3)'))
         self.check("not(x in y)", self.text_ast("x not in y"))
         self.check("not(x not in y)", self.text_ast("x in y"))
 
         config.enable('pythonenv')
         config._constants['__debug__'] = True
         config._constants['sys.flags.optimize'] = 0
+        config._constants['sys.maxunicode'] = 0x10ffff
         return config
 
     def check_line(self, line):
         elif (PYTHON3
         and 'string.atoi' in before):
             return self.skipTest("specific to Python 2")
+        if PYTHON3 and before.startswith('u"'):
+            before = before[1:]
 
         if 'if DEBUG:' in before:
             config = self.create_default_config()
             config.add_constant('DEBUG', False)
         else:
             config = None
-        if len(parts) >= 5 and parts[3] == ' or ':
+        if len(parts) >= 5 and ' or ' in parts[3]:
             after2 = parts[4]
         else:
             after2 = None
         try:
             self.check(before, self.text_ast(after), config)
-        except AssertionError:
+        except (AssertionError, SyntaxError):
             if after2 is not None:
                 self.check(before, self.text_ast(after2), config)
             else: