Source

tutagx / tutagx / meta / validate.py

Full commit
from itertools import chain
from contextlib import contextmanager
from collections import defaultdict, namedtuple

from tutagx.meta.collect_objects import collect_instances
from tutagx.meta.common import typeident, wrap_model
from tutagx.meta import model, process
import tutagx.util.codegen as cg

_ValidationContext = namedtuple(
    '_ValidationContext',
    'seen defined errors access'
)


class ValidationContext(_ValidationContext):
    def __new__(cls):
        return _ValidationContext.__new__(cls,
            defaultdict(set), set(), [], []
        )


def undefined_ids(ctx):
    """
    Return a frozenset of object IDs that are "undefined" (or "unfinished").
    An object ID is undefined if it has been used to refer to an object,
    but that object has not been loaded.
    Note that in cases where the data is split into several documents,
    the result is not meaningful before all documents have been loaded into
    the same namespace.
    """
    return frozenset(ctx.seen) - frozenset(ctx.defined)


def add_errors_for_undefined(ctx):
    """
    Add error messages for ``undefined_ids(ctx)``.
    Note that this may add wrong error messages the data is split across
    multiple files and this is called before all sources are checked in
    the same namespace.
    """
    for undef_id in undefined_ids(ctx):
        users = ctx.seen[undef_id]
        users_str = '; '.join(users)
        msg = 'referenced (from {}) but not defined'.format(users_str)
        ctx.errors.append((KeyError, repr(undef_id), msg))


def add_errors_for_unreachable(cls, ctx, loaded, errors):
    """
    Add error messages for ``unreachable_objects(cls, ctx, loaded)``.
    The caveats of ``add_errors_for_undefined`` apply here too.
    """
    for oid in unreachable_objects(cls, ctx, loaded):
        errors.append((
            RuntimeError, repr(oid),
            "not used (would be dropped during serialization)"
        ))


def unreachable_objects(cls, namespace, loaded):
    """
    Return a frozenset of object IDs which were loaded in the namespace
    but are not reachable from the object ``loaded``.
    """
    oid_to_objects = namespace.seen
    undefined = namespace.unfinished
    all_oids = frozenset(oid_to_objects.keys())
    objs = collect_instances(loaded)
    ref_type_objs = [cls for cls in objs if not cls.is_value_type]
    reachable_iter = chain(*(collected[cls] for cls in ref_type_objs))
    reachable_oids = frozenset(o.object_id for o in reachable_iter)
    return all_oids - reachable_oids


def format_errors(errors):
    """
    Given a list of error tuples (as the ``errors`` property of validation
    namespaces), return a list of readable, ``print``-able lines describing
    the errors.
    """
    exc_by_ident = {ident: [] for _, ident, _ in errors}
    for exc_type, ident, msg in errors:
        exc_by_ident[ident].append(msg)
    lines = []
    # Sorting the errors by ID not necessary, but nice to have.
    for ident in sorted(exc_by_ident.keys()):
        exc_here = exc_by_ident[ident]
        if len(exc_here) == 1:
            lines.append(ident + ': ' + exc_here[0])
        else:
            lines.append(ident + ': ')
            lines.extend('\t' + exc_here)
    return lines


@process.oneshot
def make_validate():
    sig = cg.Signature(
        '$obj', 'seen', 'defined', 'errors', 'stack', 'source', '$is_union'
    )
    RESERVED = ('error_occured', 'item', 'key', 'i', 'attr',
                '__id_undefined', '__type_missing', '__trace')
    emit = cg.CodeEmitter()
    symtab = cg.SymbolTable(RESERVED, sig)

    classes = {}
    with emit.block('def __trace(src, access)'):
        emit.line('return "{}: {!r} {}".format(src, '
                    'access[-1][0], "".join(access[-1][1:]))')
    TRACE = '__trace(source, stack)' # shortcut

    @contextmanager
    def push(access):
        emit.line('stack[-1].append(', access, ')')
        yield
        emit.line('stack[-1].pop()')

    @contextmanager
    def push_frame(obj_expr):
        require_key(obj_expr, '$id', error_flag='__id_undefined',
            msg='"object ID missing"'
        )
        emit.line('if __id_undefined: return')
        emit.line('stack.append([', obj_expr, '["$id"]])')
        emit.line('defined.add(', obj_expr, '["$id"])')
        yield
        emit.line('stack.pop()')

    def error(exc_type, message, *args):
        emit.linef('errors.append(({}, {}, {}))',
                    exc_type.__name__, TRACE, message.format(*args)
        )

    def descend(node, expr, is_union=False):
        func = yield node
        emit.line(sig.call(func, expr, str(is_union)))

    def require_isinstance(expr, expected, expected_name=None):
        # Generate a statement that adds a TypeError if ``expr`` does not
        # evaluate to an object of type ``expected``. (Both are strings!)
        if expected_name is None:
            expected_name = expected
        with emit.block('if not isinstance(', expr, ', ', expected, ')'):
            error(TypeError, '" should be {}, got " + type({}).__name__',
                expected_name, expr
            )
            emit.line('return')

    def require_key(dict_expr, key, error_flag=None, msg='"is missing"'):
        # If the object to which ``dict_expr`` evaluates does not have the key
        # ``key``, add an AttributeError (not KeyError!).
        # If ``error_flag`` is given, also assign true to that variable.
        if error_flag is not None:
            emit.linef('{} = {!r} not in {}', error_flag, key, dict_expr)
        with emit.block('if ' + repr(key) + ' not in ', dict_expr):
            error(AttributeError, msg)

    def require_between(expr, min_val, max_val):
        # Cause errors if expr evaluated to something not in
        # range(min_val, max_val) -- note that this means max_val is invalid.
        # Either value may be None to omit that part of the check.
        if min_val is not None:
            with emit.block('if ', expr, ' < ', min_val):
                error(ValueError, '"should be >= {}, is " + str({})',
                    min_val, expr
                )
                emit.line('return')
        if max_val is not None:
            with emit.block('if ', expr, ' >= ', max_val):
                error(ValueError, '"should be < {}, is " + str({})',
                    max_val, expr
                )
                emit.line('return')

    def visit_integer(node):
        require_isinstance('obj', 'int')
        require_between('obj', node.min, node.max)

    def visit_float(node):
        # Integers are acceptable as they can be converted to floats.
        # Strings and other stuff that may be convertable is NOT okay.
        require_isinstance('obj', '(int, float)', 'number')
        require_between('obj', node.min, node.max)

    def visit_string(node):
        require_isinstance('obj', 'str')

    def visit_list(node):
        require_isinstance('obj', 'list')
        with emit.block('for i, item in enumerate(obj)'):
            with push('"["+str(i)+"]"'):
                yield from descend(node.items, 'item')

    def visit_dict(node):
        require_isinstance('obj', 'dict')
        check_key = yield node.keys
        check_val = yield node.values
        with emit.block('for key in obj'):
            with push('"<key {!r}>".format(key)'):
                yield from descend(node.keys, 'key')
            with push('"[{!r}]".format(key)'):
                yield from descend(node.values, 'obj[key]')

    def visit_maybe(node):
        emit.line('if obj is None: return')
        yield from descend(node.t, 'obj')

    def visit_ref(node):
        struct = model.ModelMeta.struct_for(node)
        require_isinstance('obj', '(dict, str)', 'structure or ID')
        with emit.block('if isinstance(obj, str)'):
            # Log where we've seen it, for better error messages in case
            # it is not defined anywhere.
            emit.line('seen[obj].add(', TRACE, ')')
            emit.line('return')
        with push_frame('obj'):
            yield from do_struct(struct)

    def visit_value(node):
        struct = model.ModelMeta.struct_for(node)
        require_isinstance('obj', 'dict', 'structure')
        yield from do_struct(struct)

    def do_struct(node):
        # We could simply abort after the first errors.
        # But this way, we can catch multiple independent errors per object,
        # without recursing down paths that only pile up consequential errors.
        for name, t in node.members:
            with push('".{}"'.format(name)):
                require_key('obj', name, 'error_occured')
                with emit.block('if not error_occured'):
                    yield from descend(t, 'obj[{!r}]'.format(name))
        member_names = set(name for name, t in node.members)
        permitted = member_names | {'$id'}
        with emit.block('for attr in set(obj.keys()) - ', repr(permitted)):
            emit.line('if is_union and attr == "$type": continue')
            error(TypeError, '"unused attribute " + attr')

    def visit_union(node):
        require_isinstance('obj', '(dict, str)', 'structure or ID')
        with emit.block('if isinstance(obj, str)'):
            emit.line('seen[obj].add(', TRACE, ')')
            emit.line('return')
        tags = {tag for tag, t in node.alternatives}
        require_key('obj', '$type', '__type_missing', '"type missing"')
        emit.line('if __type_missing: return')
        with emit.block('if obj["$type"] not in ', repr(tags)):
            error(TypeError, '"unknown type " + repr(obj["$type"]) + '
                '"; expected one of: {}"',
                ', '.join(tags)
            )
        for tag, t in node.alternatives:
            with emit.block('if obj["$type"] == ', repr(tag)):
                yield from descend(t, 'obj', is_union=True)

    visit_struct = NotImplemented

    return cg.generate_code(
        locals(), mkname=typeident, signature=sig, emit=emit, symtab=symtab
    )


def validate(cls, obj, context, source):
    raw_validator = make_validate(wrap_model(cls))
    raw_validator(
        obj,
        context.seen, context.defined, context.errors, context.access, source,
        is_union=False
    )


class Validator:
    def __init__(self):
        self.context = ValidationContext()

    def validate(self, cls, obj, source='<string>'):
        # This is necessary because even ``obj`` may be invalid.
        try:
            oid = obj['$id']
        except Exception:
            self.context.access.append(['<toplevel object>'])
        else:
            self.context.access.append([oid])
        validate(cls, obj, self.context, source)
        return not self.context.errors