pypy / pypy / annotation / specialize.py

# specialization support
import py
from pypy.tool.uid import uid
from pypy.tool.sourcetools import func_with_new_name
from pypy.tool.algo.unionfind import UnionFind
from pypy.objspace.flow.model import Block, Link, Variable, SpaceOperation
from pypy.objspace.flow.model import Constant, checkgraph
from pypy.annotation import model as annmodel
from pypy.objspace.flow.argument import Signature

def flatten_star_args(funcdesc, args_s):
    argnames, vararg, kwarg = funcdesc.signature
    assert not kwarg, "functions with ** arguments are not supported"
    if vararg:
        # calls to *arg functions: create one version per number of args
        assert len(args_s) == len(argnames) + 1
        s_tuple = args_s[-1]
        assert isinstance(s_tuple, annmodel.SomeTuple), (
            "calls f(..., *arg) require 'arg' to be a tuple")
        s_len = s_tuple.len()
        assert s_len.is_constant(), "calls require known number of args"
        nb_extra_args = s_len.const
        flattened_s = list(args_s[:-1])
        flattened_s.extend(s_tuple.items)
        
        def builder(translator, func):
            # build a hacked graph that doesn't take a *arg any more, but
            # individual extra arguments
            graph = translator.buildflowgraph(func)
            argnames, vararg, kwarg = graph.signature
            assert vararg, "graph should have a *arg at this point"
            assert not kwarg, "where does this **arg come from??"
            argscopy = [Variable(v) for v in graph.getargs()]
            starargs = [Variable('stararg%d'%i) for i in range(nb_extra_args)]
            newstartblock = Block(argscopy[:-1] + starargs)
            newtup = SpaceOperation('newtuple', starargs, argscopy[-1])
            newstartblock.operations.append(newtup)
            newstartblock.closeblock(Link(argscopy, graph.startblock))
            graph.startblock = newstartblock
            argnames = argnames + ['.star%d' % i for i in range(nb_extra_args)]
            graph.signature = Signature(argnames)
            # note that we can mostly ignore defaults: if nb_extra_args > 0, 
            # then defaults aren't applied.  if nb_extra_args == 0, then this 
            # just removes the *arg and the defaults keep their meaning.
            if nb_extra_args > 0:
                graph.defaults = None   # shouldn't be used in this case
            checkgraph(graph)
            return graph

        key = ('star', nb_extra_args)
        return flattened_s, key, builder

    else:
        return args_s, None, None

def default_specialize(funcdesc, args_s):
    # first flatten the *args
    args_s, key, builder = flatten_star_args(funcdesc, args_s)
    # two versions: a regular one and one for instances with 'access_directly'
    jit_look_inside = getattr(funcdesc.pyobj, '_jit_look_inside_', True)
    # change args_s in place, "official" interface
    access_directly = False
    for i, s_obj in enumerate(args_s):
        if (isinstance(s_obj, annmodel.SomeInstance) and
            'access_directly' in s_obj.flags):
            if jit_look_inside:
                access_directly = True
                key = (AccessDirect, key)
                break                
            else:
                new_flags = s_obj.flags.copy()
                del new_flags['access_directly']
                new_s_obj = annmodel.SomeInstance(s_obj.classdef, s_obj.can_be_None,
                                              flags = new_flags)
                args_s[i] = new_s_obj

    # done
    graph = funcdesc.cachedgraph(key, builder=builder)
    if access_directly:
        graph.access_directly = True
    return graph

class AccessDirect(object):
    """marker for specialization: set when any arguments is a SomeInstance
    which has the 'access_directly' flag set."""

def getuniquenondirectgraph(desc):
    result = []
    for key, graph in desc._cache.items():
        if (type(key) is tuple and len(key) == 2 and
            key[0] is AccessDirect):
            continue
        result.append(graph)
    assert len(result) == 1
    return result[0]
        

# ____________________________________________________________________________
# specializations

class MemoTable(object):
    def __init__(self, funcdesc, args, value):
        self.funcdesc = funcdesc
        self.table = {args: value}
        self.graph = None
        self.do_not_process = False

    def register_finish(self):
        bookkeeper = self.funcdesc.bookkeeper
        bookkeeper.pending_specializations.append(self.finish)

    def absorb(self, other):
        self.table.update(other.table)
        self.graph = None   # just in case
        other.do_not_process = True

    fieldnamecounter = 0

    def getuniquefieldname(self):
        name = self.funcdesc.name
        fieldname = '$memofield_%s_%d' % (name, MemoTable.fieldnamecounter)
        MemoTable.fieldnamecounter += 1
        return fieldname

    def finish(self):
        if self.do_not_process:
            return
        from pypy.annotation.model import unionof
        assert self.graph is None, "MemoTable already finished"
        # list of which argument positions can take more than one value
        example_args, example_value = self.table.iteritems().next()
        nbargs = len(example_args)
        # list of sets of possible argument values -- one set per argument index
        sets = [set() for i in range(nbargs)]
        for args in self.table:
            for i in range(nbargs):
                sets[i].add(args[i])

        bookkeeper = self.funcdesc.bookkeeper
        annotator = bookkeeper.annotator
        name = self.funcdesc.name
        argnames = ['a%d' % i for i in range(nbargs)]

        def make_helper(firstarg, stmt, miniglobals):
            header = "def f(%s):" % (', '.join(argnames[firstarg:],))
            source = py.code.Source(stmt)
            source = source.putaround(header)
            exec source.compile() in miniglobals
            f = miniglobals['f']
            return func_with_new_name(f, 'memo_%s_%d' % (name, firstarg))

        def make_constant_subhelper(firstarg, result):
            # make a function that just returns the constant answer 'result'
            f = make_helper(firstarg, 'return result', {'result': result})
            f.constant_result = result
            return f

        def make_subhelper(args_so_far=()):
            firstarg = len(args_so_far)
            if firstarg == nbargs:
                # no argument left, return the known result
                # (or a dummy value if none corresponds exactly)
                result = self.table.get(args_so_far, example_value)
                return make_constant_subhelper(firstarg, result)
            else:
                nextargvalues = list(sets[len(args_so_far)])
                if nextargvalues == [True, False]:
                    nextargvalues = [False, True]
                nextfns = [make_subhelper(args_so_far + (arg,))
                           for arg in nextargvalues]
                # do all graphs return a constant?
                try:
                    constants = [fn.constant_result for fn in nextfns]
                except AttributeError:
                    constants = None    # one of the 'fn' has no constant_result
                restargs = ', '.join(argnames[firstarg+1:])

                # is there actually only one possible value for the current arg?
                if len(nextargvalues) == 1:
                    if constants:   # is the result a constant?
                        result = constants[0]
                        return make_constant_subhelper(firstarg, result)
                    else:
                        # ignore the first argument and just call the subhelper
                        stmt = 'return subhelper(%s)' % restargs
                        return make_helper(firstarg, stmt,
                                           {'subhelper': nextfns[0]})

                # is the arg a bool?
                elif nextargvalues == [False, True]:
                    fieldname0 = self.getuniquefieldname()
                    fieldname1 = self.getuniquefieldname()
                    stmt = ['if %s:' % argnames[firstarg]]
                    if hasattr(nextfns[True], 'constant_result'):
                        # the True branch has a constant result
                        case1 = nextfns[True].constant_result
                        stmt.append('    return case1')
                    else:
                        # must call the subhelper
                        case1 = nextfns[True]
                        stmt.append('    return case1(%s)' % restargs)
                    stmt.append('else:')
                    if hasattr(nextfns[False], 'constant_result'):
                        # the False branch has a constant result
                        case0 = nextfns[False].constant_result
                        stmt.append('    return case0')
                    else:
                        # must call the subhelper
                        case0 = nextfns[False]
                        stmt.append('    return case0(%s)' % restargs)

                    return make_helper(firstarg, '\n'.join(stmt),
                                       {'case0': case0,
                                        'case1': case1})

                # the arg is a set of PBCs
                else:
                    descs = [bookkeeper.getdesc(pbc) for pbc in nextargvalues]
                    fieldname = self.getuniquefieldname()
                    stmt = 'return getattr(%s, %r)' % (argnames[firstarg],
                                                       fieldname)
                    if constants:
                        # instead of calling these subhelpers indirectly,
                        # we store what they would return directly in the
                        # pbc memo fields
                        store = constants
                    else:
                        store = nextfns
                        # call the result of the getattr()
                        stmt += '(%s)' % restargs

                    # store the memo field values
                    for desc, value_to_store in zip(descs, store):
                        desc.create_new_attribute(fieldname, value_to_store)

                    return make_helper(firstarg, stmt, {})

        entrypoint = make_subhelper(args_so_far = ())
        self.graph = annotator.translator.buildflowgraph(entrypoint)
        self.graph.defaults = self.funcdesc.defaults

        # schedule this new graph for being annotated
        args_s = []
        for arg_types in sets:
            values_s = [bookkeeper.immutablevalue(x) for x in arg_types]
            args_s.append(unionof(*values_s))
        annotator.addpendinggraph(self.graph, args_s)


def memo(funcdesc, arglist_s):
    from pypy.annotation.model import SomePBC, SomeImpossibleValue, SomeBool
    from pypy.annotation.model import unionof
    # call the function now, and collect possible results
    argvalues = []
    for s in arglist_s:
        if s.is_constant():
            values = [s.const]
        elif isinstance(s, SomePBC):
            values = []
            assert not s.can_be_None, "memo call: cannot mix None and PBCs"
            for desc in s.descriptions:
                if desc.pyobj is None:
                    raise Exception("memo call with a class or PBC that has no "
                                   "corresponding Python object (%r)" % (desc,))
                values.append(desc.pyobj)
        elif isinstance(s, SomeImpossibleValue):
            return s    # we will probably get more possible args later
        elif isinstance(s, SomeBool):
            values = [False, True]
        else:
            raise Exception("memo call: argument must be a class or a frozen "
                            "PBC, got %r" % (s,))
        argvalues.append(values)
    # the list of all possible tuples of arguments to give to the memo function
    possiblevalues = cartesian_product(argvalues)

    # a MemoTable factory -- one MemoTable per family of arguments that can
    # be called together, merged via a UnionFind.
    bookkeeper = funcdesc.bookkeeper
    try:
        memotables = bookkeeper.all_specializations[funcdesc]
    except KeyError:
        func = funcdesc.pyobj
        if func is None:
            raise Exception("memo call: no Python function object to call "
                            "(%r)" % (funcdesc,))

        def compute_one_result(args):
            value = func(*args)
            memotable = MemoTable(funcdesc, args, value)
            memotable.register_finish()
            return memotable

        memotables = UnionFind(compute_one_result)
        bookkeeper.all_specializations[funcdesc] = memotables

    # merge the MemoTables for the individual argument combinations
    firstvalues = possiblevalues.next()
    _, _, memotable = memotables.find(firstvalues)
    for values in possiblevalues:
        _, _, memotable = memotables.union(firstvalues, values)

    if memotable.graph is not None:
        return memotable.graph   # if already computed
    else:
        # otherwise, for now, return the union of each possible result
        return unionof(*[bookkeeper.immutablevalue(v)
                         for v in memotable.table.values()])

def cartesian_product(lstlst):
    if not lstlst:
        yield ()
        return
    for tuple_tail in cartesian_product(lstlst[1:]):
        for value in lstlst[0]:
            yield (value,) + tuple_tail


def make_constgraphbuilder(n, v=None, factory=None, srcmodule=None):
    def constgraphbuilder(translator, ignore):
        args = ','.join(["arg%d" % i for i in range(n)])
        if factory is not None:
            computed_v = factory()
        else:
            computed_v = v
        miniglobals = {'v': computed_v, '__name__': srcmodule}
        exec py.code.Source("constf = lambda %s: v" % args).compile() in miniglobals
        return translator.buildflowgraph(miniglobals['constf'])
    return constgraphbuilder

def maybe_star_args(funcdesc, key, args_s):
    args_s, key1, builder = flatten_star_args(funcdesc, args_s)
    if key1 is not None:
        key = key + key1
    return funcdesc.cachedgraph(key, builder=builder)
 
def specialize_argvalue(funcdesc, args_s, *argindices):
    from pypy.annotation.model import SomePBC
    key = []
    for i in argindices:
        s = args_s[i]
        if s.is_constant():
            key.append(s.const)
        elif isinstance(s, SomePBC) and len(s.descriptions) == 1:
            # for test_specialize_arg_bound_method
            desc, = s.descriptions
            key.append(desc)
        else:
            raise Exception("specialize:arg(%d): argument not constant: %r"
                            % (i, s))
    key = tuple(key)
    return maybe_star_args(funcdesc, key, args_s)

def specialize_arg_or_var(funcdesc, args_s, *argindices):
    for argno in argindices:
        if not args_s[argno].is_constant():
            break
    else:
        # all constant
        return specialize_argvalue(funcdesc, args_s, *argindices)
    # some not constant
    return maybe_star_args(funcdesc, None, args_s)

def specialize_argtype(funcdesc, args_s, *argindices):
    key = tuple([args_s[i].knowntype for i in argindices])
    for cls in key:
        try:
            assert '_must_specialize_' not in cls.classdesc.pyobj.__dict__, (
                "%s has the tag _must_specialize_" % (cls,))
        except AttributeError:
            pass
    return maybe_star_args(funcdesc, key, args_s)

def specialize_arglistitemtype(funcdesc, args_s, i):
    s = args_s[i]
    if s.knowntype is not list:
        key = None
    else:
        key = s.listdef.listitem.s_value.knowntype
    return maybe_star_args(funcdesc, key, args_s)

def specialize_call_location(funcdesc, args_s, op):
    assert op is not None
    return maybe_star_args(funcdesc, op, args_s)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.