Commits

Armin Rigo committed c3cfc0d

Test and painful fix.

  • Participants
  • Parent commits e8ea8ae

Comments (0)

Files changed (2)

-import sys, os, md5, imp, shutil
+import sys, os, hashlib, imp, shutil
 from . import model, ffiplatform
 from . import __version__
 
         self.preamble = preamble
         self.kwds = kwds
         #
-        m = md5.md5('\x00'.join([sys.version[:3], __version__, preamble] +
-                                ffi._cdefsources))
+        m = hashlib.md5('\x00'.join([sys.version[:3], __version__, preamble] +
+                                    ffi._cdefsources))
         modulename = '_cffi_%s' % m.hexdigest()
         suffix = self._get_so_suffix()
         self.modulefilename = os.path.join('__pycache__', modulename + suffix)
-        self.sourcefilename = os.path.join('__pycache__', m.hexdigest() + '.c')
+        self.sourcefilename = os.path.join('__pycache__', modulename + '.c')
         self._status = 'init'
 
     def write_source(self, file=None):
         self._generate("collecttype")
 
     def _do_collect_type(self, tp):
-        if (isinstance(tp, (model.PointerType,
-                            model.StructOrUnion,
-                            model.ArrayType,
-                            model.FunctionPtrType)) and
-                (tp not in self._typesdict)):
+        if (not isinstance(tp, model.PrimitiveType) and
+                tp not in self._typesdict):
             num = len(self._typesdict)
             self._typesdict[tp] = num
             if isinstance(tp, model.StructOrUnion):
 
     def _write_source_to_f(self):
         self._collect_types()
+        #
         # The new module will have a _cffi_setup() function that receives
         # objects from the ffi world, and that calls some setup code in
         # the module.  This setup code is split in several independent
         # functions, e.g. one per constant.  The functions are "chained"
-        # by ending in a tail call to each other.  The following
-        # 'chained_list_constants' attribute contains the head of this
-        # chained list, as a string that gives the call to do, if any.
-        self._chained_list_constants = '0'
+        # by ending in a tail call to each other.
+        #
+        # This is further split in two chained lists, depending on if we
+        # can do it at import-time or if we must wait for _cffi_setup() to
+        # provide us with the <ctype> objects.  This is needed because we
+        # need the values of the enum constants in order to build the
+        # <ctype 'enum'> that we may have to pass to _cffi_setup().
+        #
+        # The following two 'chained_list_constants' items contains
+        # the head of these two chained lists, as a string that gives the
+        # call to do, if any.
+        self._chained_list_constants = ['0', '0']
         #
         prnt = self._prnt
         # first paste some standard set of lines that are mostly '#define'
         prnt('PyMODINIT_FUNC')
         prnt('init%s(void)' % modname)
         prnt('{')
-        prnt('  Py_InitModule("%s", _cffi_methods);' % modname)
+        prnt('  PyObject *lib;')
+        prnt('  lib = Py_InitModule("%s", _cffi_methods);' % modname)
+        prnt('  if (lib == NULL || %s < 0)' % (
+            self._chained_list_constants[False],))
+        prnt('    return;')
         prnt('  _cffi_init();')
         prnt('}')
 
                 extraarg = ', _cffi_type(%d)' % self._gettypenum(tp)
             errvalue = 'NULL'
         #
-        elif isinstance(tp, model.StructOrUnion):
+        elif isinstance(tp, (model.StructOrUnion, model.EnumType)):
             # a struct (not a struct pointer) as a function argument
-            self._prnt('  if (_cffi_to_c((char*)&%s, _cffi_type(%d), %s) < 0)'
+            self._prnt('  if (_cffi_to_c((char *)&%s, _cffi_type(%d), %s) < 0)'
                       % (tovar, self._gettypenum(tp), fromvar))
             self._prnt('    %s;' % errcode)
             return
             extraarg = ', _cffi_type(%d)' % self._gettypenum(tp)
             errvalue = 'NULL'
         #
-        elif isinstance(tp, model.EnumType):
-            converter = '_cffi_to_c_int'
-            errvalue = '-1'
-        #
         else:
             raise NotImplementedError(tp)
         #
         elif isinstance(tp, model.StructType):
             return '_cffi_from_c_struct((char *)&%s, _cffi_type(%d))' % (
                 var, self._gettypenum(tp))
+        elif isinstance(tp, model.EnumType):
+            return '_cffi_from_c_deref((char *)&%s, _cffi_type(%d))' % (
+                var, self._gettypenum(tp))
         else:
             raise NotImplementedError(tp)
 
     # constants, likely declared with '#define'
 
     def _generate_cpy_const(self, is_int, name, tp=None, category='const',
-                            vartp=None):
+                            vartp=None, delayed=True):
         prnt = self._prnt
         funcname = '_cffi_%s_%s' % (category, name)
         prnt('static int %s(PyObject *lib)' % funcname)
                 realexpr = name
             prnt('  i = (%s);' % (realexpr,))
             prnt('  o = %s;' % (self._convert_expr_from_c(tp, 'i'),))
+            assert delayed
         else:
             prnt('  if (LONG_MIN <= (%s) && (%s) <= LONG_MAX)' % (name, name))
             prnt('    o = PyInt_FromLong((long)(%s));' % (name,))
         prnt('  Py_DECREF(o);')
         prnt('  if (res < 0)')
         prnt('    return -1;')
-        prnt('  return %s;' % self._chained_list_constants)
-        self._chained_list_constants = funcname + '(lib)'
+        prnt('  return %s;' % self._chained_list_constants[delayed])
+        self._chained_list_constants[delayed] = funcname + '(lib)'
         prnt('}')
         prnt()
 
     def _generate_cpy_enum_decl(self, tp, name):
         if tp.partial:
             for enumerator in tp.enumerators:
-                self._generate_cpy_const(True, enumerator)
+                self._generate_cpy_const(True, enumerator, delayed=False)
             return
         #
         funcname = '_cffi_enum_%s' % name
                 name, enumerator, enumerator, enumvalue))
             prnt('    return -1;')
             prnt('  }')
-        prnt('  return %s;' % self._chained_list_constants)
-        self._chained_list_constants = funcname + '(lib)'
+        prnt('  return %s;' % self._chained_list_constants[True])
+        self._chained_list_constants[True] = funcname + '(lib)'
         prnt('}')
         prnt()
 
     _generate_cpy_enum_method = _generate_nothing
     _loading_cpy_enum = _loaded_noop
 
-    def _loaded_cpy_enum(self, tp, name, module, library):
+    def _loading_cpy_enum(self, tp, name, module):
         if tp.partial:
-            enumvalues = [getattr(library, enumerator)
+            enumvalues = [getattr(module, enumerator)
                           for enumerator in tp.enumerators]
             tp.enumvalues = tuple(enumvalues)
             tp.partial = False
-        else:
-            for enumerator, enumvalue in zip(tp.enumerators, tp.enumvalues):
-                setattr(library, enumerator, enumvalue)
+
+    def _loaded_cpy_enum(self, tp, name, module, library):
+        for enumerator, enumvalue in zip(tp.enumerators, tp.enumvalues):
+            setattr(library, enumerator, enumvalue)
 
     # ----------
     # macros: for now only for integers
         prnt = self._prnt
         prnt('static PyObject *_cffi_setup_custom(PyObject *lib)')
         prnt('{')
-        prnt('  if (%s < 0)' % self._chained_list_constants)
+        prnt('  if (%s < 0)' % self._chained_list_constants[True])
         prnt('    return NULL;')
         # produce the size of the opaque structures that need it.
         # So far, limited to the structures used as function arguments

testing/test_verify.py

         int foo_func(enum foo_e e) { return e; }
     """)
     assert lib.foo_func(lib.BB) == 2
+    assert lib.foo_func("BB") == 2
+
+def test_enum_as_function_result():
+    ffi = FFI()
+    ffi.cdef("""
+        enum foo_e { AA, BB, ... };
+        enum foo_e foo_func(int x);
+    """)
+    lib = ffi.verify("""
+        enum foo_e { AA, CC, BB };
+        enum foo_e foo_func(int x) { return x; }
+    """)
+    assert lib.foo_func(lib.BB) == "BB"
 
 def test_opaque_integer_as_function_result():
     ffi = FFI()