Source

tutagx / tutagx / meta / from_yaml.py

Full commit
import abc
from collections import namedtuple
import itertools
import yaml

from tutagx.meta import model, validate, process
from tutagx.meta.model import ModelMeta, Value
from tutagx.meta.common import typeident, wrap_model
import tutagx.util.codegen as cg

_ReaderNamespace = namedtuple('_ReaderNamespace', 'seen unfinished extra')


class _Placeholder:
    def __init__(self):
        self._locations = []

    def add_location(self, obj, attr):
        self._locations.append((obj, attr))

    def replace(self, actual_object):
        for obj, attr in self._locations:
            setattr(obj, attr, actual_object)
        del self._locations[:]

    # N.B. A __del__ method that warns on non-empty _locations would be useful
    # if not for the fact that cycles with __del__ methods are not collected
    # (Since the whole point of _Placeholder is replacing object attributes
    # which refer to the placeholder, cycles would be inevitable.)


class ObjectCreationStrategy(metaclass=abc.ABCMeta):
    def __init__(self, emit):
        self._emit = emit

    @abc.abstractmethod
    def construct(self, target, cls):
        pass

    @abc.abstractmethod
    def add_attributes(self, target, params):
        pass

    @abc.abstractmethod
    def finish(self, obj):
        pass

    @property
    def extra_args(self):
        return ()


class PlainObjectCreation(ObjectCreationStrategy):
    def construct(self, target, cls):
        self._emit.linef(
            '{} = classes[{!r}, {!r}](already_complete=False)',
            target, cls.__module__, cls.__name__
        )

    def add_attributes(self, target, params):
        for name, expr in params:
            self._emit.linef('{}.{} = {}', target, name, expr)

    def finish(self, obj):
        self._emit.line(obj, '._on_loaded()')


def _decoder(cls):
    RESERVED = ('classes', 'result', 'oid', 'json_val', 'k', '__Placeholder')
    emit = cg.CodeEmitter()
    ocs_class = getattr(cls, 'CONSTRUCTION_STRATEGY', PlainObjectCreation)
    ocs = ocs_class(emit)
    sig = cg.Signature('$obj', 'seen', 'unfinished', *ocs.extra_args)
    symtab = cg.SymbolTable(RESERVED, sig)

    classes = {}

    def conversion(expr, t):
        if isinstance(t, (model.Integer, model.Float, model.String)):
            return expr
        if isinstance(t, model.List):
            conv = yield from list_conv(t, expr)
            return conv
        if isinstance(t, model.Dict):
            conv = yield from dict_conv(t, expr)
            return conv
        if isinstance(t, model.Maybe):
            t_conv = yield from conversion(expr, t.t)
            if t_conv == expr:
                return expr
            return '(None if {0} is None else ({1}))'.format(expr, t_conv)
        else:
            func = yield t
            return sig.call(func, expr)

    def list_conv(node, list_expr):
        item_expr = yield from conversion('item', node.items)
        return '[{} for item in {}]'.format(item_expr, list_expr)

    def dict_conv(node, dict_expr):
        key_expr = yield from conversion('k', node.keys)
        val_expr = yield from conversion(str(dict_expr) + '[k]', node.values)
        return '{{{}: {} for k in {}}}'.format(key_expr, val_expr, dict_expr)

    def visit_toplevel(node):
        if isinstance(node, Value):
            yield from visit_value(node)
            return
        cls = ModelMeta.model_for(node)
        struct = ModelMeta.struct_for(node)
        classes[cls.__module__, cls.__name__] = cls
        member_names = tuple(name for name, _ in struct.members)
        loaders = []
        for name, t in struct.members:
            conv = yield from conversion('obj[{!r}]'.format(name), t)
            loaders.append(conv)

        emit.line('oid = obj["$id"]')
        emit.line('if oid in unfinished: result = seen[oid]')
        with emit.block('else'):
            ocs.construct('result', cls)
            emit.line('seen[oid] = result')
        emit.line('unfinished.discard(oid)')
        emit.line('result.object_id = oid')
        ocs.add_attributes('result', zip(member_names, loaders))
        ocs.finish('result')
        emit.line('return result')

    def visit_ref(node):
        cls = ModelMeta.model_for(node)
        struct_node = ModelMeta.struct_for(node)
        classes[cls.__module__, cls.__name__] = cls
        emit.line('oid = obj')
        emit.line(
            'assert isinstance(oid, str), repr(oid) + " is not an ID"'
        )
        emit.line('if oid in seen: return seen[oid]')
        ocs.construct('obj', cls)
        emit.line('seen[oid] = obj')
        emit.line('unfinished.add(oid)')
        emit.line('return obj')

    def visit_value(node):
        cls = ModelMeta.model_for(node)
        struct = ModelMeta.struct_for(node)
        classes[cls.__module__, cls.__name__] = cls
        member_names = tuple(name for name, _ in struct.members)
        loaders = []
        for name, t in struct.members:
            conv = yield from conversion('obj[{!r}]'.format(name), t)
            loaders.append(conv)

        ocs.construct('result', cls)
        ocs.add_attributes('result', zip(member_names, loaders))
        ocs.finish('result')
        emit.line('return result')

    visit_integer = visit_string = visit_float = NotImplemented
    visit_list = visit_dict = NotImplemented
    visit_struct = visit_maybe = NotImplemented

    def visit_union(node):
        with emit.block('if isinstance(obj, str)'):
            emit.line('return __Placeholder()')
        for tag, t in node.alternatives:
            with emit.block('if obj["$type"] == ', repr(tag)):
                conv = yield from conversion('obj', t)
                emit.line('return ', conv)

    code_ns = {'classes': classes, '__Placeholder': _Placeholder}

    gen = cg.generate_code(
        locals(), typeident, sig, emit, symtab,
        entry_point=visit_toplevel, code_ns=code_ns
    )
    return gen(wrap_model(cls))


#TODO replace this API with something better
class YAMLReader:
    def __init__(self, model):
        self._mcls = type(model)
        self._cls = model
        self._decoder = _decoder(model)

    def decode(self, val, namespace=None, validator_ns=None, source=None):
        results, errors = self.decode_many(
            [val], namespace, validator_ns, source
        )
        if errors:
            return None, errors
        return results[0], errors

    def decode_many(self, raw_values,
                    namespace=None, validator_ns=None,
                    source=None):
        if namespace is None:
            namespace = type(self).new_namespace()
        if validator_ns is None:
            validator_ns = validate.Validator.new_namespace()
        args = (namespace.seen, namespace.unfinished) + namespace.extra
        validator = validate.Validator(self._cls)
        results = []
        for raw_value in raw_values:
            validator.validate(raw_value, validator_ns, source)
            # Note that we continue validating, but stop decoding,
            # as soon as ANY error is encountered.
            # This allows producing additional useful errors without
            # trying to decode possibly invalid data.
            if validator_ns.errors:
                continue
            # If decoding raises an exception, it's a bug in the
            # validation and SHOULD lead to a crash!
            results.append(self._decoder(raw_value, *args))
        # If any error occured, the whole data set is likely invalid.
        # Thus, we do not attempt to return partial data.
        if validator_ns.errors:
            return None, validator_ns.errors
        return results, validator_ns.errors

    def load(self, f, namespace=None, validator_ns=None):
        results, errors = self.load_many(f, namespace, validator_ns)
        if errors:
            return None, errors
        return results[0], errors

    def load_many(self, f, namespace=None, validator_ns=None):
        try:
            source = f.name
        except AttributeError:
            source = '<unknown file>'
        raw_data = yaml.safe_load_all(f)
        return self.decode_many(raw_data, namespace, validator_ns, source)

    def loads(self, s, namespace=None):
        return self.decode(yaml.safe_load(s), namespace)

    @classmethod
    def new_namespace(cls):
        """
        Return a namespace object, which controls the scope of object IDs
        encountered in YAML documents.
        Reader namespaces have two members:

        * ``seen``, a dictionary with object IDs encountered while decoding as
          keys and the objects instanciated for these IDs as values.
        * ``unfinished``, a set of object IDs that some loaded objects referred
          to for which no object has been loaded (yet).
          If this set is non-empty after all data was loaded, there were
          bogus IDs in the input.

        Namespaces are only relevant if you plan on using one namespace across
        multiple documents, or need access to the information exposed.
        """
        return _ReaderNamespace({}, set(), cls._extra_decoder_args())

    @classmethod
    def _extra_decoder_args(cls):
        # Dummy implementation, this is only useful for subclasses
        return ()