Source

tutagx / tutagx / meta / to_yaml.py

Full commit
from collections import namedtuple
import yaml

from tutagx.util import codegen as cg
from tutagx.meta import model, process
from tutagx.meta.model import ModelMeta
from tutagx.meta.common import wrap_model, typeident

_WriterNamespace = namedtuple('_WriterNamespace', 'seen')


@process.oneshot
def _encoder():
    RESERVED = ('item', 'synth_id', 'oid', 'result', 'CLASSES', 'kind')
    sig = cg.Signature('$obj', 'seen')
    symtab = cg.SymbolTable(RESERVED, sig)
    emit = cg.CodeEmitter()
    classes = {}

    def conversion(expr, t):
        if isinstance(t, (model.Integer, model.Float, model.String)):
            return expr
        if isinstance(t, model.List):
            list_conv = yield from list_conv(t, expr)
            return conv
        if isinstance(t, model.Maybe):
            t_conv = yield from conversion(expr, t.t)
            if t_conv == expr:
                return expr
            return '(None if {0} is None else {1})'.format(expr, t_conv)
        if isinstance(t, model.Ref):
            return expr + '.object_id'
        func = yield t
        return sig.call(func, expr)

    def _list_conv(node, list_expr):
        item_expr = yield from conversion('item', node.items)
        return '[{} for item in {}]'.format(item_expr, list_expr)

    def _names_and_conversions(node):
        struct = ModelMeta.struct_for(node)
        member_names = tuple(name for name, _ in struct.members)
        convs = []
        for name, t in struct.members:
            conv = yield from conversion('obj.' + name, t)
            convs.append(conv)
        return member_names, tuple(convs)

    def visit_toplevel(cls):
        node = wrap_model(cls)
        if isinstance(node, model.Value):
            yield from visit_value(node)
            return
        member_names, convs = yield from _names_and_conversions(node)
        emit.line('synth_id = id(obj)')
        emit.line('if synth_id in seen: return seen[synth_id]')
        emit.line('oid = obj.object_id')
        emit.line('seen[synth_id] = oid')
        emit.line('result = {')
        emit.indent()
        for member_name, conv in zip(member_names, convs):
            emit.linef('{!r}: {},', member_name, conv)
        emit.dedent()
        emit.line('}')
        emit.line('result["$id"] = oid')
        emit.line('return result')

    def visit_ref(node):
        emit.line('return obj.object_id')

    def visit_value(node):
        member_names, convs = yield from _names_and_conversions(node)
        emit.line('synth_id = id(obj)')
        with emit.block('if synth_id in seen'):
            emit.line(
                'raise RuntimeError("{{}} is a value type ',
                'and must not be shared".format(obj))'
            )
        emit.line('seen[synth_id] = None')
        emit.line('return {')
        emit.indent()
        for member_name, conv in zip(member_names, convs):
            emit.linef('{!r}: {},', member_name, conv)
        emit.dedent()
        emit.line('}')

    def visit_list(node):
        with emit.block('if id(obj) in seen'):
            emit.line('raise RuntimeError("Lists must not be shared")')
        emit.line('seen[id(obj)] = None')
        conv = yield from _list_conv(node, 'obj')
        emit.line('return ', conv)

    def visit_union(node):
        for tag, t in node.alternatives:
            cls = ModelMeta.model_for(t)
            classes[cls.__qualname__] = cls
            key = repr(cls.__qualname__)
            with emit.block('if isinstance(obj, CLASSES[', key, '])'):
                conv = yield from conversion('obj', t)
                emit.line('return ', conv)

    visit_float = visit_integer = visit_string = NotImplemented
    visit_dict = visit_maybe = visit_struct = NotImplemented

    code_ns = {'CLASSES': classes}
    return cg.generate_code(
        locals(), typeident, sig, emit, symtab,
        entry_point=visit_toplevel, code_ns=code_ns
    )

# By default, the "scalar style" is all-or-nothing.
# Either *every single scalar* is forced into a particular style (which
# leads to an unholy mess with all the numbers, short strings, ``$id``s, etc)
# or everything's left simple and strings with newlines use ugly quotes.
# We want the "simple" style for most things, but we also want nice >-style
# formatting for longer texts -- hence, this hack.
# Based on http://stackoverflow.com/a/7445560/395760
def _custom_string_representer(dumper, data):
    TAG = 'tag:yaml.org,2002:str'
    if '\n' in data:
        return dumper.represent_scalar(TAG, data, style='>')
    return dumper.represent_scalar(TAG, data, style=None)


# Try to get the cuter ~ for None rather than null
def _custom_none_representer(dumper, data):
    return dumper.represent_scalar(
        'tag:yaml.org,2002:null',
        '~'
    )

yaml.add_representer(str, _custom_string_representer)
yaml.add_representer(type(None), _custom_none_representer)


class YAMLWriter:
    def __init__(self, model):
        self._encoder = _encoder(model)

    def encode(self, obj, namespace=None):
        return self.encode_many([obj], namespace)[0]

    def encode_many(self, objs, namespace=None):
        if namespace is None:
            namespace = YAMLWriter.new_namespace()
        return [self._encoder(obj, namespace.seen) for obj in objs]

    def dump(self, obj, f, namespace=None):
        self.dump_many([obj], f, namespace)

    def dump_many(self, objs, f, namespace=None):
        vals = self.encode_many(objs, namespace)
        yaml.safe_dump_all(
            vals, f,
            default_flow_style=False,
            default_style=None,
            indent=4
        )

    def dumps(self, obj, namespace=None):
        return yaml.safe_dump(self.encode(obj), namespace)

    @classmethod
    def new_namespace(cls):
        """
        Return a namespace object, which controls the scope of object IDs
        encountered in YAML documents.
        Writer namespaces currently only have a single member:
        ``seen``, a dictionary with object IDs encountered while decoding as
        keys and the objects instanciated for these IDs as values.
        It should be treated as immutable.

        Namespaces are only relevant if you plan to use one namespace across
        multiple documents, or wish to.
        """
        return _WriterNamespace({})