Source

hack2 / pypyembed / embed.py

Diff from to

File pypyembed/embed.py

 
-from ctypes import py_object, addressof, CDLL, c_long, c_char_p, c_int
-import inspect
+from ctypes import py_object, addressof, CDLL, c_long, c_char_p, c_int,\
+     POINTER, cast, c_char
+import inspect, numpy
 
 pypy = CDLL('./libpypy-c.so')
 pypy_main_startup = pypy.pypy_main_startup
 pypy_prepare_function.argtypes = [c_long, c_char_p]
 pypy_prepare_function.restype = None
 pypy_call_function = pypy.pypy_call_function
-pypy_call_function.argtypes = [c_char_p, c_long, py_object * 2]
+pypy_call_function.argtypes = [c_char_p, c_long, POINTER(py_object)]
 pypy_call_function.restype = py_object
 
 pypy_main_startup(3, (c_char_p * 3)("pypy-c", "-c", "pass"))
 
 pypy_prepare_function(0, open('inner.py').read())
 
+def wrap(arg):
+    if not isinstance(arg, numpy.ndarray):
+        return arg
+    return (arg.__array_interface__['data'][0], str(arg.dtype), arg.shape)
+
+def product(t):
+    s = 1
+    for i in t:
+        s *= i
+    return s
+
 def export_function(func):
     src = ['@cross_call\n'] + inspect.getsource(func).splitlines()[1:]
     pypy_prepare_function(0, "\n".join(src))
     def f(*args):
         lgt = len(args)
-        args = (py_object * lgt)(*args)
-        return pypy_call_function(func.func_name, lgt, args)
+        args = (py_object * lgt)(*[wrap(arg) for arg in args])
+        res = pypy_call_function(func.func_name, lgt, cast(args,
+                                                           POINTER(py_object)))
+        if isinstance(res, tuple):
+            size = res[2]
+            dtype = numpy.dtype(res[1])
+            raw_size = product(size)
+            buffer = (c_char*raw_size*dtype.itemsize).from_address(res[0])
+            return numpy.ndarray(size, dtype, buffer=buffer)
+        return res
     f.func_name = func.func_name
     return f