Source

pypy / pypy / jit / metainterp / test / test_warmspot.py

Full commit
import py
from pypy.jit.metainterp.warmspot import get_stats
from pypy.rlib.jit import JitDriver, set_param, unroll_safe
from pypy.jit.backend.llgraph import runner

from pypy.jit.metainterp.test.support import LLJitMixin, OOJitMixin
from pypy.jit.metainterp.optimizeopt import ALL_OPTS_NAMES


class Exit(Exception):
    def __init__(self, result):
        self.result = result


class WarmspotTests(object):

    def test_basic(self):
        mydriver = JitDriver(reds=['a'],
                             greens=['i'])
        CODE_INCREASE = 0
        CODE_JUMP = 1
        lst = [CODE_INCREASE, CODE_INCREASE, CODE_JUMP]
        def interpreter_loop(a):
            i = 0
            while True:
                mydriver.jit_merge_point(i=i, a=a)
                if i >= len(lst):
                    break
                elem = lst[i]
                if elem == CODE_INCREASE:
                    a = a + 1
                    i += 1
                elif elem == CODE_JUMP:
                    if a < 20:
                        i = 0
                        mydriver.can_enter_jit(i=i, a=a)
                    else:
                        i += 1
                else:
                    pass
            raise Exit(a)

        def main(a):
            try:
                interpreter_loop(a)
            except Exit, e:
                return e.result

        res = self.meta_interp(main, [1])
        assert res == 21

    def test_reentry(self):
        mydriver = JitDriver(reds = ['n'], greens = [])

        def f(n):
            while n > 0:
                mydriver.can_enter_jit(n=n)
                mydriver.jit_merge_point(n=n)
                if n % 20 == 0:
                    n -= 2
                n -= 1

        res = self.meta_interp(f, [60])
        assert res == f(30)

    def test_location(self):
        def get_printable_location(n):
            return 'GREEN IS %d.' % n
        myjitdriver = JitDriver(greens=['n'], reds=['m'],
                                get_printable_location=get_printable_location)
        def f(n, m):
            while m > 0:
                myjitdriver.can_enter_jit(n=n, m=m)
                myjitdriver.jit_merge_point(n=n, m=m)
                m -= 1

        self.meta_interp(f, [123, 10])
        assert len(get_stats().locations) >= 4
        for loc in get_stats().locations:
            assert loc == (0, 0, 123)

    def test_set_param_enable_opts(self):
        from pypy.rpython.annlowlevel import llstr, hlstr

        myjitdriver = JitDriver(greens = [], reds = ['n'])
        class A(object):
            def m(self, n):
                return n-1

        def g(n):
            while n > 0:
                myjitdriver.can_enter_jit(n=n)
                myjitdriver.jit_merge_point(n=n)
                n = A().m(n)
            return n
        def f(n, enable_opts):
            set_param(None, 'enable_opts', hlstr(enable_opts))
            return g(n)

        # check that the set_param will override the default
        res = self.meta_interp(f, [10, llstr('')])
        assert res == 0
        self.check_resops(new_with_vtable=1)

        res = self.meta_interp(f, [10, llstr(ALL_OPTS_NAMES)],
                               enable_opts='')
        assert res == 0
        self.check_resops(new_with_vtable=0)

    def test_unwanted_loops(self):
        mydriver = JitDriver(reds = ['n', 'total', 'm'], greens = [])

        def loop1(n):
            # the jit should not look here, as there is a loop
            res = 0
            for i in range(n):
                res += i
            return res

        @unroll_safe
        def loop2(n):
            # the jit looks here, due to the decorator
            for i in range(5):
                n += 1
            return n

        def f(m):
            total = 0
            n = 0
            while n < m:
                mydriver.can_enter_jit(n=n, total=total, m=m)
                mydriver.jit_merge_point(n=n, total=total, m=m)
                total += loop1(n)
                n = loop2(n)
            return total
        self.meta_interp(f, [50])
        self.check_enter_count_at_most(2)

    def test_wanted_unrolling_and_preinlining(self):
        mydriver = JitDriver(reds = ['n', 'm'], greens = [])

        @unroll_safe
        def loop2(n):
            # the jit looks here, due to the decorator
            for i in range(5):
                n += 1
            return n
        loop2._always_inline_ = True

        def g(n):
            return loop2(n)
        g._dont_inline_ = True

        def f(m):
            n = 0
            while n < m:
                mydriver.can_enter_jit(n=n, m=m)
                mydriver.jit_merge_point(n=n, m=m)
                n = g(n)
            return n
        self.meta_interp(f, [50], backendopt=True)
        self.check_enter_count_at_most(2)
        self.check_resops(call=0)

    def test_loop_header(self):
        # artificial test: we enter into the JIT only when can_enter_jit()
        # is seen, but we close a loop in the JIT much more quickly
        # because of loop_header().
        mydriver = JitDriver(reds = ['n', 'm'], greens = [])

        def f(m):
            n = 0
            while True:
                mydriver.jit_merge_point(n=n, m=m)
                if n > m:
                    m -= 1
                    if m < 0:
                        return n
                    n = 0
                    mydriver.can_enter_jit(n=n, m=m)
                else:
                    n += 1
                    mydriver.loop_header()
        assert f(15) == 1
        res = self.meta_interp(f, [15], backendopt=True)
        assert res == 1
        self.check_resops(int_add=2)   # I get 13 without the loop_header()

    def test_omit_can_enter_jit(self):
        # Simple test comparing the effects of always giving a can_enter_jit(),
        # or not giving any.  Mostly equivalent, except that if given, it is
        # ignored the first time, and so it ends up taking one extra loop to
        # start JITting.
        mydriver = JitDriver(greens=[], reds=['m'])
        #
        for i2 in range(10):
            def f2(m):
                while m > 0:
                    mydriver.jit_merge_point(m=m)
                    m -= 1
            self.meta_interp(f2, [i2])
            try:
                self.check_jitcell_token_count(1)
                break
            except AssertionError:
                print "f2: no loop generated for i2==%d" % i2
        else:
            raise     # re-raise the AssertionError: check_loop_count never 1
        #
        for i1 in range(10):
            def f1(m):
                while m > 0:
                    mydriver.can_enter_jit(m=m)
                    mydriver.jit_merge_point(m=m)
                    m -= 1
            self.meta_interp(f1, [i1])
            try:
                self.check_jitcell_token_count(1)
                break
            except AssertionError:
                print "f1: no loop generated for i1==%d" % i1
        else:
            raise     # re-raise the AssertionError: check_loop_count never 1
        #
        assert i1 - 1 == i2

    def test_no_loop_at_all(self):
        mydriver = JitDriver(greens=[], reds=['m'])
        def f2(m):
            mydriver.jit_merge_point(m=m)
            return m - 1
        def f1(m):
            while m > 0:
                m = f2(m)
        self.meta_interp(f1, [8])
        # it should generate one "loop" only, which ends in a FINISH
        # corresponding to the return from f2.
        self.check_trace_count(1)
        self.check_resops(jump=0)

    def test_simple_loop(self):
        mydriver = JitDriver(greens=[], reds=['m'])
        def f1(m):
            while m > 0:
                mydriver.jit_merge_point(m=m)
                m = m - 1
        self.meta_interp(f1, [8])
        self.check_trace_count(1)
        self.check_resops({'jump': 1, 'guard_true': 2, 'int_gt': 2,
                           'int_sub': 2})

    def test_void_red_variable(self):
        mydriver = JitDriver(greens=[], reds=['a', 'm'])
        def f1(m):
            a = None
            while m > 0:
                mydriver.jit_merge_point(a=a, m=m)
                m = m - 1
                if m == 10:
                    pass   # other case
        self.meta_interp(f1, [18])

    def test_bug_constant_rawptrs(self):
        py.test.skip("crashes because a is a constant")
        from pypy.rpython.lltypesystem import lltype, rffi
        mydriver = JitDriver(greens=['a'], reds=['m'])
        def f1(m):
            a = lltype.nullptr(rffi.VOIDP.TO)
            while m > 0:
                mydriver.jit_merge_point(a=a, m=m)
                m = m - 1
        self.meta_interp(f1, [18])

    def test_bug_rawptrs(self):
        from pypy.rpython.lltypesystem import lltype, rffi
        mydriver = JitDriver(greens=['a'], reds=['m'])
        def f1(m):
            a = lltype.malloc(rffi.VOIDP.TO, 5, flavor='raw')
            while m > 0:
                mydriver.jit_merge_point(a=a, m=m)
                m = m - 1
                if m == 10:
                    pass
            lltype.free(a, flavor='raw')
        self.meta_interp(f1, [18])


class TestLLWarmspot(WarmspotTests, LLJitMixin):
    CPUClass = runner.LLtypeCPU
    type_system = 'lltype'

class TestOOWarmspot(WarmspotTests, OOJitMixin):
    ##CPUClass = runner.OOtypeCPU
    type_system = 'ootype'

class TestWarmspotDirect(object):
    def setup_class(cls):
        from pypy.jit.metainterp.typesystem import llhelper
        from pypy.jit.codewriter.support import annotate
        from pypy.jit.metainterp.warmspot import WarmRunnerDesc
        from pypy.rpython.lltypesystem.rclass import OBJECT, OBJECT_VTABLE
        from pypy.rpython.lltypesystem import lltype, llmemory
        exc_vtable = lltype.malloc(OBJECT_VTABLE, immortal=True)
        cls.exc_vtable = exc_vtable

        class FakeFailDescr(object):
            def __init__(self, no):
                self.no = no
            def handle_fail(self, metainterp_sd, jitdrivers_sd):
                no = self.no
                if no == 0:
                    raise metainterp_sd.warmrunnerdesc.DoneWithThisFrameInt(3)
                if no == 1:
                    raise metainterp_sd.warmrunnerdesc.ContinueRunningNormally(
                        [0], [], [], [1], [], [])
                if no == 3:
                    exc = lltype.malloc(OBJECT)
                    exc.typeptr = exc_vtable
                    raise metainterp_sd.warmrunnerdesc.ExitFrameWithExceptionRef(
                        metainterp_sd.cpu,
                        lltype.cast_opaque_ptr(llmemory.GCREF, exc))
                assert 0

        class FakeDescr:
            def as_vtable_size_descr(self):
                return self

        class FakeCPU(object):
            supports_floats = False
            supports_longlong = False
            supports_singlefloats = False
            ts = llhelper
            translate_support_code = False
            stats = "stats"

            def get_fail_descr_number(self, d):
                return -1

            def __init__(self, *args, **kwds):
                pass

            def nodescr(self, *args, **kwds):
                return FakeDescr()
            fielddescrof = nodescr
            calldescrof  = nodescr
            sizeof       = nodescr

            def get_fail_descr_from_number(self, no):
                return FakeFailDescr(no)

            def make_execute_token(self, *ARGS):
                return "not callable"

        driver = JitDriver(reds = ['red'], greens = ['green'])

        def f(green):
            red = 0
            while red < 10:
                driver.can_enter_jit(red=red, green=green)
                driver.jit_merge_point(red=red, green=green)
                red += 1
            return red

        rtyper = annotate(f, [0])
        FakeCPU.rtyper = rtyper
        translator = rtyper.annotator.translator
        translator.config.translation.gc = 'hybrid'
        cls.desc = WarmRunnerDesc(translator, CPUClass=FakeCPU)

    def test_call_helper(self):
        from pypy.rpython.llinterp import LLException

        [jd] = self.desc.jitdrivers_sd
        assert jd._assembler_call_helper(0, 0) == 3
        assert jd._assembler_call_helper(1, 0) == 10
        try:
            jd._assembler_call_helper(3, 0)
        except LLException, lle:
            assert lle[0] == self.exc_vtable
        else:
            py.test.fail("DID NOT RAISE")