Commits

Lenard Lindstrom  committed 96f9d29

Add basic JIT interpreter hints

The JIT hints include ``promote`` calls in an attempt to improve performance.

  • Participants
  • Parent commits 6bc964b

Comments (0)

Files changed (5)

 import cffi
 
+import os
+
 ffi = cffi.FFI()
-
 ffi.cdef("""
     typedef struct ArrayView {
         char typestr[4];
     void blit(ArrayView_t *destination, ArrayView_t *source);
     void rpython_startup_code();
 """)
-_blitter = ffi.dlopen('./libtesting.so')
+try:
+    os.environ['PYTHON_ONLY']
+except KeyError:
+    try:
+        _blitter = ffi.dlopen('./libtesting.so')
+    except OSError:
+        import blitter as _blitter
+        _blitter.blit = _blitter.blit_buffer
+    else:
+        _blitter.rpython_startup_code()
+else:
+    import blitter as _blitter
+    _blitter.blit = _blitter.blit_buffer
 
-_blitter.rpython_startup_code()
 
 def blit(destination, source):
     """Copy source array to destination

File blit_interpreter.py

 
 ## RPython
 
-from rpython.rtyper.lltypesystem.rffi import r_uchar
+try:
+    from rpython.rtyper.lltypesystem.rffi import r_uchar
+    from rpython.rlib.jit import JitDriver, promote
+except ImportError:
+    # Allow Python import without rpython package
+    def promote(v):
+        return v
+    def r_uchar(c):
+        return c
+    class JitDriver(object):
+        def __init__(self, *args, **kwds):
+            pass
+        def jit_merge_point(self, *args, **kwds):
+            pass
+        def can_enter_jit(self, *args, **kwds):
+            pass
 
 # The language:
 #
     
     return code
 
+jitdriver = JitDriver(greens=['pc', 'code'],
+                      reds=['doffset', 'soffset',
+                            'dst', 'src', 'shape',
+                            'advances', 'counters'])
+
 def execute_loop(code, dst, src, shape, dst_strides, src_strides):
     pc = 0
     ndim = code[pc]
     soffset = 0
 
     while True:
-        # jit merge point
+        jitdriver.jit_merge_point(pc=pc, code=code,
+                                  dst=dst, src=src, shape=shape,
+                                  doffset=doffset, soffset=soffset,
+                                  advances=advances, counters=counters)
         c = code[pc]
         pc += 1
         if c == 1:
-            i = code[pc]
+            i = promote(code[pc])
             pc += 1
-            p = code[pc]
+            p = promote(code[pc])
             pc += 1
             if counters[i] == 0:
                 pc = p
             else:
                 counters[i] -= 1
         elif c == 2:
-            # can enter jit
-            p = code[pc]
+            p = promote(code[pc])
             pc += 1
             pc = p
+            jitdriver.can_enter_jit(pc=pc, code=code,
+                                    dst=dst, src=src, shape=shape,
+                                    doffset=doffset, soffset=soffset,
+                                    advances=advances, counters=counters)
         elif c == 3:
-            s = code[pc]
+            s = promote(code[pc])
             pc += 1
             doffset += s
         elif c == 4:
-            s = code[pc]
+            s = promote(code[pc])
             pc += 1
             soffset += s
         elif c == 5:
-            i = code[pc]
+            i = promote(code[pc])
             pc += 1
             doffset += advances[i]
         elif c == 6:
-            i = code[pc]
+            i = promote(code[pc])
             pc += 1
             soffset += advances[i]
         elif c == 7:
-            sd = code[pc]
+            sd = promote(code[pc])
             pc += 1
-            ss = code[pc]
+            ss = promote(code[pc])
             pc += 1
             dst[doffset + sd] = src[soffset + ss]
         elif c == 8:
-            sd = code[pc]
+            sd = promote(code[pc])
             pc += 1
             dst[doffset + sd] = r_uchar(0)
         elif c == 9:
-            i = code[pc]
+            i = promote(code[pc])
             pc += 1
             counters[i] = shape[i]
         elif c == 0:
 """Export blit_buffer"""
-# Not RPython
-
-from rpython.rlib.entrypoint import entrypoint_lowlevel
-from rpython.translator.tool.cbuild import ExternalCompilationInfo
-from array_view import ARRAYVIEW_P
 from blit_interpreter import compile_loop, execute_loop
 
-import os
+def add_field_getters(fields, prefix='c_'):
+    for name in fields:
+        exec """\
+def get_%s_field(rec):
+    return rec.%s%s
+""" % (name, prefix, name) in globals()
 
-@entrypoint_lowlevel('main', [ARRAYVIEW_P, ARRAYVIEW_P], c_name='blit')
+fields = ['ndim', 'typestr', 'shape', 'strides', 'data']
+
+
+## RPython
+
+code_cache = {}
+
 def blit_buffer(destination, source):
     """Copy source array onto destination
 
 
 !For now, ignore integer sign. Truncate if necessary.
 """
-    ndim = destination.c_ndim
-    dst_type = destination.c_typestr
-    dst_shape = destination.c_shape
-    dst_strides = destination.c_strides
-    dst_data = destination.c_data
-    src_type = source.c_typestr
-    src_shape = source.c_shape
-    src_strides = source.c_strides
-    src_data = source.c_data
+    ndim = get_ndim_field(destination)
+    dst_type = get_typestr_field(destination)
+    dst_shape = get_shape_field(destination)
+    dst_strides = get_strides_field(destination)
+    dst_data = get_data_field(destination)
+    src_type = get_typestr_field(source)
+    src_shape = get_shape_field(source)
+    src_strides = get_strides_field(source)
+    src_data = get_data_field(source)
     __, dst_bytesize, dst_lil_endian = decode_int_type(dst_type)
     __, src_bytesize, src_lil_endian = decode_int_type(src_type)
     if src_lil_endian:
     else:
         offset_to = max(dst_bytesize - len(src_offsets), 0) - 1
         dst_offsets = [i for i in range(dst_bytesize - 1, offset_to, -1)]
-    code = compile_loop(ndim,
-                        dst_bytesize, dst_offsets,
-                        src_bytesize, src_offsets)
+    key = chr(ndim) + chr(dst_bytesize) + chr(src_bytesize)
+    for i in dst_offsets:
+        key += chr(i)
+    for i in src_offsets:
+        key += chr(i)
+    try:
+        code = code_cache[key]
+    except KeyError:
+        code = compile_loop(ndim,
+                            dst_bytesize, dst_offsets,
+                            src_bytesize, src_offsets)
+        code_cache[key] = code
     execute_loop(code, dst_data, src_data,
                  dst_shape, dst_strides, src_strides)
 
-blit_buffer._compilation_info.includes += 'arrview.h',
-blit_buffer._compilation_info.include_dirs += os.path.abspath(os.getcwd()),
-
 def decode_int_type(typestr):
     """Return (is_signed, size, is_lil_endian)"""
     order = typestr.item0
     size = typestr.item2
     return signed == 'i', '0123456789'.find(size), order == '<'
 
+add_field_getters(fields)
+
+# not RPython
+
+import os
+
+python_only = False
+
+try:
+    os.environ['PYTHON_ONLY']
+except KeyError:
+    try:
+        import rpython
+    except ImportError:
+        python_only = True
+else:
+    python_only = True
+
+if python_only:
+    def decode_int_type(typestr):
+        """Return (is_signed, size, is_lil_endian)"""
+        order, signed, size = typestr[0:3]
+        return signed == 'i', '0123456789'.find(size), order == '<'
+
+    add_field_getters(fields, '')
+
+del fields, add_field_getters, python_only

File targetblittershared.py

-from blitter import blit_buffer, ARRAYVIEW_P
+"""Blit shared library RPython target.
+
+To build:
+
+pypy <pypy-dir-path>/rpython/bin/rpython -Ojit --gcrootfinder=shadowstack \
+     --shared targetblittershared.py
+
+"""
+
+from blitter import blit_buffer
+from array_view import ARRAYVIEW_P
+
+from rpython.rlib.entrypoint import entrypoint_lowlevel
+
+import os
+
+
+@entrypoint_lowlevel('main', [ARRAYVIEW_P, ARRAYVIEW_P], c_name='blit')
+def blit(destination, source):
+    blit_buffer(destination, source)
+blit._compilation_info.includes += 'arrview.h',
+blit._compilation_info.include_dirs += os.path.abspath(os.getcwd()),
 
 def target(driver, args):
-    driver.exe_name = 'libblitter'
-    return blit_buffer, [ARRAYVIEW_P, ARRAYVIEW_P]
+    driver.extmod_name = 'blitter'
+    return blit, [ARRAYVIEW_P, ARRAYVIEW_P]
+
+def jitpolicy(driver):
+    from rpython.jit.codewriter.policy import JitPolicy
+    return JitPolicy()

File test_blitter.py

     dst.fill(-1)
     blit(dst, src)
     assert (dst == src & 0xFFFF).all()
+
+def test_big_row():
+    """Bug check"""
+
+    src = np.arange(10000, dtype='<u4')
+    src.shape = 10, src.size // 10
+    dst = np.zeros(src.shape, dtype='<u4')
+    blit(dst, src)
+    assert (dst == src).all()