Source

tutagx / tutagx / meta / decode.py

import abc
from collections import namedtuple, defaultdict
from functools import lru_cache, partial

from tutagx.meta import model, process
from tutagx.meta.model import ModelMeta, Value
from tutagx.meta.common import typeident, wrap_model
import tutagx.util.codegen as cg


_DecodingContext = namedtuple('_DecodingContext', 'seen unfinished extra')


class DecodingContext(_DecodingContext):
    def __new__(cls):
        return _DecodingContext.__new__(cls, {}, set(), {})


class ObjectCreationStrategy(metaclass=abc.ABCMeta):
    def __init__(self, emit):
        self._emit = emit

    @abc.abstractmethod
    def construct(self, target, cls):
        pass

    @abc.abstractmethod
    def add_attributes(self, target, params):
        pass

    @abc.abstractmethod
    def finish(self, obj):
        pass


class PlainObjectCreation(ObjectCreationStrategy):
    def construct(self, target, cls):
        self._emit.linef(
            '{} = classes[{!r}, {!r}](already_complete=False)',
            target, cls.__module__, cls.__name__
        )

    def add_attributes(self, target, params):
        for name, expr in params:
            self._emit.linef('{}.{} = {}', target, name, expr)

    def finish(self, obj):
        self._emit.line(obj, '._on_loaded()')


class _Placeholder:
    def __init__(self):
        self._locations = []

    def add_location(self, obj, attr):
        self._locations.append((obj, attr))

    def replace(self, actual_object):
        for obj, attr in self._locations:
            setattr(obj, attr, actual_object)
        del self._locations[:]

    # N.B. A __del__ method that warns on non-empty _locations would be useful
    # if not for the fact that cycles with __del__ methods are not collected
    # (Since the whole point of _Placeholder is replacing object attributes
    # which refer to the placeholder, cycles would be inevitable.)


def _ocs(cls, emit):
    ocs_class = getattr(cls, 'CONSTRUCTION_STRATEGY', PlainObjectCreation)
    return ocs_class(emit)


@lru_cache()
def make_decoder(cls):
    RESERVED = ('classes', 'result', 'oid', 'json_val', 'k', '__Placeholder')
    emit = cg.CodeEmitter()
    sig = cg.Signature('$obj', 'seen', 'unfinished', 'extra')
    symtab = cg.SymbolTable(RESERVED, sig)
    classes = {}

    def const(x):
        return x
        yield

    def conversion(expr, t):
        convs =  {
            model.Integer: partial(const, expr),
            model.Float: partial(const, expr),
            model.String: partial(const, expr),
            model.List: partial(list_conv, t, expr),
            model.Dict: partial(dict_conv, t, expr),
            model.Maybe: partial(maybe_conv, t, expr)
        }
        if type(t) in convs:
            return (yield from convs[type(t)]())
        else:
            func = yield t
            return sig.call(func, expr)

    def list_conv(node, list_expr):
        item_expr = yield from conversion('item', node.items)
        return '[{} for item in {}]'.format(item_expr, list_expr)

    def dict_conv(node, dict_expr):
        key_expr = yield from conversion('k', node.keys)
        val_expr = yield from conversion(str(dict_expr) + '[k]', node.values)
        return '{{{}: {} for k in {}}}'.format(key_expr, val_expr, dict_expr)

    def maybe_conv(node, expr):
        t_conv = yield from conversion(expr, node.t)
        if t_conv == expr:
            return expr
        return '(None if {0} is None else {1})'.format(expr, t_conv)

    def visit_toplevel(node):
        if isinstance(node, Value):
            yield from visit_value(node)
            return
        cls = ModelMeta.model_for(node)
        struct = ModelMeta.struct_for(node)
        classes[cls.__module__, cls.__name__] = cls
        member_names = tuple(name for name, _ in struct.members)
        ocs = _ocs(cls, emit)
        loaders = []
        for name, t in struct.members:
            conv = yield from conversion('obj[{!r}]'.format(name), t)
            loaders.append(conv)

        emit.line('oid = obj["$id"]')
        emit.line('if oid in unfinished: result = seen[oid]')
        with emit.block('else'):
            ocs.construct('result', cls)
            emit.line('seen[oid] = result')
        emit.line('unfinished.discard(oid)')
        emit.line('result.object_id = oid')
        ocs.add_attributes('result', zip(member_names, loaders))
        ocs.finish('result')
        emit.line('return result')

    def visit_ref(node):
        cls = ModelMeta.model_for(node)
        struct_node = ModelMeta.struct_for(node)
        classes[cls.__module__, cls.__name__] = cls
        emit.line('oid = obj')
        emit.line(
            'assert isinstance(oid, str), repr(oid) + " is not an ID"'
        )
        emit.line('if oid in seen: return seen[oid]')
        ocs = _ocs(cls, emit)
        ocs.construct('obj', cls)
        emit.line('seen[oid] = obj')
        emit.line('unfinished.add(oid)')
        emit.line('return obj')

    def visit_value(node):
        cls = ModelMeta.model_for(node)
        struct = ModelMeta.struct_for(node)
        classes[cls.__module__, cls.__name__] = cls
        member_names = tuple(name for name, _ in struct.members)
        loaders = []
        for name, t in struct.members:
            conv = yield from conversion('obj[{!r}]'.format(name), t)
            loaders.append(conv)
        ocs = _ocs(cls, emit)
        ocs.construct('result', cls)
        ocs.add_attributes('result', zip(member_names, loaders))
        ocs.finish('result')
        emit.line('return result')

    visit_integer = visit_string = visit_float = NotImplemented
    visit_list = visit_dict = NotImplemented
    visit_struct = visit_maybe = NotImplemented

    def visit_union(node):
        with emit.block('if isinstance(obj, str)'):
            emit.line('return __Placeholder()')
        for tag, t in node.alternatives:
            with emit.block('if obj["$type"] == ', repr(tag)):
                conv = yield from conversion('obj', t)
                emit.line('return ', conv)

    code_ns = {'classes': classes, '__Placeholder': _Placeholder}

    gen = cg.generate_code(
        locals(), typeident, sig, emit, symtab,
        entry_point=visit_toplevel, code_ns=code_ns
    )
    return gen(wrap_model(cls))


def decode(cls, obj, context):
    """
    Decode the YAML data obj as serialized object of type cls in the context.
    """
    raw_decoder = make_decoder(cls)
    return raw_decoder(obj, context.seen, context.unfinished, context.extra)