Commits

Victor Stinner committed a277a3c

Optimize iterators and generators

Comments (0)

Files changed (4)

    - ``not(x in y)`` => ``x not in y``
    - ``4 and 5 and x and 6`` => ``x and 6``
 
- * Replace ``range()`` with ``xrange()`` for iterators (on Python 2).
-   Examples:
+ * Optimize iterators and generators. Examples:
 
    - ``for x in range(3): ...`` => ``for x in xrange(3): ...``
    - ``[x for x in range(3)]`` => ``[x for x in xrange(3)]``
+   - ``tuple(x for x in "abc")`` => ``tuple(iter("abc"))``
+   - ``(x for x in "abc" if False)`` => ``iter(())``
+   - ``iter(set())`` => ``iter(())``
+   - ``frozenset("")`` => ``frozenset()``
 
  * Replace list with tuple. Examples:
 
 Changes:
 
  * Optimize print() on Python 2 with "from __future__ import print_function"
+ * Optimize iterators and generators
  * Replace list with tuple
 
 Version 0.3.1 (2012-09-12)

astoptimizer/ast_tools.py

     else:
         return any(isinstance(node, obj_type) for node in iter_all_ast(tree))
 
+def new_call(node, name, *args):
+    # name: str
+    # args: ast objects
+    name = ast.Name(id=name, ctx=ast.Load())
+    copy_lineno(node, name)
+    new_node = ast.Call(
+        func=name,
+        args=list(args),
+        keywords=[],
+        starargs=None,
+        kwargs=None)
+    return copy_lineno(node, new_node)
+
+def check_func_args(node, min_narg=None, max_narg=None):
+    keywords = node.keywords
+    starargs = node.starargs
+    kwargs = node.kwargs
+    # Don't support keywords, *args, **kw yet
+    if keywords or starargs or kwargs:
+        return False
+    if min_narg is not None and len(node.args) < min_narg:
+        return False
+    if max_narg is not None and len(node.args) > max_narg:
+        return False
+    return True
+
+def new_pass(node):
+    new_node = ast.Pass()
+    return copy_lineno(node, new_node)
+

astoptimizer/optimizer.py

 import ast
 import operator
 from astoptimizer import UNSET
-from astoptimizer.ast_tools import ast_contains, copy_lineno, new_constant
+from astoptimizer.ast_tools import (
+    ast_contains, copy_lineno,
+    new_constant, new_call, new_pass,
+    check_func_args)
 from astoptimizer.config import optimize_unicode
 from astoptimizer.compatibility import (
     u,
     IMMUTABLE_ITERABLE_TYPES)
 import sys
 
+DROP_NODE = object()
+
 def operator_not_in(data, item):
     return item not in data
 
 TUPLE_BINOPS = (ast.Add,)
 FROZENSET_BINOPS = (ast.BitAnd, ast.BitOr, ast.BitXor, ast.Sub)
 
+# builtin functions accepting an iterable as input
+BUILTIN_ACCEPTING_ITERABLE = ('dict', 'frozenset', 'list', 'set', 'tuple')
+
 class Namespace:
     def __init__(self):
         self._aliases_enabled = True
             return constant
         return UNSET
 
+    def check_func(self, node, name, min_narg=None, max_narg=None):
+        if not isinstance(node, ast.Call):
+            return False
+        if not isinstance(node.func, ast.Name):
+            return False
+        qualname = self.namespace.get_qualname(node.func.id)
+        if qualname is UNSET:
+            return False
+        if isinstance(name, str):
+            if qualname != name:
+                return False
+        else:
+            if qualname not in name:
+                return False
+        return check_func_args(node, min_narg, max_narg)
+
+    def check_builtin_func(self, node, name, min_narg=None, max_narg=None):
+        if 'builtin_funcs' not in self.config.features:
+            return False
+        return self.check_func(node, name, min_narg, max_narg)
+
     def get_constant(self, node):
         if node is None:
             return None
                     return UNSET
                 constants.append(constant)
             return tuple(constants)
-        if (isinstance(node, ast.Call)
-        and 'builtin_funcs' in self.config.features
-        and isinstance(node.func, ast.Name)
-        and node.func.id == 'frozenset'
-        and self.check_func_args(node)):
+        if self.check_builtin_func(node, 'frozenset', 0, 1):
             if len(node.args) == 1:
                 arg = self.get_constant(node.args[0])
                 if arg is UNSET:
         # Create a pass instruction if needed.
         # Example: if 0: print("debug") => pass
         if len(node_list) == 0:
-            new_node = ast.Pass()
-            return copy_lineno(parent, new_node)
+            return new_pass(parent)
         else:
             return node_list
 
     def print_func(self, node, args):
         constants = self.get_constant_list(args)
         if constants is UNSET:
-            return UNSET
+            return
         if PYTHON3:
             if any(isinstance(constant, BYTES_TYPE)
                    for constant in constants):
-                return UNSET
+                return
         else:
             # print(unicode) depends on the locale encoding
             if any(isinstance(constant, UNICODE_TYPE)
                    for constant in constants):
-                return UNSET
+                return
         text = ' '.join(str(constant) for constant in constants)
         new_arg = new_constant(node, text)
         node.args = [new_arg]
 
     def call_func(self, node, name):
-        if ('builtin_funcs' in self.config.features
-        and name == 'print'):
-            return self.print_func(node, node.args)
         if name not in self.config.functions:
             return UNSET
         func = self.config.functions[name]
             return UNSET
         return new_constant(node, result)
 
-    def check_func_args(self, node):
-        keywords = node.keywords
-        starargs = node.starargs
-        kwargs = node.kwargs
-        # Don't support keywords, *args, **kw yet
-        if keywords or starargs or kwargs:
-            return False
-        return True
+    def is_empty_iterable(self, node):
+        if isinstance(node, (ast.List, ast.Tuple)):
+            return len(node.elts) == 0
+
+        constant = self.get_constant(node)
+        if constant is not UNSET:
+            if isinstance(constant, IMMUTABLE_ITERABLE_TYPES):
+                return len(constant) == 0
+            else:
+                return False
+
+        if self.check_builtin_func(node, BUILTIN_ACCEPTING_ITERABLE, 0, 0):
+            return True
+
+    def call_name(self, node):
+        name = self.namespace.get_qualname(node.func.id)
+        if name is not UNSET:
+            new_node = self.call_func(node, name)
+            if new_node is not UNSET:
+                return new_node
+
+        if self.check_builtin_func(node, BUILTIN_ACCEPTING_ITERABLE,  1, 1):
+            arg = node.args[0]
+            new_arg = self.optimize_iter(arg, True)
+            if new_arg is DROP_NODE:
+                del node.args[0]
+            elif new_arg is not None:
+                node.args[0] = new_arg
+
+        elif (self.check_func(node, 'iter', 1, 1)
+        and self.is_empty_iterable(node.args[0])):
+            # iter(set()) => iter(())
+            node.args[0] = new_constant(node.args[0], ())
+
+        elif self.check_builtin_func(node, 'print'):
+            self.print_func(node, node.args)
+
+        else:
+            new_node = self.call_func(node, name)
+            if new_node is not UNSET:
+                return new_node
 
     def visit_Call(self, node):
         func = node.func
-        args = node.args
-        if not self.check_func_args(node):
+        if not check_func_args(node):
             return
         if isinstance(func, ast.Name):
-            name = self.namespace.get_qualname(func.id)
-            if name is not UNSET:
-                new_node = self.call_func(node, name)
-                if new_node is not UNSET:
-                    return new_node
-
-            if ('builtin_funcs' in self.config.features
-            and name in ('frozenset', 'list', 'set')
-            and len(node.args) == 1
-            and isinstance(node.args[0], ast.List)):
-                node.args[0] = self.list_to_tuple(node.args[0])
+            return self.call_name(node)
         elif isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name):
             name = "%s.%s" % (func.value.id, func.attr)
             qualname = self.namespace.get_qualname(name)
         else:
             self.disable_vars(node)
 
-    def _optimize_iter(self, node):
+    def optimize_iter(self, node, may_drop):
+        if not self.config.remove_dead_code:
+            may_drop = False
+
+        if may_drop and self.is_empty_iterable(node):
+            # set("") => set()
+            return DROP_NODE
+
+        if isinstance(node, ast.List):
+            # for x in [1, 2, 3]: ... => for x in (1, 2, 3): ...
+            return self.list_to_tuple(node)
+
         if (PYTHON2
-        and 'builtin_funcs' in self.config.features
-        and isinstance(node, ast.Call)):
+        and self.check_builtin_func(node, 'range', 1, 3)):
             # range(int, [int[, int]]) => xrange(...)
-            if not (isinstance(node.func, ast.Name)
-            and node.func.id == 'range'
-            and self.check_func_args(node)):
-                return
-            if not(1 <= len(node.args) <= 3):
-                return
             args = self.get_constant_list(node.args)
             if args is UNSET:
                 return
                 return
             node.func.id = 'xrange'
             return node
-        elif isinstance(node, ast.List):
-            # for x in [1, 2, 3]: ... => for x in (1, 2, 3): ...
-            return self.list_to_tuple(node)
 
-    def optimize_iter(self, node):
-        new_node = self._optimize_iter(node)
-        if new_node is not None:
-            return new_node
-        else:
-            return node
-
-    def fullvisit_comprehension(self, node):
-        self.unassign(node.target)
-        node.iter = self.visit(node.iter)
-        node.iter = self.optimize_iter(node.iter)
+        if may_drop and self.check_func(node, 'iter', 1, 1):
+            iter_arg = node.args[0]
+            if (isinstance(iter_arg, ast.Tuple)
+            and len(iter_arg.elts) == 0):
+                # set(iter([])) => set()
+                return DROP_NODE
 
     def list_to_tuple(self, node):
         if len(node.elts) > self.config.max_tuple_length:
         return copy_lineno(node, new_node)
 
     def fullvisit_For(self, node):
+        node.iter = self.visit(node.iter)
+
+        new_iter = self.optimize_iter(node.iter, True)
+        if new_iter is DROP_NODE:
+            return new_pass(node)
+        elif new_iter is not None:
+            node.iter = new_iter
+
         #node.target = self.visit(node.target)
         self.unassign(node.target)
 
-        node.iter = self.visit(node.iter)
         node.body = self.visit_list(node.body, conditional=True)
         node.orelse = self.visit_list(node.orelse, conditional=True)
 
-        node.iter = self.optimize_iter(node.iter)
-
     def fullvisit_arguments(self, node):
         # Don't visit arguments
         if PYTHON3:
             class_namespace = self.namespace.copy()
             return optimizer._optimize(node, class_namespace)
 
+    def fullvisit_comprehension(self, node):
+        self.unassign(node.target)
+        node.iter = self.visit(node.iter)
+        new_iter = self.optimize_iter(node.iter, False)
+        if new_iter is not None:
+            node.iter = new_iter
+
+        new_ifs = []
+        empty_gen = False
+        for ifexp in node.ifs:
+            constant = self.get_constant(ifexp)
+            if constant is UNSET:
+                new_ifs.append(ifexp)
+            elif not constant:
+                empty_gen = True
+                break
+        if empty_gen:
+            false_cst = new_constant(ifexp, False)
+            new_ifs = [false_cst]
+        node.ifs = new_ifs
+
+    def visit_GeneratorExp(self, node):
+        # use iter() builtin function
+        if 'builtin_funcs' not in self.config.features:
+            return
+        # FIXME: support more than 1 generator
+        if len(node.generators) != 1:
+            return
+        generator = node.generators[0]
+        if not isinstance(generator.target, ast.Name):
+            return
+        if generator.ifs:
+            if len(generator.ifs) != 1:
+                return
+            test_expr = generator.ifs[0]
+            constant = self.get_constant(test_expr)
+            if constant is UNSET:
+                return
+            if constant:
+                return
+            # (x for x in data if False) => iter(())
+            iter_expr = new_constant(test_expr, ())
+        else:
+            name = generator.target.id
+            if (not isinstance(node.elt, ast.Name)
+            or node.elt.id != name):
+                # x*2 for x in data
+                # y for x in data
+                return
+            iter_expr = generator.iter
+        return new_call(node, 'iter', iter_expr)
+
 
 class FunctionOptimizer(Optimizer):
     def __init__(self, config):

astoptimizer/tests.py

         self.check_not_optimized('def f():\n return 1\n return 2', config)
         self.check_not_optimized('if 0: print("log")', config)
         self.check_not_optimized('while 0: print("log")', config)
+        self.check_not_optimized('for x in (): pass', config)
 
     def test_IfExp(self):
         self.check('4 if "abc" else 5', self.text_num(4))
                        config)
         self.check('for x in [1, 2, 3]: pass',
                    self.text_ast('for x in (1, 2, 3): pass'))
+        self.check('for x in "": pass', self.TEXT_PASS)
 
         config = self.create_config()
         config.max_tuple_length = 2
         self.check('len(frozenset("abc"))', self.text_num(3), config)
         self.check_not_optimized('len(frozenset("abcd"))', config)
 
+    def test_iter_empty_iterable(self):
+        config = self.create_config('builtin_funcs')
+        self.check('dict(iter(()))', self.text_ast('dict()'), config)
+        self.check('frozenset(iter(()))', self.text_ast('frozenset()'), config)
+        self.check('list(iter(()))', self.text_ast('list()'), config)
+        self.check('set(iter(()))', self.text_ast('set()'), config)
+        self.check('tuple(iter(()))', self.text_ast('tuple()'), config)
+
+    def test_drop_empty_iterable(self):
+        config = self.create_config('builtin_funcs')
+        self.check('set(())', self.text_ast('set()'), config)
+        self.check('set([])', self.text_ast('set()'), config)
+        self.check('set(tuple())', self.text_ast('set()'), config)
+        self.check('set(list())', self.text_ast('set()'), config)
+        self.check('set(dict())', self.text_ast('set()'), config)
+        self.check('set(set())', self.text_ast('set()'), config)
+        self.check('set(frozenset())', self.text_ast('set()'), config)
+
+    def test_drop_iter_empty_iterable(self):
+        config = self.create_config('builtin_funcs')
+        self.check('tuple(iter(()))', self.text_ast('tuple()'), config)
+        self.check('tuple(iter([]))', self.text_ast('tuple()'), config)
+        self.check('tuple(iter(tuple()))', self.text_ast('tuple()'), config)
+        self.check('tuple(iter(list()))', self.text_ast('tuple()'), config)
+        self.check('tuple(iter(dict()))', self.text_ast('tuple()'), config)
+        self.check('tuple(iter(set()))', self.text_ast('tuple()'), config)
+        self.check('tuple(iter(frozenset()))', self.text_ast('tuple()'), config)
+
+    def test_GeneratorExp(self):
+        self.check('tuple(x*2 for x in (1, 2, 3) if True)',
+                   self.text_ast('tuple(x*2 for x in (1, 2, 3))'))
+        self.check('tuple(x*2 for x in (1, 2, 3) if 0)',
+                   self.text_ast('tuple(x*2 for x in (1, 2, 3) if False)'))
+        self.check_not_optimized('tuple(x*2 for x in (1, 2, 3) if x % 2)')
+
+        # replace with iter()
+        config = self.create_config('builtin_funcs')
+        self.check('tuple(x for x in "abc")',
+                   self.text_ast('tuple(iter("abc"))'),
+                   config)
+        self.check('tuple(x for x in "abc" if True)',
+                   self.text_ast('tuple(iter("abc"))'),
+                   config)
+        self.check('(x*2 for x in (1, 2, 3) if 0)',
+                   self.text_ast('iter(())'),
+                   config)
+        self.check('tuple(x*2 for x in (1, 2, 3) if 0)',
+                   self.text_ast('tuple()'),
+                   config)
+        self.check_not_optimized('tuple(x for x in "abc" if x)', config)
+        self.check_not_optimized('tuple([x for x in "abc"])', config)
+
 
 class TestFrozenset(BaseTestCase):
     def create_default_config(self):
                           {}, ['x', 'y'])
         self.check_values('x=1\nfor x in range(3): print(x)',
                           {}, ['x'])
+        self.check('x = 5\nfor x in "": x = 9\nprint(x)',
+                   self.text_ast('x = 5\nprint("5")'))
 
     def test_listcomp(self):
         self.check_values('[x for x in range(3)]',