Source

hotpy_2 / HotPy / gen_format_code.py

#! ./python -S
import collections

def get_formats_from_file(file):
    lines = file.readlines()
    lines = [ l for l in lines if l[0] != '#']
    formats = { ''.join(l.split()[1:]) for l in lines }
    formats.discard('')
    return formats

def verify_format(fmt):
    if fmt != ''.join(reversed(sorted(fmt))):
        raise Exception("Format %s is illegal; should be %s" %
                        (fmt, ''.join(reversed(sorted(fmt)))))
    if fmt.count('e') + fmt.count('E') > 1:
        raise Exception("Format %s is illegal; too many exits" % fmt)

def write_word(layout, parameters, outfile):
    print ('    write_word_%s(x, %s);' % (layout, ', '.join(parameters)), file = outfile)

def write_function(fmt, outfile):
    args = 'HotPyOptimiser *x, int op'
    consts = 0
    for letter in fmt:
        if letter in 'kK':
            args += ', int k%d' % consts
            consts += 1
        if letter in 'eE':
            args += ', int exit'
    print ('static void\nwrite_%s(%s)\n{' % (fmt, args), file = outfile)
    inputs = fmt.count('r')
    outputs = fmt.count('o')
    peek = 'p' in fmt
    for i in reversed(range(inputs)):
        print ('    int r%d = pop_as_register(x, %d);' % (i+peek, i+peek), file = outfile)
    if peek:
        assert fmt.count('p') == 1
        print ('    int r0 = peek_as_register(x);', file = outfile)
    for i in range(outputs):
        print ('    int out%d;' % i, file = outfile)
    consts = 0
    for letter in fmt:
        if letter in 'kK':
            print ('    assert(k%d >= 0);' % consts, file = outfile)
            consts += 1
        if letter in 'eE':
            print ('    assert(exit >= 0);', file = outfile)
    if 'e' in fmt or 'E' in fmt:
        print ('    flush_stack(x);', file = outfile)
    for i in range(outputs):
        print ('    out%d = choose_register(x);' % i, file = outfile)
    print ('    assert(consistent_format(op, "%s"));' % fmt, file = outfile)
    size = 1
    layout = '1'
    parameters = [ 'op' ]
    regs = 0
    consts = 0
    outs = 0
    for letter in fmt:
        if letter < 'a' and (size & 1):
            size += 1
            layout += '1'
            parameters.append('0')
        if size == 4:
            write_word(layout, parameters, outfile)
            layout = ''
            parameters = []
            size = 0
        if letter < 'a':
            layout += '2'
            size += 2
        else:
            layout += '1'
            size += 1
        if letter in 'rp':
            parameters.append('r%d' % regs)
            regs += 1
        elif letter == 'o':
            parameters.append('out%d' % outs)
            outs += 1
        elif letter in 'kK':
            parameters.append('k%d' % consts)
            consts += 1
        elif letter in 'eE':
            parameters.append('exit')
    assert size > 0
    assert sum(int(l) for l in layout) == size
    #Complete layout
    if layout == '2':
        layout = '22'
        parameters.append('0')
    else:
        while size < 4:
            size += 1
            parameters.append('0')
            layout += '1'
    write_word(layout, parameters, outfile)
    for i in range(outputs):
        print ('    push_register(x, out%d);' % i, file = outfile)
    print ('}\n', file = outfile)

def defuses_function(fmt, outfile):
    print('static uint32_t *\ndefuses_for_%s(uint32_t '
          '*next_instr, int *defs, int *uses)\n{' % fmt, file = outfile)
    print('    uint32_t instruction_word = *next_instr++;', file = outfile)
    print('    FORMAT_%s;' % fmt, file = outfile)
    regs = 0
    consts = 0
    for letter in fmt:
        if letter in 'rop':
            if letter == 'o':
                print('    *defs++ = r%d;' % regs, file = outfile)
            else:
                print('    *uses++ = r%d;' % regs, file = outfile)
            regs += 1
        elif letter in 'kK':
            #Avoid unused variable warning
            print('    (void)k%d;' % consts, file = outfile)
            consts += 1
        elif letter in 'eE':
            #Avoid unused variable warning
            print('    (void)e0;', file = outfile)
    print('    *defs++ = -1;', file = outfile)
    print('    *uses++ = -1;', file = outfile)
    print('    return next_instr;', file = outfile)
    print('}\n', file = outfile)

def get_exit_function(fmt, outfile):
    def emit(x):
        print(x, file = outfile)
    emit('static HotPyContext *')
    emit('get_exit_context_%s(uint32_t *instr, PyObject *exits)\n{' % fmt)
    if 'e' in fmt or 'E' in fmt:
        emit('    uint32_t instruction_word = *instr++;')
        emit('    int e0;')
        emit('    PyObject *exit;')
        offset = 1
        instruction_len = 1
        for letter in fmt:
            if letter < 'a' and (offset & 1):
                offset += 1
            if offset == 4:
                emit('    instruction_word = *instr++;')
                offset = 0
                instruction_len += 1
            if letter == 'E':
                print_read_short('e0', offset, outfile)
            if letter == 'e':
                print_read_byte('e0', offset, outfile)
            offset += 2 if letter < 'a' else 1
        emit('    exit = PyTuple_GetItem(exits, e0);')
        emit('    return ((HotPyExitObject *)exit)->exit_context;')
    else:
        emit('   return NULL;')
    emit('}\n')

def relabel_uses_function(fmt, outfile):
    def emit(x):
        print(x, file = outfile)
    emit('static uint32_t *\nrelabel_uses_%s(uint32_t '
          '*next_instr, unsigned char *relabel_table)\n{' % fmt)
    if 'r' in fmt or 'p' in fmt:
        emit('    int reg;')
    emit('    uint32_t instruction_word = *next_instr++;')
    emit('    (void)instruction_word; /* Stop compiler complaining */')
    offset = 1
    instruction_len = 1
    for letter in fmt:
        if letter < 'a' and (offset & 1):
            offset += 1
        if offset == 4:
            emit('    next_instr[-1] = instruction_word;')
            emit('    instruction_word = *next_instr++;')
            offset = 0
            instruction_len += 1
        if letter in 'rp':
            print_read_byte('reg', offset, outfile)
            print_write_byte('relabel_table[reg]', offset, outfile)
        offset += 2 if letter < 'a' else 1
    emit('    next_instr[-1] = instruction_word;')
    emit('    return next_instr;')
    emit('}\n')

def relabel_defs_function(fmt, outfile):
    def emit(x):
        print(x, file = outfile)
    emit('static uint32_t *\nrelabel_defs_%s(uint32_t '
          '*next_instr, unsigned char *relabel_table)\n{' % fmt)
    if 'o' in fmt:
        emit('    int reg;')
    emit('    uint32_t instruction_word = *next_instr++;')
    emit('    (void)instruction_word; /* Stop compiler complaining */')
    offset = 1
    instruction_len = 1
    for letter in fmt:
        if letter < 'a' and (offset & 1):
            offset += 1
        if offset == 4:
            emit('    next_instr[-1] = instruction_word;')
            emit('    instruction_word = *next_instr++;')
            offset = 0
            instruction_len += 1
        if letter == 'o':
            print_read_byte('reg', offset, outfile)
            print_write_byte('relabel_table[reg]', offset, outfile)
        offset += 2 if letter < 'a' else 1
    emit('    next_instr[-1] = instruction_word;')
    emit('    return next_instr;')
    emit('}\n')

def print_read_byte(var, offset, outfile):
    if offset == 0:
        print('    %s = instruction_word & 255; \\' % var, file = outfile)
    elif offset == 3:
        print('    %s = (instruction_word >> %d); \\' % (var, 8*offset), file = outfile)
    else:
        print('    %s = (instruction_word >> %d) & 255; \\' % (var, 8*offset), file = outfile)

def print_write_byte(var, offset, outfile):
    mask = '0x' + 'ff' * (3-offset) + '00' + 'ff' * offset
    if offset == 0:
        print('    instruction_word = (instruction_word & %s) | %s;' %
             (mask, var), file = outfile)
    else:
        print('    instruction_word = (instruction_word & %s) | (%s << %d);' %
              (mask, var, offset * 8), file = outfile)

def print_read_short(var, offset, outfile):
    assert (offset & 1) == 0
    if offset == 0:
        print ('    %s = instruction_word & ((1 << 16)-1); \\' % var, file = outfile)
    else:
        print ('    %s = instruction_word >> 16; \\' % var, file = outfile)

def read_instruction(fmt, outfile):
    regs = 0
    consts = 0
    verify_format(fmt)
    print('#define FORMAT_%s \\' % fmt, file = outfile)
    prefix = '    int'
    for letter in fmt:
        if letter in 'rop':
            print (prefix, 'r%d' % regs, end = '', file = outfile)
            regs += 1
        elif letter in 'kK':
            print (prefix, 'k%d' % consts, end = '', file = outfile)
            consts += 1
        elif letter in 'eE':
            print (prefix, 'e0', end = '', file = outfile)
        else:
            raise Exception("Illegal format letter " + letter)
        prefix = ','
    print ('; \\', file = outfile)
    offset = 1
    regs = 0
    consts = 0
    instruction_len = 1
    for letter in fmt:
        if letter < 'a' and (offset & 1):
            offset += 1
        if offset == 4:
            print('    instruction_word = *next_instr++; \\', file = outfile)
            offset = 0
            instruction_len += 1
        if letter in 'rop':
            print_read_byte('r%d' % regs, offset, outfile)
            regs += 1
        elif letter == 'e':
            print_read_byte('e0', offset, outfile)
        elif letter == 'E':
            print_read_short('e0', offset, outfile)
        elif letter == 'k':
            print_read_byte('k%d' % consts, offset, outfile)
            consts += 1
        elif letter == 'K':
            print_read_short('k%d' % consts, offset, outfile)
            consts += 1
        offset += 2 if letter < 'a' else 1
    print('    assert(consistent_format(next_instr[-%d] & 255, "%s"))\n' %
          (instruction_len, fmt), file = outfile)

def print_function(fmt, outfile):
    print('void\nprint_%s(FILE* out, uint32_t **instr_ptr)\n{' % fmt, file = outfile)
    print('    uint32_t *next_instr = *instr_ptr;', file = outfile)
    print('    uint32_t instruction_word = *next_instr++;', file = outfile)
    print('    int op = instruction_word & 255;', file = outfile)
    print('    FORMAT_%s;' % fmt, file = outfile)
    print('    *instr_ptr = next_instr;', file = outfile)
    regs = 0
    consts = 0
    format_str = '%s'
    args = '_HotPy_Instruction_Names[op]'
    for letter in fmt:
        if letter in 'rop':
            format_str += ' %s%%d' % letter
            args += ', r%d' % regs
            regs += 1
        elif letter in 'kK':
            format_str += ' %s%%d' % letter
            args += ', k%d' % consts
            consts += 1
        elif letter in 'eE':
            format_str += ' exit%d'
            args += ', e0'
    print('    fprintf(out, "%s\\n", %s);' % (format_str, args), file = outfile)
    print('}\n', file = outfile)


formats = [
    'K', 'kkK', 'r', 'rr', 'rrr', 'E', 'rrK',
    'o', 'oo', 'ro', 'rro', 'rrro', 'oK', 'rK', 'rrk',
    'rKE', 'rKKE', 'roK', 'rroK', 'rok', 'rrkkk',
    'rE', 'rrE', 'rroE', 'rrroE', 'rokk', 'rokE',
    'rokkkE', 'rrok', 'p', 'oKKK', 'rrrE',
    'rrooE', 'rroooE', 'rroo', 'roo',
    'oKE', 'roKE', 'rrKE', 'rroKE', 'rrrKE', 'rrrokE'
]

def defines(formats, outfile):
    for i, fmt in enumerate(formats):
        print('#define FORMAT_ID_%s %d' % (fmt, i), file = outfile)
    print(file = outfile)

def main():
    import sys
    if len(sys.argv) != 3:
        print("Usage: %s (d|f) outfile" % sys.argv[0])
        sys.exit(1)
    f = sorted(formats)
    with open(sys.argv[2], 'w') as outfile:
        if sys.argv[1] == 'd':
            defines(f, outfile)
            for format in f:
                read_instruction(format, outfile)
        elif sys.argv[1] == 'f':
            for format in f:
                write_function(format, outfile)
            for format in f:
                print_function(format, outfile)
            for format in f:
                defuses_function(format, outfile)
            for format in f:
                relabel_uses_function(format, outfile)
            for format in f:
                relabel_defs_function(format, outfile)
            for format in f:
                get_exit_function(format, outfile)
        else:
            print("Usage: %s (d|f) outfile" % sys.argv[0])
            sys.exit(1)

if __name__ == '__main__':
    main()