Commits

Lenard Lindstrom committed 8ca9901

Generate rffi definitions automatically from cffi definitions

Add an ``tourcffi`` module, which contains a ``Parser`` class that does a
depth-first tour of a ``cffi`` ``CType`` instance. The ``rcffi`` module defines
a ``Parser.Handler`` visitor that build an rpython ``lltype`` representation
of the ``CType``.

The ``array_view`` module is modified to replace the explicit rffi C struct
definitions of ARRAYVIEW and ARRAYVIEW_P with thoses taken from ``arrblit``,
a Python level module using ``cffi``.

This initial version of ``rcffi`` does not handle C struct typedefs, so
code has been changed to use ``struct ArrayView`` in place of ``ArrayView_t``.
Also, ``rffi.CStruct`` is replaced with ``lltype.Struct``, to keep C struct
field names compatible with ``cffi``.

Comments (0)

Files changed (6)

-from rpython.rtyper.lltypesystem.rffi import (CStruct, CFixedArray,
-                                              INT, SSIZE_TP, CHAR, UCHARP)
-from rpython.rtyper.lltypesystem.lltype import Ptr
+import arrblit
+import rcffi
 
-ARRAYVIEW = CStruct('ArrayView',
-                    ('typestr', CFixedArray(CHAR, 4)),
-                    ('ndim', INT),
-                    ('shape', SSIZE_TP),
-                    ('strides', SSIZE_TP),
-                    ('data', UCHARP))
+parser = rcffi.Parser(arrblit.ffi)
+encoder = rcffi.RffiKnownFilter(rcffi.Encoder())
 
-ARRAYVIEW_P = Ptr(ARRAYVIEW)
+ARRAYVIEW = parser.visit('struct ArrayView', encoder)
+ARRAYVIEW_P = parser.visit('struct ArrayView *', encoder)
 
 ffi = cffi.FFI()
 ffi.cdef("""
-    typedef struct ArrayView {
+    struct ArrayView {
         char typestr[4];
         int  ndim;
         ssize_t *shape;
         ssize_t *strides;
         unsigned char *data;
-    } ArrayView_t;
+    };
 
-    void blit(ArrayView_t *destination, ArrayView_t *source);
+    void blit(struct ArrayView *destination, struct ArrayView *source);
     void rpython_startup_code();
 """)
 try:
                   ffi.new('ssize_t[]', dst_shape),
                   ffi.new('ssize_t[]', dst_strides),
                   ffi.cast('unsigned char *', dst_data)]
-    dst = ffi.new('ArrayView_t *', dst_values)
+    dst = ffi.new('struct ArrayView *', dst_values)
     src_values = [src_type,
                   ndim,
                   ffi.new('ssize_t[]', src_shape),
                   ffi.new('ssize_t[]', src_strides),
                   ffi.cast('unsigned char *', src_data)]
-    src = ffi.new('ArrayView_t *', src_values)
+    src = ffi.new('struct ArrayView *', src_values)
     _blitter.blit(dst, src)
 
 def get_field(array_interface, name):
 
 #include <stdio.h>
 
-typedef struct ArrayView {
+struct ArrayView {
     char typestr[4];
     int  ndim;
     ssize_t *shape;
     ssize_t *strides;
     unsigned char *data;
-} ArrayView_t;
+};
 
 #endif
 """Export blit_buffer"""
 from blit_interpreter import compile_loop, execute_loop
 
-def add_field_getters(fields, prefix='c_'):
-    for name in fields:
-        exec """\
-def get_%s_field(rec):
-    return rec.%s%s
-""" % (name, prefix, name) in globals()
-
-fields = ['ndim', 'typestr', 'shape', 'strides', 'data']
-
-
 ## RPython
 
 code_cache = {}
 
 !For now, ignore integer sign. Truncate if necessary.
 """
-    ndim = get_ndim_field(destination)
-    dst_type = get_typestr_field(destination)
-    dst_shape = get_shape_field(destination)
-    dst_strides = get_strides_field(destination)
-    dst_data = get_data_field(destination)
-    src_type = get_typestr_field(source)
-    src_shape = get_shape_field(source)
-    src_strides = get_strides_field(source)
-    src_data = get_data_field(source)
+    ndim = destination.ndim
+    dst_type = destination.typestr
+    dst_shape = destination.shape
+    dst_strides = destination.strides
+    dst_data = destination.data
+    src_type = source.typestr
+    src_shape = source.shape
+    src_strides = source.strides
+    src_data = source.data
     __, dst_bytesize, dst_lil_endian = decode_int_type(dst_type)
     __, src_bytesize, src_lil_endian = decode_int_type(src_type)
     if src_lil_endian:
     size = typestr.item2
     return signed == 'i', '0123456789'.find(size), order == '<'
 
-add_field_getters(fields)
 
 # not RPython
 
         order, signed, size = typestr[0:3]
         return signed == 'i', '0123456789'.find(size), order == '<'
 
-    add_field_getters(fields, '')
-
-del fields, add_field_getters, python_only
+del python_only
+from tourcffi import Parser
+from rpython.rtyper.lltypesystem import rffi, lltype
+
+from collections import deque
+
+def int_sized(byte_size):
+    if rffi.sizeof(rffi.SIGNEDCHAR) == byte_size:
+        return rffi.SIGNEDCHAR
+    if rffi.sizeof(rffi.SHORT) == byte_size:
+        return rffi.SHORT
+    if rffi.sizeof(rffi.INT) == byte_size:
+        return rffi.INT
+    if rffi.sizeof(rffi.LONG) == byte_size:
+        return rffi.LONG
+    if rffi.sizeof(rffi.LONGLONG) == byte_size:
+        return rffi.LONGLONG
+    raise RuntimeError("Unsupported int size {}".format(byte_size))
+
+def uint_sized(byte_size):
+    if rffi.sizeof(rffi.UCHAR) == byte_size:
+        return rffi.UCHAR
+    if rffi.sizeof(rffi.USHORT) == byte_size:
+        return rffi.USHORT
+    if rffi.sizeof(rffi.UINT) == byte_size:
+        return rffi.UINT
+    if rffi.sizeof(rffi.ULONG) == byte_size:
+        return rffi.ULONG
+    if rffi.sizeof(rffi.ULONGLONG) == byte_size:
+        return rffi.ULONGLONG
+    raise RuntimeError("Unsupported int size {}".format(byte_size))
+
+def pointer(rtype):
+    if not isinstance(rtype, lltype.ContainerType):
+        rtype = lltype.Array(rtype, hints={'nolength': True})
+    return lltype.Ptr(rtype)
+
+class KnownFilter(Parser.Handler):
+    def __init__(self, encoder, known_types=None):
+        if known_types is None:
+            known_types = {}
+        self.encoder = encoder
+        self.known_types = known_types
+
+    def _weave(attributes, locs, prefix, aspect):
+        for name in attributes:
+            if name.startswith(prefix):
+                aspect(locs, prefix, name)
+
+    def _add_visitor_callbacks(locs, prefix, meth_nm):
+        globs = globals()
+        funcs = {}
+        node_nm = meth_nm[len(prefix):]
+        format_args = dict(enter_nm=meth_nm,
+                           exit_nm="exit_{}".format(node_nm),
+                           arg_nm="{}_obj".format(node_nm))
+        exec("""def {enter_nm}(self, {arg_nm}, parser, result):
+                    try:
+                        result = self.known_types[{arg_nm}.cname]
+                    except KeyError:
+                        return self.encoder.{enter_nm}({arg_nm}, parser, result)
+                    else:
+                        parser.skip(result)
+             """.format(**format_args), globals(), locs)
+
+        exec("""def {exit_nm}(self, {arg_nm}, parser, result):
+                    result = self.encoder.{exit_nm}({arg_nm}, parser, result)
+                    self.known_types[{arg_nm}.cname] = result
+                    return result
+             """.format(**format_args), globs, funcs)
+
+        locs.update(funcs)
+
+    _weave(dir(Parser.Handler), locals(), 'enter_', _add_visitor_callbacks)
+
+    def enter_field(self, field_obj, parser, result):
+        return self.encoder.enter_field(field_obj, parser, result)
+
+    def exit_field(self, field_obj, parser, result):
+        return self.encoder.exit_field(field_obj, parser, result)
+
+    def enter_cffi(self, cffi_obj, parser, result):
+        return self.encoder.enter_cffi(cffi_obj, parser, result)
+
+    def exit_cffi(self, cffi_obj, parser, result):
+        return self.encoder.exit_cffi(cffi_obj, parser, result)
+
+    del _weave, _add_visitor_callbacks
+
+class RffiKnownFilter(KnownFilter):
+    primitives = {
+        'char': rffi.CHAR,
+        'unsigned char': rffi.UCHAR,
+        'short': rffi.SHORT,
+        'unsigned short': rffi.USHORT,
+        'int': rffi.INT,
+        'unsigned int': rffi.UINT,
+        'long': rffi.LONG,
+        'ulong': rffi.ULONG,
+        'long long': rffi.LONGLONG,
+        'unsigned long long': rffi.ULONGLONG,
+        'wchar_t': rffi.WCHAR_T,
+        'size_t': rffi.SIZE_T,
+        'ssize_t': rffi.SSIZE_T,
+        'float': rffi.FLOAT,
+        'double': rffi.DOUBLE,
+        'long double': rffi.LONGDOUBLE,
+        'void *': rffi.VOIDP,
+        'int8_t': int_sized(1),
+        'uint8_t': uint_sized(1),
+        'int16_t': int_sized(2),
+        'uint16_t': uint_sized(2),
+        'int32_t': int_sized(4),
+        'uint32_t': uint_sized(4),
+        'char *': rffi.CCHARP,
+        'unsigned char *': rffi.UCHARP
+    }
+
+    def __init__(self, encoder):
+        super(RffiKnownFilter, self).__init__(encoder, self.primitives)
+
+class Encoder(Parser.Handler):
+    def exit_struct(self, struct_obj, parser, result):
+        STRUCT = 'struct '
+        cname = struct_obj.cname
+        if cname.startswith(STRUCT):
+            name = cname[len(STRUCT):]
+        else:
+            name = cname
+        return lltype.Struct(name, *result)
+    
+    def exit_field(self, field_obj, parser, result):
+        return field_obj[0], result
+
+    def exit_pointer(self, pointer_obj, parser, result):
+        return pointer(result)
+
+    def exit_array(self, array_obj, parser, result):
+        return rffi.CFixedArray(result, array_obj.length)
+
+    def enter_primitive(self, primitive_obj, parser, result):
+        raise ValueError("Unrecognized type {}".format(primitive_obj.cname))
+
+def cffi_to_rffi(obj, parser):
+    encoder = RffiKnownFilter(Encoder())
+    return parser.visit(obj, encoder)
+from collections import deque
+
+class Parser(object):
+    class SkipNode(BaseException):
+        pass
+
+    def __init__(self, ffi):
+        self.ffi = ffi
+
+    def visit(self, cffi_obj, handler):
+        return self.visit_cffi(cffi_obj, handler, None)
+
+    def _visit_cffi(self, cffi_obj, handler, result):
+        if isinstance(cffi_obj, self.ffi.CData):
+            return self.visit_cdata(cffi_obj, handler, result)
+        elif isinstance(cffi_obj, self.ffi.CType):
+            return self.visit_ctype(cffi_obj, handler, result)
+        elif isinstance(cffi_obj, str):
+            return self.visit_ctype(self.ffi.typeof(cffi_obj), handler, result)
+        else:
+            raise ValueError("Unsupported object {}".format(cffi_obj))
+
+    def _visit_cdata(self, cdata_obj, handler, result):
+        ctype = self.ffi.typeof(cdata_obj)
+        return self.decode(ctype, cdata_obj, handler, result)
+
+    def decode(self, ctype_obj, cdata_obj, handler, result):
+        return handler.decode(ctype_obj, cdata_obj, self, result)
+
+    def _visit_ctype(self, ctype_obj, handler, result):
+        kind = ctype_obj.kind
+        if kind == 'primitive':
+            return self.visit_primitive(ctype_obj, handler, result)
+        elif kind == 'pointer':
+            return self.visit_pointer(ctype_obj, handler, result)
+        elif kind == 'struct':
+            return self.visit_struct(ctype_obj, handler, result)
+        elif kind == 'array':
+            return self.visit_array(ctype_obj, handler, result)
+        else:
+            self.visit_other(ctype_obj, handler, result)
+
+    def _visit_primitive(self, ctype, handler, result):
+        return result
+
+    def _visit_pointer(self, pointer_obj, handler, result):
+        return self.visit_ctype(pointer_obj.item, handler, result)
+
+    def _visit_array(self, array_obj, handler, result):
+        return self.visit_ctype(array_obj.item, handler, result)
+
+    def _visit_struct(self, struct_obj, handler, result):
+        results = []
+        for field in struct_obj.fields:
+            result = self.visit_field(field, handler, result)
+            results.append(result)
+        return results
+
+    def _visit_field(self, field_obj, handler, result):
+        return self.visit_ctype(field_obj[1].type, handler, result)
+
+    def _visit_other(self, other_obj, handler, result):
+        raise ValueError("Unhandled CType kind '{}'".format(other_obj.kind))
+
+    def skip(self, result=None):
+        raise self.SkipNode('', result)
+
+    def _weave(locs, prefix, aspect):
+        for name, value in locs.items():
+            if name.startswith(prefix):
+                aspect(locs, prefix, name, value)
+
+    def _add_visitor_aspects(locs, prefix, meth_nm, fn):
+        node_nm = meth_nm[len(prefix):]
+        format_args = dict(meth_nm=meth_nm[1:],
+                           node_nm=node_nm,
+                           enter_nm="enter_{}".format(node_nm),
+                           exit_nm="exit_{}".format(node_nm),
+                           arg_nm="{}_obj".format(node_nm))
+        exec("""def {meth_nm}(self, {arg_nm}, handler, result):
+                    try:
+                        result = handler.{enter_nm}({arg_nm}, self, result)
+                    except self.SkipNode as e:
+                        return e.args[1]
+                    else:
+                        result = self._{meth_nm}({arg_nm}, handler, result)
+                        return handler.{exit_nm}({arg_nm}, self, result)
+             """.format(**format_args), globals(), locs)
+    _weave(locals(), '_visit_', _add_visitor_aspects)
+
+    class Handler(object):
+        def decode(self, ctype_obj, cdata_obj, parser, result):
+            return result
+
+        @classmethod
+        def _add_visitor_callbacks(cls, locs, prefix, meth_nm, fn):
+            globs = globals()
+            node_nm = meth_nm[len(prefix):]
+            funcs = {}
+            format_args = dict(node_nm=node_nm,
+                               enter_nm="enter_{}".format(node_nm),
+                               exit_nm="exit_{}".format(node_nm),
+                               arg_nm="{}_obj".format(node_nm))
+            exec("""def {enter_nm}(self, {arg_nm}, parser, result):
+                        '''{node_nm} visitor stub'''
+                        return result
+                 """.format(**format_args), globs, funcs)
+
+            exec("""def {exit_nm}(self, {arg_nm}, parser, result):
+                        '''{node_nm} visitor stub'''
+                        return result
+                 """.format(**format_args), globs, funcs)
+
+            for name, fn in funcs.items():
+                setattr(cls, name, fn)
+    _weave(locals(), '_visit_', Handler._add_visitor_callbacks)
+
+    del _weave, _add_visitor_aspects, Handler._add_visitor_callbacks
+
+class UniqueHandler(Parser.Handler):
+    def __init__(self, *args, **kwds):
+        self.structs_seen = set()
+
+    def enter_struct(self, struct_obj, parser, result):
+        cname = struct_obj.cname
+        if cname not in self.structs_seen:
+            self.structs_seen.add(cname)
+        else:
+            parser.skip(self.handle_seen_struct(struct_obj, parser, result))
+        return result
+
+    def handle_seen_struct(self, struct_obj, parser, result):
+        return result
+
+class TypedefHandler(UniqueHandler):
+    def __init__(self):
+        UniqueHandler.__init__(self)
+        self.typedefs = deque()
+
+    def enter_cffi(self, cffi_obj, parser, result):
+        return ()
+
+    def exit_cffi(self, cffi_obj, parser, result):
+        super(TypedefHandler, self).exit_cffi(cffi_obj, parser, result)
+        return list(self.typedefs)
+
+    def enter_field(self, field_obj, parser, result):
+        return ()
+
+    def exit_field(self, field_obj, parser, result):
+        return result + (field_obj[0], ';')
+
+    def exit_struct(self, struct_obj, parser, result):
+        cname = struct_obj.cname
+        if cname.startswith('struct '):
+            prefix = 'struct', cname[7:], '{'
+            suffix = '}', ';'
+            field = 'struct', cname[7:]
+        else:
+            prefix = 'typedef', 'struct', '{'
+            suffix = '}', cname, ';'
+            field = cname,
+        self.typedefs.extend(prefix + concat(result) + suffix)
+        return field
+
+    def handler_seen_struct(self, struct_obj, parser, result):
+        return result + (struct_obj.cname,)
+
+    def exit_pointer(self, pointer_obj, parser, result):
+        return result + ('*',)
+
+    def exit_array(self, array_obj, parser, result):
+        return result + ('[', str(array_obj.length), ']')
+
+    def exit_primitive(self, primative_obj, parser, result):
+        return result + (primative_obj.cname,)
+
+def pprint(cffi_obj, parser):
+    from sys import stdout
+
+    tokens = parser.visit(cffi_obj, TypedefHandler())
+    indent = 0
+    line_start = True
+    for token in tokens:
+        if token == '{':
+            indent += 4
+            stdout.write(" {\n")
+            line_start = True
+        elif token == '}':
+            line_start = False
+            indent -= 4
+            stdout.write("{0}}}".format(' ' * indent))
+        elif token == ';':
+            stdout.write(";\n")
+            line_start = True
+        elif line_start:
+            stdout.write("{0}{1}".format(' ' * indent, token))
+            line_start = False
+        else:
+            stdout.write(" {0}".format(token))
+    stdout.write("\n")
+
+def concat(seqs):
+    result = ()
+    for seq in seqs:
+        result += seq
+    return result