Commits

Victor Stinner committed 559493f

eval small range(n) and xrange(n), ex: range(3)

Comments (0)

Files changed (7)

    shadowed at runtime.
  - x in [a, b, c] => x in {a, b, c}.
    A list can contain non-hashable objects, a set cannot.
+ - [x*2 for x in "abc"] => (x*2 for x in "abc").
+   The generator expression doesn't set the local variable x, and generators
+   are slower than list comprehensions.
+ - range(n) => xrange(n): xrange() raises an OverflowError if n is too big
 
  * Optimize loops. Examples:
 
    - ``while True: ...`` => ``while 1: ...``
-   - ``for x in range(3): ...`` => ``for x in xrange(3): ...`` (Python 2)
+   - ``for x in range(3): ...`` => ``for x in (0, 1, 2): ...``
+   - ``for x in range(1000): ...`` => ``for x in xrange(1000): ...`` (Python 2)
 
  * Optimize list comprehension and generators. Examples:
 
 bugs:
 
  - "i=0; while i < 10: print(i); i = i + 1": don't replace print(i) with print('0')
+ - "from math import pow as xrange; list(range(n))"
+    => "from math import pow as xrange; list(xrange(n))" is wrong
 
 major optimizations:
 

astoptimizer/config_builtin_funcs.py

 def setup_config(config):
     # pure builtin functions
     config.add_func('abs', Function(abs, 1, COMPLEX_TYPES))
+    config.add_func('bin', Function(bin, 1, INT_TYPES))
     config.add_func('bool', Function(bool, 1, FLOAT_TYPES + STR_TYPES))
     if PYTHON3:
         config.add_func('chr', Function(chr, 1, INT_TYPES, check_args=check_unichr))
     config.add_func('sum', Function(sum, (1, 2), (tuple, frozenset), COMPLEX_TYPES, check_args=check_sum_args))
     config.add_func('tuple', Function(tuple, (0, 1), IMMUTABLE_ITERABLE_TYPES, check_args=check_tuple_args))
 
-    config.add_func('bin', Function(bin, 1, INT_TYPES))
     if PYTHON2:
         config.add_func('long', Function(long, 1, FLOAT_TYPES))
         config.add_func('unichr', Function(unichr, 1, INT_TYPES, check_args=check_unichr))

astoptimizer/optimizer.py

                 return False
         return check_func_args(node, min_narg, max_narg)
 
-    def check_builtin_func(self, node, name, min_narg=None, max_narg=None):
+    def check_builtin_func(self, node, name, min_narg, max_narg):
         if 'builtin_funcs' not in self.config.features:
             return False
         return self.check_func(node, name, min_narg, max_narg)
                     new_node = ast.Dict(keys=[], values=[])
                     return copy_lineno(node, new_node)
                 del node.args[0]
-            elif new_arg is not None:
+                return
+            if new_arg is not None:
                 node.args[0] = new_arg
+                return
+
+            if (self.config.remove_dead_code
+            and self.check_builtin_func(arg, BUILTIN_ACCEPTING_ITERABLE,  1, 1)):
+                # set(list(iterable)) => set(iterable)
+                node.args[0] = arg.args[0]
+                return
 
         elif (self.check_builtin_func(node, 'iter', 1, 1)
         and self.is_empty_iterable(node.args[0])):
             # iter(set()) => iter(())
             node.args[0] = new_constant(node.args[0], ())
 
-        elif self.check_builtin_func(node, 'print'):
+        elif self.check_builtin_func(node, 'print', None, None):
             self.print_func(node, node.args)
 
         else:
         else:
             self.disable_vars(node)
 
+    def optimize_range(self, node):
+        args = self.get_constant_list(node.args)
+        if args is UNSET:
+            return
+        if len(args) == 1:
+            start = 0
+            stop = args[0]
+            step = 1
+        elif len(args) == 2:
+            start = args[0]
+            stop = args[1]
+            step = 1
+        elif len(args) == 3:
+            start = args[0]
+            stop = args[1]
+            step = args[2]
+
+        if step == 0:
+            return
+        if not all(isinstance(arg, INT_TYPES) for arg in args):
+            return
+
+        if PYTHON2:
+            minval = self.config.min_c_long
+            maxval = self.config.max_c_long
+            if not all(minval <= arg <= maxval for arg in args):
+                return
+
+        if PYTHON3:
+            range_tuple = range(*args)
+        else:
+            range_tuple = xrange(*args)
+        try:
+            range_len = len(range_tuple)
+        except OverflowError:
+            # OverflowError: Python int too large to convert to C ssize_t
+            pass
+        else:
+            if range_len <= self.config.max_tuple_length:
+                # range(3) => (0, 1, 2)
+                return new_constant(node, tuple(range_tuple))
+
+        if PYTHON3:
+            return
+        qualname = self.namespace.get_qualname(node.func.id)
+        if qualname != 'range':
+            return
+
+        # range(int, [int[, int]]) => xrange(...)
+        node.func.id = 'xrange'
+        return node
+
     def optimize_iter(self, node):
         if (self.config.remove_dead_code
         and self.is_empty_iterable(node)):
             # for x in [1, 2, 3]: ... => for x in (1, 2, 3): ...
             return self.list_to_tuple(node)
 
-        if (PYTHON2
-        and self.check_builtin_func(node, 'range', 1, 3)):
-            # range(int, [int[, int]]) => xrange(...)
-            args = self.get_constant_list(node.args)
-            if args is UNSET:
-                return
-            minval = self.config.min_c_long
-            maxval = self.config.max_c_long
-            if not all((isinstance(arg, INT_TYPES) and minval <= arg <= maxval)
-                       for arg in args):
-                return
-            node.func.id = 'xrange'
-            return node
+        if PYTHON2:
+            is_range = (self.check_builtin_func(node, 'range', 1, 3)
+                        or self.check_builtin_func(node, 'xrange', 1, 3))
+        else:
+            is_range = self.check_builtin_func(node, 'range', 1, 3)
+        if is_range:
+            return self.optimize_range(node)
 
         if (self.config.remove_dead_code
         and self.check_builtin_func(node, 'iter', 1, 1)):
             # set(iter(iterable)) => set(iterable)
             return iter_arg
 
-        if (self.config.remove_dead_code
-        and self.check_builtin_func(node, 'list', 1, 1)):
-            # set(list(arg)) => set(arg)
-            iterable = node.args[0]
-            if self.is_empty_iterable(iterable):
-                # set("") => set()
-                return DROP_NODE
-            return iterable
-
     def list_to_tuple(self, node):
         if len(node.elts) > self.config.max_tuple_length:
             return node

astoptimizer/tests.py

 
     def __init__(self, *args):
         MyTestCase.__init__(self, *args)
+        self.maxDiff = 4096
         self.warnings = []
 
     def setUp(self):
         self.check('print >>f, "x =", 2', self.text_ast('print >> f, %r' % ('x = 2',)))
 
     def test_optimize_iter(self):
+        config = self.create_config('builtin_funcs')
+
+        # optimized
+        self.check('[x for x in range(3)]',
+                   self.text_ast('list((0, 1, 2))'),
+                   config)
+        self.check('[x*2 for x in range(0, 2 ** 20, 2 ** 18)]',
+                   self.text_ast('[x*2 for x in (0, %s, %s, %s)]' % (2 ** 18, 2 * 2 ** 18, 3 * 2 ** 18)),
+                   config)
+
         if PYTHON2:
-            config = self.create_config('builtin_funcs')
-
-            # optimized
-            self.check('[x for x in range(3)]',
-                       self.text_ast('list(xrange(3))'),
+            self.check('[x for x in range(1000)]',
+                       self.text_ast('list(xrange(1000))'),
                        config)
-            self.check('[x*2 for x in range(1, 10)]',
-                       self.text_ast('[x*2 for x in xrange(1, 10)]'),
+            self.check('[x*2 for x in range(1, 1000)]',
+                       self.text_ast('[x*2 for x in xrange(1, 1000)]'),
                        config)
-            self.check('[x*2 for x in range(0, 10, 4)]',
-                       self.text_ast('[x*2 for x in xrange(0, 10, 4)]'),
+            self.check('[x*2 for x in range(0, 1000, 4)]',
+                       self.text_ast('[x*2 for x in xrange(0, 1000, 4)]'),
                        config)
 
-            # not optimized
-            self.check('[x*2 for x in range(0, 2 ** 100, 2 ** 99)]',
-                       self.text_ast('[x*2 for x in range(0, %s, %s)]' % (2 ** 100, 2 ** 99)),
-                       config)
-            self.check_not_optimized('[x*2 for x in range(1, 2, 3, 4)]',
-                       config)
-            self.check_not_optimized('[x*2 for x in range(1, 2, step=3)]',
-                       config)
-            self.check_not_optimized('[x*2 for x in range(0, 5, 0.1)]',
+        # not optimized
+        self.check('[x*2 for x in range(0, 2 ** 100)]',
+                   self.text_ast('[x*2 for x in range(0, %s)]' % (2 ** 100,)),
+                   config)
+        self.check_not_optimized('[x*2 for x in range(1, 2, 3, 4)]',
+                   config)
+        self.check_not_optimized('[x*2 for x in range(1, 2, step=3)]',
+                   config)
+        self.check_not_optimized('[x*2 for x in range(0, 5, 0.1)]',
+                   config)
+
+    def test_For(self):
+        # range => xrange
+        self.check_not_optimized('for x in range(n): pass')
+
+        config = self.create_config('builtin_funcs')
+        self.check('for x in range(3): pass',
+                   self.text_ast('for x in (0, 1, 2): pass'),
+                   config)
+        if PYTHON2:
+            self.check('for x in range(1000): pass',
+                       self.text_ast('for x in xrange(1000): pass'),
                        config)
 
-    def test_For(self):
-        if PYTHON2:
-            # range => xrange
-            self.check_not_optimized('for x in range(n): pass')
-
-            config = self.create_config('builtin_funcs')
-            self.check('for x in range(3): pass',
-                       self.text_ast('for x in xrange(3): pass'),
-                       config)
         self.check('for x in [1, 2, 3]: pass',
                    self.text_ast('for x in (1, 2, 3): pass'))
         self.check('for x in "": pass', self.TEXT_PASS)
         # generators
         self.check('(x*2 for x in "")',
                    self.text_ast('(None for x in ())'))
+        if PYTHON2:
+            self.check('(x*2 for x in range(1000))',
+                       self.text_ast('(x*2 for x in xrange(1000))'),
+                       config)
+        else:
+            self.check_not_optimized('(x*2 for x in range(1000))', config)
         self.check('(x*2 for x in range(3))',
-                   self.text_ast('(x*2 for x in xrange(3))'),
+                   self.text_ast('(x*2 for x in (0, 1, 2))'),
                    config)
 
         # if
         # list comprehension
         self.check('[x*2 for x in ""]',
                    self.text_ast('[]'))
+        if PYTHON2:
+            self.check('[x*2 for x in range(1000)]',
+                       self.text_ast('[x*2 for x in xrange(1000)]'),
+                       config)
+        else:
+            self.check_not_optimized('[x*2 for x in range(1000)]', config)
         self.check('[x*2 for x in range(3)]',
-                   self.text_ast('[x*2 for x in xrange(3)]'),
+                   self.text_ast('[x*2 for x in (0, 1, 2)]'),
                    config)
 
         # if
         self.check('tuple([x for x in iterable])',
                    self.text_ast('tuple(iterable)'),
                    config)
+        if PYTHON2:
+            self.check('set([x for x in range(1000)])',
+                       self.text_ast('set(xrange(1000))'),
+                       config)
+        else:
+            self.check('set([x for x in range(1000)])',
+                       self.text_ast('set(range(1000))'),
+                       config)
 
+        self.check('set([x for x in range(3)])',
+                   self.text_ast('set((0, 1, 2))'),
+                   config)
 
 class TestFrozenset(BaseTestCase):
     def create_default_config(self):
 
     config = Config('builtin_funcs', 'pythonenv')
     config.use_experimental_vars = True
-    config.remove_dead_code = False
-    print("Config features: %s" % ', '.join(sorted(config.features)))
+#    config.remove_dead_code = False
+    print("Features: %s" % ', '.join(sorted(config.features)))
+    print("Experimental varaibles? %s" % config.use_experimental_vars)
+    print("Remove code? %s" % config.remove_dead_code)
     print("")
 
     tree = parse_ast(code_str)