Commits

Stephen Tu  committed 0ebe824

improve feature generation

  • Participants
  • Parent commits 8fb095a

Comments (0)

Files changed (1)

File pypy/interpreter/astcompiler/codegen.py

         'reversed'   : _AllArgsMask,
     }
 
+    _BitArith = (ast.BitOr, ast.BitXor, ast.BitAnd, ast.LShift, ast.RShift)
+    _RegArith = (ast.Add, ast.Sub, ast.Mult, ast.Div, ast.FloorDiv, ast.Mod, ast.Pow)
+
     def __init__(self, symbols, module, mapping_file, feature_file):
         self.symbols = symbols
         self.stack = []
         if func.name == astbuilder.ASTBuilder._HIDDEN_NAME_:
             # ignore our instrumentation
             return
+        self._break_if_file_matches('default.py')
         args = func.args
         assert isinstance(args, ast.arguments)
         self.visit_sequence(args.defaults)
         if func.args.args:
             for (param, idx) in zip(func.args.args, xrange(len(func.args.args))):
                 self._record_binary_feature(param, 'reg_func_param_%d' % (idx))
+            if func.args.defaults:
+                for (param, default) in zip(func.args.args[::-1],
+                                            func.args.defaults[::-1]):
+                    self._assignment_like(param, default, aug=False)
         self.visit_sequence(args.args)
         self.visit_sequence(func.body)
         self._pop_scope()
 
     def visit_Assign(self, node):
         #self._break_if_file_matches('scope.py')
-        if len(node.targets) == 1 and isinstance(node.value, ast.Num):
-            # XXX: hacky- shouldn't do it like this
-            value = self.symbols.space.unwrap(node.value.n)
-            if type(value) == int:
-                self._record_binary_feature(node.targets[0], 'assigned_int_literal')
-            elif type(value) == float:
-                self._record_binary_feature(node.targets[0], 'assigned_float_literal')
-            else:
-                print '[WARNING] unknown numeric type: %s' % (str(type(value)))
-        elif len(node.targets) == 1 and isinstance(node.value, ast.Str):
-            self._record_binary_feature(node.targets[0], 'assigned_str_literal')
-        elif len(node.targets) == 1 and isinstance(node.value, ast.Name):
-            if node.value.id in ('True', 'False'):
-                self._record_binary_feature(node.targets[0], 'assigned_bool_literal')
-            elif node.value.id == 'None':
-                self._record_binary_feature(node.targets[0], 'assigned_none_literal')
+        if len(node.targets) == 1:
+            self._assignment_like(node.targets[0], node.value, aug=False)
         ast.GenericASTVisitor.visit_Assign(self, node)
 
+    def visit_AugAssign(self, node):
+        self._assignment_like(node.target, node.value, aug=True)
+        self._binop_like(node.target, node.value, node.op)
+        ast.GenericASTVisitor.visit_AugAssign(self, node)
+
     def visit_Print(self, node):
         for value in node.values:
             self._record_binary_feature(value, 'used_in_print_stmt')
         ast.GenericASTVisitor.visit_BoolOp(self, node)
 
     def visit_BinOp(self, node):
-        BitArith = (ast.BitOr, ast.BitXor, ast.BitAnd, ast.LShift, ast.RShift)
-        RegArith = (ast.Add, ast.Sub, ast.Mult, ast.Div, ast.FloorDiv, ast.Mod)
-
-        if node.op in BitArith:
-            self._record_binary_feature(node.left, 'used_in_bit_arith')
-            self._record_binary_feature(node.right, 'used_in_bit_arith')
-        elif node.op in RegArith:
-            self._record_binary_feature(node.left, 'used_in_reg_arith')
-            self._record_binary_feature(node.right, 'used_in_reg_arith')
-
+        self._binop_like(node.left, node.right, node.op)
         ast.GenericASTVisitor.visit_BinOp(self, node)
 
     def visit_UnaryOp(self, node):
         if node.ops[0] in (ast.In, ast.NotIn):
             self._record_binary_feature(node.left, 'used_as_search_key')
             self._record_binary_feature(node.comparators[0], 'used_as_searchable')
-        elif node.ops[0] in (ast.Lt, ast.LtE, ast.Gt, ast.GtE):
-            self._record_binary_feature(node.left, 'used_as_comparable')
-            self._record_binary_feature(node.comparators[0], 'used_as_comparable')
-        elif node.ops[0] in (ast.Is, ast.IsNot):
-            self._record_binary_feature(node.left, 'used_in_ref_eq')
-            self._record_binary_feature(node.comparators[0], 'used_in_ref_eq')
-
+        else:
+            self._binop_like(node.left, node.comparators[0], node.ops[0])
         ast.GenericASTVisitor.visit_Compare(self, node)
 
     def visit_Call(self, node):
             self._record_binary_feature(target, 'used_as_for_loop_target')
         self._record_binary_feature(itr, 'used_as_for_loop_iter')
 
+    def _literal_type(self, value):
+        if isinstance(value, ast.Num):
+            # XXX: hacky- shouldn't do it like this
+            konst = self.symbols.space.unwrap(value.n)
+            return type(konst)
+        elif isinstance(value, ast.Str):
+            return str
+        elif isinstance(value, ast.Name):
+            if value.id in ('True', 'False'):
+                return bool
+            elif value.id == 'None':
+                return type(None) # NoneType
+        elif isinstance(value, ast.List):
+            return list
+        elif isinstance(value, ast.Dict):
+            return dict
+        return None
+
+    _TypeToFeatureName = {
+        int        : 'int_literal',
+        float      : 'float_literal',
+        str        : 'str_literal',
+        bool       : 'bool_literal',
+        type(None) : 'none_literal',
+        list       : 'list_literal',
+        dict       : 'dict_literal',
+    }
+
+    def _assignment_like(self, target, value, aug):
+        def feature_name(suffix):
+            return ('aug_assigned_%s' % (suffix)) if aug else ('assigned_%s' % (suffix))
+        konst_type = self._literal_type(value)
+        if konst_type:
+            try:
+                k = self._TypeToFeatureName[konst_type]
+                self._record_binary_feature(target, feature_name(k))
+            except KeyError:
+                print '[WARNING]: unknown literal type: %s' % (str(konst_type))
+
+    # XXX: this is arbitrary how we assign which binops/cmps to distinguish
+    # between: we'll need to refine this
+    _BinopTypeToFeatureName = {
+        ast.Add      : 'plusminus',
+        ast.Sub      : 'plusminus',
+        ast.Div      : 'multdiv',
+        ast.FloorDiv : 'multdiv',
+        ast.Mod      : 'mod',
+        ast.LShift   : 'bitops',
+        ast.RShift   : 'bitops',
+        ast.BitAnd   : 'bitops',
+        ast.BitOr    : 'bitops',
+        ast.BitXor   : 'bitops',
+        ast.Mult     : 'multdiv',
+        ast.Pow      : 'pow',
+        ast.Eq       : 'eq',
+        ast.NotEq    : 'eq',
+        ast.Lt       : 'ieq',
+        ast.LtE      : 'ieq',
+        ast.Gt       : 'ieq',
+        ast.GtE      : 'ieq',
+        ast.Is       : 'ref_eq',
+        ast.IsNot    : 'ref_eq',
+    }
+
+    def _binop_like(self, lhs, rhs, op):
+        assert op not in (ast.In, ast.NotIn), 'not handled'
+        # check two special cases: lhs is const, and rhs is const
+        lhs_const_type = self._literal_type(lhs)
+        rhs_const_type = self._literal_type(rhs)
+        done = False
+        if lhs_const_type:
+            try:
+                k = self._TypeToFeatureName[lhs_const_type]
+                self._record_binary_feature(rhs, 'used_in_%s_%s' % (
+                    self._BinopTypeToFeatureName[op],
+                    k))
+                done = True
+            except KeyError:
+                pass
+        elif rhs_const_type:
+            try:
+                k = self._TypeToFeatureName[rhs_const_type]
+                self._record_binary_feature(lhs, 'used_in_%s_%s' % (
+                    self._BinopTypeToFeatureName[op],
+                    k))
+                done = True
+            except KeyError:
+                pass
+        if not done:
+            # the general case (which is also a fallback case if the const
+            # cases fail above)
+            self._record_binary_feature(
+                lhs, 'used_in_%s' % (self._BinopTypeToFeatureName[op]))
+
 ### XXX: avoid code duplication
 
 def get_feature_file_object():