Source

blitter / blit_interpreter.py

"""Array blit interpreter

exports compile_loop, execute_loop"""

## RPython

try:
    from rpython.rtyper.lltypesystem.rffi import r_uchar
    from rpython.rlib.jit import JitDriver
except ImportError:
    # Allow Python import without rpython package
    def r_uchar(c):
        return c
    class JitDriver(object):
        def __init__(self, *args, **kwds):
            pass
        def jit_merge_point(self, *args, **kwds):
            pass
        def can_enter_jit(self, *args, **kwds):
            pass

# The language:
#
# Registers:
#   constant:
#     source            : From buffer address ((char *) constant)
#     destination       : To buffer address ((char *) constant)
#     advances[ndim*2] : Stride differences
#   variable:
#     program_counter   : Code index
#     counters[ndim]    : Loop counters
#     doffset           : Memory address offset from destination
#     soffset           : Memory address offset from source
#
# Instruction set:
#   1                  : stop
#   2 index position   : if counter[index] == 0:
#                            program_counter = position
#                        else:
#                            counters[index] -= 1;
#   3 position         : program_counter = position;
#   4 size             : doffset += size
#   5 size             : soffset += size
#   6 index            : doffset += advances[index]
#   7 index            : soffset += advances[index]
#   8 size_d size_s    : (destination + doffset + size_d) =
#                        (source + soffset + size_s)
#   9 size             : (destination + doffset + size) = 0
#  10 index            : counters[i] = shape[i]
#
#   All instruction arguments are integer constants.
#
def set_globals(ops):
    for name, (opcode, __) in ops.opcodes.items():
        globals()[name] = opcode

## RPython

class Operations(object):
    def __init__(self, **kwds):
        self.opcodes = {}
        self.names = {}
        i = 1
        for name, nargs in kwds.items():
            opcode = chr(i)
            i += 1
            self.opcodes[name] = opcode, nargs
            self.names[opcode] = name, nargs

    def get_name(self, c):
        return self.names[c]

def disassemble_op(pc, code):
    op = code[pc]
    name, nargs = ops.get_name(op)
    if nargs == 0:
        return "%d:%s" % (pc, name)
    elif nargs == 1:
        return "%d:%s %d" % (pc, name, ord(code[pc + 1]))
    return "%d:%s %d %d" %  (pc, name, ord(code[pc + 1]), ord(code[pc + 2]))

ops = Operations(LOOP=1,
                 NEXT=1,
                 DADV=0,
                 SADV=0,
                 MOVE=2,
                 DZERO=1,
                 CSET=0)

set_globals(ops)


def compile_loop(ndim, dst_bytesize, dst_offsets, src_bytesize, src_offsets):
    """Compile a blit loop"""

    # loop start positions
    loop_stack = [0] * ndim

    # How many bytes per element to copy
    noffsets = min(dst_bytesize, src_bytesize)

    #!! Should be able to calculate code size from arguments.
    # First code entry is the number of dimensions, not an instruction
    code = [chr(ndim & 0xFF)]

    # Add loop tests
    for i in range(ndim):
        code.append(CSET)
        loop_stack[i] = len(code)
        code.append(LOOP)
        code.append('\x00')

    # Add byte copy code
    for b in range(noffsets):
        code.append(MOVE)
        code.append(chr(dst_offsets[b]))
        code.append(chr(src_offsets[b]))

    # Zero any extra destination bytes
    if dst_bytesize > noffsets:
        for i in range(dst_bytesize):
            for b in range(noffsets):
                if dst_offsets[b] == i:
                    break
            else:
                code.append(DZERO)
                code.append(chr(i))

    # Add loop nexts
    for i in range(ndim - 1, -1, -1):
        code.append(DADV)
        code.append(SADV)
        code.append(NEXT)
        code.append(chr(loop_stack[i]))
        code[loop_stack[i] + 1] = chr(len(code))

    return ''.join(code)

jitdriver = JitDriver(greens=['pc', 'code'],
                      reds=['ndim', 'dim', 'counter',
                            'dadvance', 'sadvance',
                            'doffset', 'soffset',
                            'dst', 'src', 'shape',
                            'advances', 'counters'],
                      get_printable_location=disassemble_op)

def execute_loop(code, dst, src, shape, dst_strides, src_strides):
    assert isinstance(code, str)
    pc = 0
    ndim = ord(code[pc])
    if ndim < 1:
        return
    pc += 1
    advances = [0] * (2 * ndim)
    for dim in range(0, ndim - 1):
        dincr = dim + 1
        sz = shape[dincr]
        advances[dim] = dst_strides[dim] - sz * dst_strides[dincr]
        advances[ndim + dim] = src_strides[dim] - sz * src_strides[dincr]
    advances[ndim - 1] = dst_strides[ndim - 1]
    advances[2 * ndim - 1] = src_strides[ndim - 1]
    counters = [0] * ndim
    doffset = 0
    soffset = 0
    dim = -1
    counter = 0
    advance = 0
    dadvance = 0
    sadvance = 0

    while True:
        jitdriver.jit_merge_point(pc=pc, code=code, ndim=ndim,
                                  dst=dst, src=src, shape=shape,
                                  doffset=doffset, soffset=soffset,
                                  dim=dim, counter=counter,
                                  dadvance=dadvance, sadvance=sadvance,
                                  advances=advances, counters=counters)
        op = code[pc]
        pc += 1
        if op == MOVE:
            sd = ord(code[pc])
            pc += 1
            ss = ord(code[pc])
            pc += 1
            dst[doffset + sd] = src[soffset + ss]
        elif op == DZERO:
            sd = ord(code[pc])
            pc += 1
            dst[doffset + sd] = r_uchar(0)
        elif op == LOOP:
            p = ord(code[pc])
            pc += 1
            if counter == 0:
                if dim == 0:
                    break
                counter = counters[dim]
                dim -= 1
                dadvance = advances[dim]
                sadvance = advances[ndim + dim]
                pc = p
            else:
                counter -= 1
        elif op == NEXT:
            pc = ord(code[pc])
            jitdriver.can_enter_jit(pc=pc, code=code, ndim=ndim,
                                    dst=dst, src=src, shape=shape,
                                    doffset=doffset, soffset=soffset,
                                    dim=dim, counter=counter,
                                    dadvance=dadvance, sadvance=sadvance,
                                    advances=advances, counters=counters)
        elif op == DADV:
            doffset += dadvance
        elif op == SADV:
            soffset += sadvance
        elif op == CSET:
            dim += 1
            counters[dim] = counter
            counter = shape[dim]
            dadvance = advances[dim]
            sadvance = advances[ndim + dim]
        else:
            raise RuntimeError("Unknown bytecode %d at %d" % (ord(op), pc - 1))