Source

tutagx / tutagx / meta / encode.py

from collections import namedtuple
from functools import lru_cache

from tutagx.util import codegen as cg
from tutagx.meta import model, process
from tutagx.meta.model import ModelMeta
from tutagx.meta.common import wrap_model, typeident


EncodingContext = namedtuple('EncodingContext', 'seen')


@lru_cache()
@process.oneshot
def make_encoder():
    RESERVED = (
        'item', 'synth_id', 'oid', 'result', 'CLASSES', 'kind', 'k', 'v'
    )
    sig = cg.Signature('$obj', 'seen')
    symtab = cg.SymbolTable(RESERVED, sig)
    emit = cg.CodeEmitter()
    classes = {}

    def conversion(expr, t):
        if isinstance(t, (model.Integer, model.Float, model.String)):
            return expr
        if isinstance(t, model.Ref):
            return expr + '.object_id'
        if isinstance(t, model.List):
            conv = yield from _list_conv(t, expr)
            return conv
        if isinstance(t, model.Dict):
            conv = yield from _dict_conv(t, expr)
            return conv
        if isinstance(t, model.Maybe):
            t_conv = yield from conversion(expr, t.t)
            if t_conv == expr:
                return expr
            return '(None if {0} is None else {1})'.format(expr, t_conv)
        if isinstance(t, model.Ref):
            return expr + '.object_id'
        func = yield t
        return sig.call(func, expr)

    def _list_conv(node, list_expr):
        item_expr = yield from conversion('item', node.items)
        return '[{} for item in {}]'.format(item_expr, list_expr)

    def _dict_conv(node, dict_expr):
        key_expr = yield from conversion('k', node.keys)
        val_expr = yield from conversion('v', node.values)
        return '{{{}: {} for k, v in {}.items()}}'.format(
            key_expr, val_expr, dict_expr
        )

    def _names_and_conversions(node):
        struct = ModelMeta.struct_for(node)
        member_names = tuple(name for name, _ in struct.members)
        convs = []
        for name, t in struct.members:
            conv = yield from conversion('obj.' + name, t)
            convs.append(conv)
        return member_names, tuple(convs)

    def visit_toplevel(cls):
        node = wrap_model(cls)
        if isinstance(node, model.Value):
            yield from visit_value(node)
            return
        member_names, convs = yield from _names_and_conversions(node)
        emit.line('synth_id = id(obj)')
        emit.line('if synth_id in seen: return seen[synth_id]')
        emit.line('oid = obj.object_id')
        emit.line('seen[synth_id] = oid')
        emit.line('result = {')
        emit.indent()
        for member_name, conv in zip(member_names, convs):
            emit.linef('{!r}: {},', member_name, conv)
        emit.dedent()
        emit.line('}')
        emit.line('result["$id"] = oid')
        emit.line('return result')

    def visit_value(node):
        member_names, convs = yield from _names_and_conversions(node)
        emit.line('synth_id = id(obj)')
        with emit.block('if synth_id in seen'):
            emit.line(
                'raise RuntimeError("{{}} is a value type ',
                'and must not be shared".format(obj))'
            )
        emit.line('seen[synth_id] = None')
        emit.line('return {')
        emit.indent()
        for member_name, conv in zip(member_names, convs):
            emit.linef('{!r}: {},', member_name, conv)
        emit.dedent()
        emit.line('}')

    def visit_list(node):
        with emit.block('if id(obj) in seen'):
            emit.line('raise RuntimeError("Lists must not be shared")')
        emit.line('seen[id(obj)] = None')
        conv = yield from _list_conv(node, 'obj')
        emit.line('return ', conv)

    def visit_union(node):
        for tag, t in node.alternatives:
            cls = ModelMeta.model_for(t)
            classes[cls.__qualname__] = cls
            key = repr(cls.__qualname__)
            with emit.block('if isinstance(obj, CLASSES[', key, '])'):
                conv = yield from conversion('obj', t)
                emit.line('return ', conv)

    visit_float = visit_integer = visit_string = NotImplemented
    visit_dict = visit_maybe = visit_struct = visit_ref = NotImplemented

    code_ns = {'CLASSES': classes}
    return cg.generate_code(
        locals(), typeident, sig, emit, symtab,
        entry_point=visit_toplevel, code_ns=code_ns
    )


def encode(obj, context):
    """
    Encode an object in the given EncodingContext and return the result.
    """
    raw_encoder = make_encoder(type(obj))
    return raw_encoder(obj, context.seen)