Commits

Stephen Tu committed 8ef5e3e

WIP: building feature vectors

Comments (0)

Files changed (3)

pypy/interpreter/astcompiler/astbuilder.py

         if not hasattr(node, 'lineno'):
             return node
         #print "setting lineno=%d for node:" % (self.lineno), node
-        node.lineno += self.lineno
+        node.lineno = self.lineno
         return node
 
 def ast_from_node(space, node, compile_info):
             if not self.do_instrument_file():
               return ast.Module(stmts)
 
-#            my_source = \
-#('def %s' % (self._HIDDEN_NAME_)) + '''(filename, lineno, column, varname, var):
-#  print "filename=%s, lineno=%d, column=%d, varname=%s, type=%s.%s" % (
-#    filename, lineno, column, varname, type(var).__module__, type(var).__name__)
+            my_source = \
+('def %s' % (self._HIDDEN_NAME_)) + '''(filename, lineno, column, varname, var):
+  print "filename=%s, lineno=%d, column=%d, varname=%s, type=%s.%s" % (
+    filename, lineno, column, varname, type(var).__module__, type(var).__name__)
+  return var
+'''
+
+#            my_source = '''
+#def %s(filename, lineno, column, varname, var):
+#  import instrumentation
+#  import json
+#  rec = {
+#    'filename' : filename,
+#    'lineno'   : lineno,
+#    'column'   : column,
+#    'varname'  : varname,
+#    'type'     : type(var).__module__ + '.' + type(var).__name__,
+#  }
+#  fp = instrumentation.get_output_file_object()
+#  print >>fp, json.dumps(rec)
 #  return var
-#'''
-
-            my_source = '''
-def %s(filename, lineno, column, varname, var):
-  import instrumentation
-  import json
-  rec = {
-    'filename' : filename,
-    'lineno'   : lineno,
-    'column'   : column,
-    'varname'  : varname,
-    'type'     : type(var).__module__ + '.' + type(var).__name__,
-  }
-  fp = instrumentation.get_output_file_object()
-  print >>fp, json.dumps(rec)
-  return var
-''' % (self._HIDDEN_NAME_)
+#''' % (self._HIDDEN_NAME_)
 
             ci = pyparse.CompileInfo(self.compile_info.filename,
                                      self.compile_info.mode,

pypy/interpreter/astcompiler/codegen.py

 # you figure out a way to remove them, great, but try a translation first,
 # please.
 
-from pypy.interpreter.astcompiler import ast, assemble, symtable, consts, misc
+from pypy.interpreter.astcompiler import ast, astbuilder, assemble, symtable, consts, misc
 from pypy.interpreter.astcompiler import optimize # For side effects
 from pypy.interpreter.pyparser.error import SyntaxError
 from pypy.tool import stdlib_opcode as ops
 from pypy.interpreter.error import OperationError
 
+class FeatureExtractorVisitor(ast.GenericASTVisitor):
+
+    _AllArgsMask = ~0x0
+    _Arg0Mask    = 0x1
+    _Arg1Mask    = 0x1 << 1
+    _Arg2Mask    = 0x1 << 1
+
+    # see pypy/module/__builtin__/__init__.py
+    _BuiltinsTrackArgs = {
+        'apply'      : _AllArgsMask,
+        'sorted'     : _AllArgsMask,
+        'any'        : _AllArgsMask,
+        'all'        : _AllArgsMask,
+        'sum'        : _AllArgsMask,
+        'map'        : _AllArgsMask,
+        'reduce'     : _AllArgsMask,
+        'filter'     : _AllArgsMask,
+        'zip'        : _AllArgsMask,
+        'open'       : _AllArgsMask,
+        'abs'        : _AllArgsMask,
+        'chr'        : _AllArgsMask,
+        'unichr'     : _AllArgsMask,
+        'len'        : _AllArgsMask,
+        'ord'        : _AllArgsMask,
+        'pow'        : _AllArgsMask,
+        'repr'       : _AllArgsMask,
+        'hash'       : _AllArgsMask,
+        'oct'        : _AllArgsMask,
+        'hex'        : _AllArgsMask,
+        'round'      : _AllArgsMask,
+        'cmp'        : _AllArgsMask,
+        'coerce'     : _AllArgsMask,
+        'divmod'     : _AllArgsMask,
+        'format'     : _AllArgsMask,
+        'issubclass' : _AllArgsMask,
+        'isinstance' : _AllArgsMask,
+        'getattr'    : _AllArgsMask,
+        'setattr'    : _AllArgsMask,
+        'delattr'    : _AllArgsMask,
+        'hasattr'    : _AllArgsMask,
+        'iter'       : _AllArgsMask,
+        'next'       : _AllArgsMask,
+        'id'         : _AllArgsMask,
+        'intern'     : _AllArgsMask,
+        'callable'   : _AllArgsMask,
+        'range'      : _AllArgsMask,
+        'xrange'     : _AllArgsMask,
+        'enumerate'  : _AllArgsMask,
+        'min'        : _AllArgsMask,
+        'max'        : _AllArgsMask,
+        'reversed'   : _AllArgsMask,
+    }
+
+    def __init__(self, symbols, module):
+        self.symbols = symbols
+        self.stack = []
+        self._push_scope(module)
+
+        # symbol features maps each unique, logical symbol to a
+        # set of features encountered for that symbol
+        #   1) a "logical" symbol is a pair ("originating" scope, identifier name)
+        #   2) a feature is a 2-tuple of (feature name, feature value)
+        #
+        # right now, we just hardcode a bunch of random features
+        self.symbol_features = {}
+
+    def _feature_list(self, sym):
+        # sym is (scope, ident), where ident is a string
+        key = sym
+        if key in self.symbol_features:
+            return self.symbol_features[key]
+        value = []
+        self.symbol_features[key] = value
+        return value
+
+    def _logical_symbol(self, node):
+        if isinstance(node, ast.Call):
+            # unwrap our hidden instrumentation
+            if isinstance(node.func, ast.Name) and \
+               node.func.id == astbuilder.ASTBuilder._HIDDEN_NAME_:
+                node = node.args[4]
+        if not isinstance(node, ast.Name):
+            return None
+        name = node
+        if name.id == astbuilder.ASTBuilder._HIDDEN_NAME_:
+            # ignore our instrumentation
+            return None
+        scope = self._get_originating_scope(self._current_scope(), name.id)
+        if not scope:
+            return None
+        return (scope, name.id)
+
+    def _record_feature(self, node, feature_name, feature_value):
+        if not node:
+            return False
+        sym = self._logical_symbol(node)
+        if not sym:
+            return False
+        features = self._feature_list(sym)
+        features.append((feature_name, feature_value))
+
+        # output a particular feature vector for this node location
+        print '<filename=%s, lineno=%d, varname=%s, features=%s>' % (
+            self.symbols.compile_info.filename,
+            node.lineno,
+            sym[1],
+            str(features)
+        )
+        return True
+
+    def _record_binary_feature(self, node, feature_name, feature_value=True):
+        return self._record_feature(node, feature_name, feature_value)
+
+    def _push_scope(self, node):
+        self.stack.append(self.symbols.find_scope(node))
+
+    def _pop_scope(self):
+        ret = self._current_scope()
+        self.stack.pop()
+        return ret
+
+    def _current_scope(self):
+        return self.stack[-1]
+
+    def _module_scope(self):
+        return self.stack[0]
+
+    def _scope_no_str(self, no):
+        if no == symtable.SCOPE_UNKNOWN:
+            return 'SCOPE_UNKNOWN'
+        if no == symtable.SCOPE_GLOBAL_IMPLICIT:
+            return 'SCOPE_GLOBAL_IMPLICIT'
+        if no == symtable.SCOPE_GLOBAL_EXPLICIT:
+            return 'SCOPE_GLOBAL_EXPLICIT'
+        if no == symtable.SCOPE_LOCAL:
+            return 'SCOPE_LOCAL'
+        if no == symtable.SCOPE_FREE:
+            return 'SCOPE_FREE'
+        if no == symtable.SCOPE_CELL:
+            return 'SCOPE_CELL'
+        return 'unknown(%d)' % (no)
+
+    def _get_corresponding_cell_scope_in_parents(self, scope, name):
+        '''
+        precondition: name is a free symbol in scope
+
+        looks for a corresponding CELL variable with the same name
+        in parent scopes. returns None if none exists
+        '''
+        assert scope.lookup(name) == symtable.SCOPE_FREE
+        cur = scope.parent
+        while cur:
+            if cur.lookup(name) == symtable.SCOPE_CELL:
+                return cur
+        return None
+
+    def _get_originating_scope(self, scope, name):
+        '''get the originating scope for symbol name'''
+        no = scope.lookup(name)
+        if no == symtable.SCOPE_GLOBAL_EXPLICIT:
+            # assume module level scope
+            return self._module_scope()
+        if no == symtable.SCOPE_LOCAL or \
+           no == symtable.SCOPE_CELL:
+            # easy case- current scope
+            return scope
+        if no == symtable.SCOPE_FREE:
+            return self._get_corresponding_cell_scope_in_parents(scope, name)
+        # unhandled case
+        return None
+
+    ### Visitors ###
+
+    def visit_Module(self, node):
+        #if self.symbols.compile_info.filename.find('scope.py') != -1:
+        #    import pdb
+        #    pdb.set_trace()
+        self.visit_sequence(node.body)
+
+    def visit_FunctionDef(self, func):
+        if func.name == astbuilder.ASTBuilder._HIDDEN_NAME_:
+            # ignore our instrumentation
+            return
+        args = func.args
+        assert isinstance(args, ast.arguments)
+        self.visit_sequence(args.defaults)
+        self.visit_sequence(func.decorator_list)
+        self._push_scope(func)
+        if func.args.args:
+            for (param, idx) in zip(func.args.args, xrange(len(func.args.args))):
+                self._record_binary_feature(param, 'reg_func_param_%d' % (idx))
+        self.visit_sequence(args.args)
+        self.visit_sequence(func.body)
+        self._pop_scope()
+
+    def visit_ClassDef(self, clsdef):
+        self.visit_sequence(clsdef.bases)
+        self.visit_sequence(clsdef.decorator_list)
+        self._push_scope(clsdef)
+        self.visit_sequence(clsdef.body)
+        self._pop_scope()
+
+    def visit_Return(self, node):
+        if node.value:
+            self._record_binary_feature(node, 'used_as_ret_val')
+
+    def visit_Print(self, node):
+        for value in node.values:
+            self._record_binary_feature(value, 'used_in_print_stmt')
+        ast.GenericASTVisitor.visit_Print(self, node)
+
+    def visit_For(self, node):
+        self._visit_for_like(node.target, node.iter)
+        ast.GenericASTVisitor.visit_For(self, node)
+
+    def visit_While(self, node):
+        self._record_binary_feature(node.test, 'used_as_while_loop_test')
+        ast.GenericASTVisitor.visit_While(self, node)
+
+    def visit_If(self, node):
+        #if self.symbols.compile_info.filename.find('scope.py') != -1:
+        #    import pdb
+        #    pdb.set_trace()
+        self._record_binary_feature(node.test, 'used_as_if_test')
+        ast.GenericASTVisitor.visit_If(self, node)
+
+    def visit_With(self, node):
+        if node.optional_vars:
+            self._record_binary_feature(node.optional_vars, 'used_as_with_target')
+        ast.GenericASTVisitor.visit_With(self, node)
+
+    def visit_Assert(self, node):
+        self._record_binary_feature(node, 'used_in_truth_test')
+        ast.GenericASTVisitor.visit_Assert(self, node)
+
+    def visit_BoolOp(self, node):
+        for value in node.values:
+            self._record_binary_feature(value, 'used_in_truth_test')
+        ast.GenericASTVisitor.visit_BoolOp(self, node)
+
+    def visit_BinOp(self, node):
+        BitArith = (ast.BitOr, ast.BitXor, ast.BitAnd, ast.LShift, ast.RShift)
+        RegArith = (ast.Add, ast.Sub, ast.Mult, ast.Div, ast.FloorDiv, ast.Mod)
+
+        if node.op in BitArith:
+            self._record_binary_feature(node.left, 'used_in_bit_arith')
+            self._record_binary_feature(node.right, 'used_in_bit_arith')
+        elif node.op in RegArith:
+            self._record_binary_feature(node.left, 'used_in_reg_arith')
+            self._record_binary_feature(node.right, 'used_in_reg_arith')
+
+        ast.GenericASTVisitor.visit_BinOp(self, node)
+
+    def visit_UnaryOp(self, node):
+        if node.op == ast.UAdd or node.op == ast.USub:
+            self._record_binary_feature(node.operand, 'used_in_reg_arith')
+        elif node.op == ast.Invert:
+            self._record_binary_feature(node.operand, 'used_in_bit_arith')
+        elif node.op == ast.Not:
+            self._record_binary_feature(node.operand, 'used_in_truth_test')
+        ast.GenericASTVisitor.visit_UnaryOp(self, node)
+
+    def visit_Lambda(self, lamb):
+        args = lamb.args
+        assert isinstance(args, ast.arguments)
+        self.visit_sequence(args.defaults)
+        self._push_scope(lamb)
+        if lamb.args.args:
+            for (param, idx) in zip(lamb.args.args, xrange(len(lamb.args.args))):
+                self._record_binary_feature(param, 'lambda_func_param_%d' % (idx))
+        lamb.args.walkabout(self)
+        lamb.body.walkabout(self)
+        self._pop_scope()
+
+    def visit_IfExp(self, node):
+        self._record_binary_feature(node.test, 'used_as_if_test')
+        ast.GenericASTVisitor.visit_IfExp(self, node)
+
+    def visit_Dict(self, node):
+        if node.keys:
+            for key in node.keys:
+                self._record_binary_feature(key, 'used_as_dict_key')
+        ast.GenericASTVisitor.visit_Dict(self, node)
+
+    def visit_SetComp(self, setcomp):
+        self._visit_comprehension(setcomp, setcomp.generators, setcomp.elt)
+
+    def visit_DictComp(self, dictcomp):
+        self._visit_comprehension(dictcomp, dictcomp.generators,
+                                  dictcomp.value, dictcomp.key)
+
+    def visit_GeneratorExp(self, genexp):
+        self._visit_comprehension(genexp, genexp.generators, genexp.elt)
+
+    def visit_Yield(self, node):
+        if node.value:
+            self._record_binary_feature(node.value, 'used_in_yield')
+        ast.GenericASTVisitor.visit_Yield(self, node)
+
+    def visit_Compare(self, node):
+        assert len(node.ops) >= 1
+        # only look at ops[0]
+        if node.ops[0] in (ast.In, ast.NotIn):
+            self._record_binary_feature(node.left, 'used_as_search_key')
+            self._record_binary_feature(node.comparators[0], 'used_as_searchable')
+        elif node.ops[0] in (ast.Lt, ast.LtE, ast.Gt, ast.GtE):
+            self._record_binary_feature(node.left, 'used_as_comparable')
+            self._record_binary_feature(node.comparators[0], 'used_as_comparable')
+        elif node.ops[0] in (ast.Is, ast.IsNot):
+            self._record_binary_feature(node.left, 'used_in_ref_eq')
+            self._record_binary_feature(node.comparators[0], 'used_in_ref_eq')
+
+        ast.GenericASTVisitor.visit_Compare(self, node)
+
+    def visit_Call(self, node):
+        # check builtin arg
+        if isinstance(node.func, ast.Name):
+            name = node.func.id
+            try:
+                argmask = self._BuiltinsTrackArgs[name]
+                for (arg, idx) in zip(node.args, xrange(len(node.args))):
+                    if not (idx & argmask):
+                        continue
+                    self._record_binary_feature(arg, 'used_as_builtin_%s_arg%d' % (name, idx))
+            except KeyError:
+                self._record_binary_feature(node.func, 'used_as_callable')
+        ast.GenericASTVisitor.visit_Call(self, node)
+
+    def visit_Attribute(self, node):
+        self._record_binary_feature(node.value, 'used_as_object')
+        ast.GenericASTVisitor.visit_Attribute(self, node)
+
+    def visit_Subscript(self, node):
+        self._record_binary_feature(node.value, 'used_as_subscriptable')
+        self._record_binary_feature(node.slice, 'used_as_subscript_idx')
+        ast.GenericASTVisitor.visit_Subscript(self, node)
+
+    def visit_Slice(self, node):
+        if node.lower:
+            self._record_binary_feature(node.lower, 'used_as_slice_idx')
+        if node.upper:
+            self._record_binary_feature(node.upper, 'used_as_slice_idx')
+        if node.step:
+            self._record_binary_feature(node.step, 'used_as_slice_idx')
+        ast.GenericASTVisitor.visit_Slice(self, node)
+
+    def visit_comprehension(self, node):
+        self._visit_for_like(node.target, node.iter)
+        # XXX: node.ifs
+        ast.GenericASTVisitor.visit_comprehension(self, node)
+
+    ### Visitor helpers ###
+
+    def _visit_comprehension(self, node, comps, *consider):
+        outer = comps[0]
+        assert isinstance(outer, ast.comprehension)
+        outer.iter.walkabout(self)
+        self._push_scope(node)
+        outer.target.walkabout(self)
+        self.visit_sequence(outer.ifs)
+        self.visit_sequence(comps[1:])
+        for item in list(consider):
+            item.walkabout(self)
+        self._pop_scope()
+
+    def _visit_for_like(self, target, itr):
+        if isinstance(target, ast.Tuple):
+            # unwrap tuple 1 level
+            for elt in target.elts:
+                self._record_binary_feature(elt, 'used_as_for_loop_target')
+        else:
+            self._record_binary_feature(target, 'used_as_for_loop_target')
+        self._record_binary_feature(itr, 'used_as_for_loop_iter')
 
 def compile_ast(space, module, info):
     """Generate a code object from AST."""
     symbols = symtable.SymtableBuilder(space, module, info)
+    v = FeatureExtractorVisitor(symbols, module)
+    module.walkabout(v)
     return TopLevelCodeGenerator(space, module, symbols, info).assemble()
 
 

pypy/interpreter/astcompiler/optimize.py

             assert type(my_ast) == ast.Call
             v = astbuilder.LineNumberVisitor(name.lineno)
             my_ast.mutate_over(v)
-            #print "instrumented: %s:%d %s" % (self.compile_info.filename, name.lineno, name.id)
+            assert my_ast.lineno == name.lineno
+            print "instrumented: %s:%d %s" % (self.compile_info.filename, name.lineno, name.id)
             return my_ast
         else:
             # don't bother w/ stores, delets, etc