Commits

rkruppe committed 7e9da72

Port validation to new codgen

  • Participants
  • Parent commits 8de6021
  • Branches experiments

Comments (0)

Files changed (2)

tutagx/meta/common.py

     that nevertheless leads to somewhat readable output.
     """
     def visit_float(node):
-        return 'float'
+        return 'float_num'
 
     def visit_integer(node):
-        return 'int'
+        return 'integer'
 
     def visit_string(node):
-        return 'str'
+        return 'string'
 
     def visit_list(node):
         return 'list_' + visit(node.items)

tutagx/meta/validate.py

-"""
-Provide various means.
-"""
 from itertools import chain
 from contextlib import contextmanager
 from collections import defaultdict, namedtuple
 from tutagx.meta.collect_objects import collect_instances
-from tutagx.meta.common import FunctionCodeGen
-from tutagx.meta import model
-
+from tutagx.meta.common import typeident, wrap_model
+from tutagx.meta import model, process
+import tutagx.util.codegen as cg
 
 _ValidatorNamespace = namedtuple(
     '_ValidatorNamespace',
     return lines
 
 
-class _Validator(FunctionCodeGen):
-    _TARGET_FUNC_ATTR = '_check'
-    _ARGS = (
-        'obj', 'seen', 'defined', 'errors', 'stack', 'source', 'is_union'
+@process.oneshot
+def _validate():
+    sig = cg.Signature(
+        '$obj', 'seen', 'defined', 'errors', 'stack', 'source', '$is_union'
     )
-    _HINT_PREFIX = 'check_'
-    # __trace generated an "identifier" for the currently processed object
-    # (according to stack).
-    # Said identifier is the object ID for reference types, or the
-    # object ID of the owning object with additional attribute access to
-    # the value object in question.
-    _PRELUDE = '''
-def __trace(src, access):
-    return '{}: {!r} {}'.format(src, access[-1][0], ''.join(access[-1][1:]))
-    '''.strip()
-    # Just a shortcut, as it is needed in several places.
-    __TRACE = '__trace(source, stack)'
+    RESERVED = ('error_occured', 'item', 'key', 'i', 'attr',
+                '__id_undefined', '__type_missing', '__trace')
+    emit = cg.CodeEmitter()
+    symtab = cg.SymbolTable(RESERVED, sig)
 
-    def __init__(self, *args):
-        # Stack of stacks of attribute names, describing how one got there
-        # but leaving reference types aside
-        self._classes = {}
-        super().__init__(*args, reserved=(
-            'error_occured', 'item', 'key', 'i', 'attr',
-            '__id_undefined', '__type_missing', '__trace'
-        ))
+    classes = {}
+    with emit.block('def __trace(src, access)'):
+        emit.line('return "{}: {!r} {}".format(src, '
+                    'access[-1][0], "".join(access[-1][1:]))')
+    TRACE = '__trace(source, stack)' # shortcut
 
     @contextmanager
-    def _access(self, access):
-        self.line('stack[-1].append({})', access)
+    def push(access):
+        emit.line('stack[-1].append(', access, ')')
         yield
-        self.line('stack[-1].pop()')
+        emit.line('stack[-1].pop()')
 
     @contextmanager
-    def _access_frame(self, obj_expr):
-        self.line('__id_undefined = False')
-        self._require_key(
-            obj_expr, '$id',
-            error_flag='__id_undefined',
+    def push_frame(obj_expr):
+        require_key(obj_expr, '$id', error_flag='__id_undefined',
             msg='"object ID missing"'
         )
-        self.line('if __id_undefined: return')
-        self.line('stack.append([{}["$id"]])', obj_expr)
-        self.line('defined.add({}["$id"])', obj_expr)
+        emit.line('if __id_undefined: return')
+        emit.line('stack.append([', obj_expr, '["$id"]])')
+        emit.line('defined.add(', obj_expr, '["$id"])')
         yield
-        self.line('stack.pop()')
+        emit.line('stack.pop()')
 
-    def _error(self, exc_type, message):
-        # Emit an error with the given ``exc_type`` and ``message``.
-        self.line(
-            'errors.append(({}, {}, {}))',
-            exc_type.__name__, self.__TRACE, message
+    def error(exc_type, message, *args):
+        emit.linef('errors.append(({}, {}, {}))',
+                    exc_type.__name__, TRACE, message.format(*args)
         )
 
-    def _descend(self, node, expr, is_union=False):
-        # Call the function for ``node``, with ``expr`` as argument.
-        self.line(
-            '{}({}, seen, defined, errors, stack, source, {})',
-            self.genfunc(node), expr, is_union
-        )
+    def descend(node, expr, is_union=False):
+        func = yield node
+        emit.line(sig.call(func, expr, str(is_union)))
 
-    def _require_isinstance(self, expr, expected, expected_name=None):
+    def require_isinstance(expr, expected, expected_name=None):
         # Generate a statement that adds a TypeError if ``expr`` does not
         # evaluate to an object of type ``expected``. (Both are strings!)
         if expected_name is None:
             expected_name = expected
-        with self.block('if not isinstance({}, {}):', expr, expected):
-            self._error(
-                TypeError,
-                '" should be {}, got " + type({}).__name__'
-                .format(expected_name, expr)
+        with emit.block('if not isinstance(', expr, ', ', expected, ')'):
+            error(TypeError, '" should be {}, got " + type({}).__name__',
+                expected_name, expr
             )
-            self.line('return')
+            emit.line('return')
 
-    def _require_key(self, dict_expr, key,
-                     error_flag=None, msg='"is missing"'):
+    def require_key(dict_expr, key, error_flag=None, msg='"is missing"'):
         # If the object to which ``dict_expr`` evaluates does not have the key
         # ``key``, add an AttributeError (not KeyError!).
         # If ``error_flag`` is given, also assign true to that variable.
-        with self.block('if {!r} not in {}:', key, dict_expr):
-            self._error(AttributeError, msg)
-            if error_flag is not None:
-                self.line('{} = True', error_flag)
+        if error_flag is not None:
+            emit.linef('{} = {!r} not in {}', error_flag, key, dict_expr)
+        with emit.block('if ' + repr(key) + ' not in ', dict_expr):
+            error(AttributeError, msg)
 
-    def _require_between(self, expr, min_val, max_val):
+    def require_between(expr, min_val, max_val):
         # Cause errors if expr evaluated to something not in
         # range(min_val, max_val) -- note that this means max_val is invalid.
         # Either value may be None to omit that part of the check.
         if min_val is not None:
-            with self.block('if {} < {}:', expr, min_val):
-                self._error(ValueError,
-                    '"should be >= {}, is " + str({})'.format(min_val, expr)
+            with emit.block('if ', expr, ' < ', min_val):
+                error(ValueError, '"should be >= {}, is " + str({})',
+                    min_val, expr
                 )
-                self.line('return')
+                emit.line('return')
         if max_val is not None:
-            with self.block('if obj >= {}:', max_val):
-                self._error(ValueError,
-                    '"should be < {}, is " + str({})'.format(max_val, expr)
+            with emit.block('if ', expr, ' >= ', max_val):
+                error(ValueError, '"should be < {}, is " + str({})',
+                    max_val, expr
                 )
-                self.line('return')
+                emit.line('return')
 
-    def visit_integer(self, node):
-        self._require_isinstance('obj', 'int')
-        self._require_between('obj', node.min, node.max)
+    def visit_integer(node):
+        require_isinstance('obj', 'int')
+        require_between('obj', node.min, node.max)
 
-    def visit_float(self, node):
+    def visit_float(node):
         # Integers are acceptable as they can be converted to floats.
         # Strings and other stuff that may be convertable is NOT okay.
-        self._require_isinstance('obj', '(int, float)', 'number')
-        self._require_between('obj', node.min, node.max)
+        require_isinstance('obj', '(int, float)', 'number')
+        require_between('obj', node.min, node.max)
 
-    def visit_string(self, node):
-        self._require_isinstance('obj', 'str')
+    def visit_string(node):
+        require_isinstance('obj', 'str')
 
-    def visit_list(self, node):
-        self._require_isinstance('obj', 'list')
-        with self.block('for i, item in enumerate(obj):'):
-            with self._access('"[{}]".format(i)'):
-                self._descend(node.items, 'item')
+    def visit_list(node):
+        require_isinstance('obj', 'list')
+        with emit.block('for i, item in enumerate(obj)'):
+            with push('"["+str(i)+"]"'):
+                yield from descend(node.items, 'item')
 
-    def visit_dict(self, node):
-        self._require_isinstance('obj', 'dict')
-        check_key = self.genfunc(node.keys)
-        check_val = self.genfunc(node.values)
-        with self.block('for key in obj:'):
-            with self._access('"<key {!r}>".format(key)'):
-                self._descend(node.keys, 'key')
-            with self._access('"[{!r}]".format(key)'):
-                self._descend(node.values, 'obj[key]')
+    def visit_dict(node):
+        require_isinstance('obj', 'dict')
+        check_key = yield node.keys
+        check_val = yield node.values
+        with emit.block('for key in obj'):
+            with push('"<key {!r}>".format(key)'):
+                yield from descend(node.keys, 'key')
+            with push('"[{!r}]".format(key)'):
+                yield from descend(node.values, 'obj[key]')
 
-    def visit_maybe(self, node):
-        self.line('if obj is None: return')
-        self._descend(node.t, 'obj')
+    def visit_maybe(node):
+        emit.line('if obj is None: return')
+        yield from descend(node.t, 'obj')
 
-    def visit_ref(self, node):
+    def visit_ref(node):
         struct = model.ModelMeta.struct_for(node)
-        self._require_isinstance('obj', '(dict, str)', 'structure or ID')
-        with self.block('if isinstance(obj, str):'):
+        require_isinstance('obj', '(dict, str)', 'structure or ID')
+        with emit.block('if isinstance(obj, str)'):
             # Log where we've seen it, for better error messages in case
             # it is not defined anywhere.
-            self.line('seen[obj].add({})', self.__TRACE)
-            self.line('return')
-        with self._access_frame('obj'):
-            self._struct(struct)
+            emit.line('seen[obj].add(', TRACE, ')')
+            emit.line('return')
+        with push_frame('obj'):
+            yield from do_struct(struct)
 
-    def visit_value(self, node):
+    def visit_value(node):
         struct = model.ModelMeta.struct_for(node)
-        self._require_isinstance('obj', 'dict', 'structure')
-        self._struct(struct)
+        require_isinstance('obj', 'dict', 'structure')
+        yield from do_struct(struct)
 
-    def _struct(self, node):
+    def do_struct(node):
         # We could simply abort after the first errors.
         # But this way, we can catch multiple independent errors per object,
         # without recursing down paths that only pile up consequential errors.
-        self.line('error_occured = False')
         for name, t in node.members:
-            with self._access('".{}"'.format(name)):
-                self._require_key('obj', name, 'error_occured')
-                with self.block('if not error_occured:'):
-                    self._descend(t, 'obj[{!r}]'.format(name))
+            with push('".{}"'.format(name)):
+                require_key('obj', name, 'error_occured')
+                with emit.block('if not error_occured'):
+                    yield from descend(t, 'obj[{!r}]'.format(name))
         member_names = set(name for name, t in node.members)
         permitted = member_names | {'$id'}
-        with self.block('for attr in set(obj.keys()) - {!r}:', permitted):
-            self.line('if is_union and attr == "$type": continue')
-            self._error(TypeError, '"unused attribute " + attr')
+        with emit.block('for attr in set(obj.keys()) - ', repr(permitted)):
+            emit.line('if is_union and attr == "$type": continue')
+            error(TypeError, '"unused attribute " + attr')
 
-    def visit_union(self, node):
-        self._require_isinstance('obj', '(dict, str)', 'structure or ID')
-        with self.block('if isinstance(obj, str):'):
-            self.line('seen[obj].add({})', self.__TRACE)
-            self.line('return')
+    def visit_union(node):
+        require_isinstance('obj', '(dict, str)', 'structure or ID')
+        with emit.block('if isinstance(obj, str)'):
+            emit.line('seen[obj].add(', TRACE, ')')
+            emit.line('return')
         tags = {tag for tag, t in node.alternatives}
-        self.line('__type_missing = False')
-        self._require_key('obj', '$type', '__type_missing', '"type missing"')
-        self.line('if __type_missing: return')
-        with self.block('if obj["$type"] not in {!r}:', tags):
-            self._error(TypeError,
-                '"unknown type " + repr(obj["$type"]) + '
-                '"; expected one of: {}"'.format(', '.join(tags))
+        require_key('obj', '$type', '__type_missing', '"type missing"')
+        emit.line('if __type_missing: return')
+        with emit.block('if obj["$type"] not in ', repr(tags)):
+            error(TypeError, '"unknown type " + repr(obj["$type"]) + '
+                '"; expected one of: {}"',
+                ', '.join(tags)
             )
         for tag, t in node.alternatives:
-            with self.block('if obj["$type"] == {!r}:', tag):
-                self._descend(t, 'obj', is_union=True)
+            with emit.block('if obj["$type"] == ', repr(tag)):
+                yield from descend(t, 'obj', is_union=True)
 
-    def make_namespace(self):
-        return {}
+    visit_struct = NotImplemented
+
+    return cg.generate_code(
+        locals(), mkname=typeident, signature=sig, emit=emit, symtab=symtab
+    )
 
 
 class Validator:
     def __init__(self, cls):
-        self._check = _Validator(cls)._check
+        self._node = wrap_model(cls)
 
     def validate(self, obj, ns=None, source='<string>'):
         if ns is None:
             ns.access.append(['<toplevel object>'])
         else:
             ns.access.append([oid])
-        self._check(
+        _validate(self._node)(
             obj, ns.seen, ns.defined, ns.errors, ns.access, source,
             is_union=False
         )