Source

hack2 / pypyembed / embed.py


from ctypes import py_object, addressof, CDLL, c_long, c_char_p, c_int,\
     POINTER, cast, c_char
import inspect, numpy, os, sys

dirpath = os.path.abspath(os.path.dirname(__file__))
if len(sys.argv) > 1:
    so_path = sys.argv[1]
else:
    so_path = os.path.join(dirpath, 'libpypy-c.so')

pypy = CDLL(so_path)
pypy_main_startup = pypy.pypy_main_startup
pypy_main_startup.argtypes = (c_int, c_char_p * 3)
pypy_main_startup.restype = c_int
pypy_prepare_function = pypy.pypy_prepare_function
pypy_prepare_function.argtypes = [c_char_p]
pypy_prepare_function.restype = None
pypy_call_function = pypy.pypy_call_function
pypy_call_function.argtypes = [c_char_p, c_long, POINTER(py_object)]
pypy_call_function.restype = py_object

pypy_main_startup(3, (c_char_p * 3)("pypy-c", "-c", "pass"))

def addr(obj):
    return addressof(py_object(obj))

pypy_prepare_function('__dir__ = "%s"\n' % dirpath + 'so_path = "%s"\n' % so_path +  open(os.path.join(dirpath, 'inner.py')).read())

def wrap(arg):
    if not isinstance(arg, numpy.ndarray):
        return arg
    return (arg.__array_interface__['data'][0], str(arg.dtype), arg.shape)

def product(t):
    s = 1
    for i in t:
        s *= i
    return s

def export_function(func):
    src = ['@cross_call\n'] + inspect.getsource(func).splitlines()[1:]
    pypy_prepare_function("\n".join(src))
    def f(*args):
        lgt = len(args)
        args = (py_object * lgt)(*[wrap(arg) for arg in args])
        res = pypy_call_function(func.func_name, lgt, cast(args,
                                                           POINTER(py_object)))
        if isinstance(res, tuple):
            size = res[2]
            dtype = numpy.dtype(res[1])
            raw_size = product(size)
            buffer = (c_char*raw_size*dtype.itemsize).from_address(res[0])
            return numpy.ndarray(size, dtype, buffer=buffer)
        return res
    f.func_name = func.func_name
    return f