Commits

Victor Stinner committed 8595340

get_constant() takes an optional want_type argument

Comments (0)

Files changed (2)

astoptimizer/optimizer.py

     u,
     PYTHON2, PYTHON3,
     is_bytes_ascii, is_singleton,
-    INT_TYPES, FLOAT_TYPES, COMPLEX_TYPES,
+    INT_TYPES, FLOAT_TYPES, COMPLEX_TYPES, NONE_TYPE,
     BYTES_TYPE, UNICODE_TYPE, STR_TYPES,
     IMMUTABLE_ITERABLE_TYPES, ITERABLE_TYPES)
 
             return False
         return self.check_func(node, name, min_narg, max_narg)
 
-    def get_constant(self, node):
+    def get_constant(self, node, want_type=None):
         if node is None:
-            return None
-        if isinstance(node, ast.Num):
-            return node.n
-        if isinstance(node, ast.Str):
-            text = node.s
-            if (isinstance(text, UNICODE_TYPE)
-            and not optimize_unicode(self.config, text)):
+            constant = None
+        elif isinstance(node, ast.Num):
+            constant = node.n
+        elif isinstance(node, ast.Str):
+            constant = node.s
+            if (isinstance(constant, UNICODE_TYPE)
+            and not optimize_unicode(self.config, constant)):
                 return UNSET
-            return text
-        if PYTHON3:
-            if isinstance(node, ast.Bytes):
-                return node.s
-        if isinstance(node, ast.Name):
-            return self.load_name(node.id)
-        if isinstance(node, ast.Tuple):
+        elif PYTHON3 and isinstance(node, ast.Bytes):
+            constant = node.s
+        elif isinstance(node, ast.Name):
+            constant = self.load_name(node.id)
+        elif isinstance(node, ast.Tuple):
+            if (want_type is not None
+            and not issubclass(tuple, want_type)):
+                return UNSET
             elts = node.elts
             if len(elts) > self.config.max_tuple_length:
                 return UNSET
                     return UNSET
                 constants.append(constant)
             return tuple(constants)
-        if self.check_builtin_func(node, 'frozenset', 0, 1):
+        elif self.check_builtin_func(node, 'frozenset', 0, 1):
+            if (want_type is not None
+            and not issubclass(frozenset, want_type)):
+                return UNSET
             if len(node.args) == 1:
-                arg = self.get_literal(node.args[0])
+                arg = self.get_literal(node.args[0], ITERABLE_TYPES)
                 if arg is UNSET:
                     return UNSET
-                if not isinstance(arg, ITERABLE_TYPES):
-                    return UNSET
                 if len(arg) > self.config.max_tuple_length:
                     return UNSET
                 return frozenset(arg)
-            elif len(node.args) == 0:
-                return frozenset()
-        return UNSET
 
-    def get_literal(self, node):
+            # else: len(node.args) == 0:
+            return frozenset()
+        else:
+            return UNSET
+
+        if (want_type is not None
+        and not isinstance(constant, want_type)):
+            return UNSET
+        return constant
+
+    def get_literal(self, node, want_type=None):
         if isinstance(node, ast.List):
+            if (want_type is not None
+            and not issubclass(list, want_type)):
+                return UNSET
             return [self.get_literal(elt) for elt in node.elts]
-        return self.get_constant(node)
+        else:
+            return self.get_constant(node, want_type)
 
-    def get_constant_list(self, nodes):
+    def get_constant_list(self, nodes, want_type=None):
         constants = []
         for node in nodes:
-            constant = self.get_constant(node)
+            constant = self.get_constant(node, want_type)
             if constant is UNSET:
                 return UNSET
             constants.append(constant)
         return new_node_list
 
     def visit_Name(self, node):
-        constant = self.get_constant(node)
+        constant = self.load_name(node.id)
         if constant is UNSET:
             return
         return new_constant(node, constant)
         return False
 
     def binop(self, node, eval_binop, left, left_cst, right, right_cst):
-        if left_cst is UNSET:
-            return UNSET
-        if right_cst is UNSET:
-            return UNSET
-
         if (isinstance(node.op, ast.Mod)
         and isinstance(left_cst, STR_TYPES)):
             # str % args
         eval_binop = EVAL_BINOP.get(node.op.__class__)
         if not eval_binop:
             return
+
         left = node.left
         left_cst = self.get_constant(left)
+        if left_cst is UNSET:
+            return
+
         right = node.right
         right_cst = self.get_constant(right)
+        if right_cst is UNSET:
+            return
+
         new_node = self.binop(node, eval_binop, left, left_cst, right, right_cst)
         if new_node is not UNSET:
             return new_node
-        return node
 
     def visit_BoolOp(self, node):
         if (not hasattr(node, 'op')
         if eval_unaryop is None:
             return
         operand = node.operand
-        constant = self.get_constant(operand)
+        constant = self.get_constant(operand, COMPLEX_TYPES)
         if constant is not UNSET:
             if not self.check_unary_op(node.op, constant):
                 return
                 return new_node
 
     def subscript_slice(self, node, value, lower, upper, step):
-        lower_cst = self.get_constant(lower)
+        lower_cst = self.get_constant(lower, INT_TYPES + (NONE_TYPE,))
         if lower_cst is UNSET:
             return UNSET
-        upper_cst = self.get_constant(upper)
+        upper_cst = self.get_constant(upper, INT_TYPES + (NONE_TYPE,))
         if upper_cst is UNSET:
             return UNSET
-        step_cst = self.get_constant(step)
+        step_cst = self.get_constant(step, INT_TYPES + (NONE_TYPE,))
         if step_cst is UNSET:
             return UNSET
         myslice = slice(lower_cst, upper_cst, step_cst)
         return new_constant(node, value[myslice])
 
     def subscript_index(self, node, value, index):
-        index_constant = self.get_constant(index)
+        index_constant = self.get_constant(index, INT_TYPES)
         if index_constant is UNSET:
             return UNSET
-        if not isinstance(index_constant, INT_TYPES):
-            return UNSET
         if index_constant >= 0:
             if index_constant >= len(value):
                 return UNSET
 
     def visit_Subscript(self, node):
         value = node.value
-        value_constant = self.get_constant(value)
+        value_constant = self.get_constant(value, IMMUTABLE_ITERABLE_TYPES)
         if value_constant is UNSET:
             return
         if isinstance(node.slice, ast.Slice):
             if isinstance(data, ast.List):
                 return new_set_elts(data, data.elts)
 
-            constant = self.get_constant(data)
+            constant = self.get_constant(data, (tuple, frozenset))
             if constant is UNSET:
                 return
-            if isinstance(constant, (tuple, list, frozenset)):
-                # x in (1, 2) => x in {1, 2}
-                # x in frozenset((1, 2)) => x in {1, 2}
-                return new_set(data, constant)
+            # x in (1, 2) => x in {1, 2}
+            # x in frozenset((1, 2)) => x in {1, 2}
+            return new_set(data, constant)
         else:
             if isinstance(data, ast.List):
                 # x in [1, 2] => x in (1, 2)
         constant = self.get_constant(node.test)
         if constant is UNSET:
             return
+        if not self.config.remove_dead_code:
+            if constant is not UNSET:
+                node.test = self.constant_to_test(node.test, constant)
+            return
         if constant:
             return node.body
         else:
                 # list([1, 2, 3]) => [1, 2, 3]
                 return arg
         elif qualname in ('frozenset', 'set'):
-            constant = self.get_constant(arg)
-            if isinstance(constant, IMMUTABLE_ITERABLE_TYPES):
+            constant = self.get_constant(arg, IMMUTABLE_ITERABLE_TYPES)
+            if constant is not UNSET:
                 elts = frozenset(constant)
                 use_literal_set = (qualname == 'set' and sys.version_info >= (2, 7))
                 if (len(elts) != len(constant)) or use_literal_set:

astoptimizer/tests.py

         self.check('"abc"[2]', self.text_native_str('c'))
         self.check_not_optimized('"abc"[3]')
         self.check('"abc"[-3]', self.text_native_str('a'))
+        self.check('"abc"[None:]', self.text_native_str('abc'))
         self.check_not_optimized('"abc"[-4]')
         self.check_not_optimized('"abc"[None]')
+        self.check_not_optimized('"abc"[1.0]')
+        self.check_not_optimized('"abc"[:1.0]')
+        self.check_not_optimized('1[1]')
+        self.check_not_optimized('1[:1]')
 
     def check_pass(self, code):
         expected = code.replace("pass; pass", "pass")
         self.check('while "": print("log")',
                    self.text_ast('while 0: print("log")'),
                    config)
+        self.check('2 if "abc" else 3',
+                   self.text_ast('2 if 1 else 3'),
+                   config)
+        self.check('2 if "" else 3',
+                   self.text_ast('2 if 0 else 3'),
+                   config)
         self.check_not_optimized('for x in (): pass', config)
         self.check_not_optimized('for x in (1, 2, 3): pass', config)
         self.check_not_optimized('[x for x in () if 1]', config)