Commits

Armin Rigo committed df807c9

Fix issue #45: accept unicode strings as the name of the enum constants,
as long as they can be converted to plain strings.

  • Participants
  • Parent commits beba075

Comments (0)

Files changed (2)

c/_cffi_backend.c

 # define STR_OR_BYTES "bytes"
 # define PyText_Type PyUnicode_Type
 # define PyText_Check PyUnicode_Check
+# define PyTextAny_Check PyUnicode_Check
 # define PyText_FromFormat PyUnicode_FromFormat
 # define PyText_AsUTF8 _PyUnicode_AsString   /* PyUnicode_AsUTF8 in Py3.3 */
 # define PyText_AS_UTF8 _PyUnicode_AsString
 # define STR_OR_BYTES "str"
 # define PyText_Type PyString_Type
 # define PyText_Check PyString_Check
+# define PyTextAny_Check(op) (PyString_Check(op) || PyUnicode_Check(op))
 # define PyText_FromFormat PyString_FromFormat
 # define PyText_AsUTF8 PyString_AsString
 # define PyText_AS_UTF8 PyString_AS_STRING
 static PyObject *convert_enum_string_to_int(CTypeDescrObject *ct, PyObject *ob)
 {
     PyObject *d_value;
-    char *p = PyText_AS_UTF8(ob);
+    char *p = PyText_AsUTF8(ob);
+    if (p == NULL)
+        return NULL;
 
     if (p[0] == '#') {
         char *number = p + 1;       /* strip initial '#' */
             else {
                 PyObject *ob;
                 PyErr_Clear();
-                if (!PyText_Check(init)) {
+                if (!PyTextAny_Check(init)) {
                     expected = "str or int";
                     goto cannot_convert;
                 }
                                  (CT_POINTER|CT_FUNCTIONPTR|CT_ARRAY)) {
         value = (Py_intptr_t)((CDataObject *)ob)->c_data;
     }
-    else if (PyText_Check(ob)) {
-        if (ct->ct_flags & CT_IS_ENUM) {
-            ob = convert_enum_string_to_int(ct, ob);
-            if (ob == NULL)
-                return NULL;
-            cd = cast_to_integer_or_char(ct, ob);
-            Py_DECREF(ob);
-            return cd;
+    else if ((ct->ct_flags & CT_IS_ENUM) && PyTextAny_Check(ob)) {
+        ob = convert_enum_string_to_int(ct, ob);
+        if (ob == NULL)
+            return NULL;
+        cd = cast_to_integer_or_char(ct, ob);
+        Py_DECREF(ob);
+        return cd;
+    }
+#if PY_MAJOR_VERSION < 3
+    else if (PyString_Check(ob)) {
+        if (PyString_GET_SIZE(ob) != 1) {
+            PyErr_Format(PyExc_TypeError,
+                         "cannot cast string of length %zd to ctype '%s'",
+                         PyString_GET_SIZE(ob), ct->ct_name);
+            return NULL;
         }
-        else {
-#if PY_MAJOR_VERSION < 3
-            if (PyString_GET_SIZE(ob) != 1) {
-                PyErr_Format(PyExc_TypeError,
-                      "cannot cast string of length %zd to ctype '%s'",
-                             PyString_GET_SIZE(ob), ct->ct_name);
-                return NULL;
-            }
-            value = (unsigned char)PyString_AS_STRING(ob)[0];
-#else
-            wchar_t ordinal;
-            if (_my_PyUnicode_AsSingleWideChar(ob, &ordinal) < 0) {
-                PyErr_Format(PyExc_TypeError,
-                             "cannot cast string of length %zd to ctype '%s'",
-                             PyUnicode_GET_SIZE(ob), ct->ct_name);
-                return NULL;
-            }
-            value = (long)ordinal;
+        value = (unsigned char)PyString_AS_STRING(ob)[0];
+    }
 #endif
-        }
-    }
 #ifdef HAVE_WCHAR_H
     else if (PyUnicode_Check(ob)) {
         wchar_t ordinal;
         if (_my_PyUnicode_AsSingleWideChar(ob, &ordinal) < 0) {
             PyErr_Format(PyExc_TypeError,
-                         "cannot cast unicode of length %zd to ctype '%s'",
+                      "cannot cast unicode string of length %zd to ctype '%s'",
                          PyUnicode_GET_SIZE(ob), ct->ct_name);
             return NULL;
         }
 {
     char *ename;
     PyObject *enumerators, *enumvalues;
-    PyObject *dict1 = NULL, *dict2 = NULL, *combined = NULL;
+    PyObject *dict1 = NULL, *dict2 = NULL, *combined = NULL, *tmpkey = NULL;
     ffi_type *ffitype;
     int name_size;
     CTypeDescrObject *td;
     if (dict1 == NULL)
         goto error;
     for (i=n; --i >= 0; ) {
-        PyObject *key = PyTuple_GET_ITEM(enumerators, i);
+        long lvalue;
         PyObject *value = PyTuple_GET_ITEM(enumvalues, i);
-        long lvalue;
-        if (!PyText_Check(key)) {
-            PyErr_SetString(PyExc_TypeError,
-                            "enumerators must be a list of strings");
-            goto error;
+        tmpkey = PyTuple_GET_ITEM(enumerators, i);
+        Py_INCREF(tmpkey);
+        if (!PyText_Check(tmpkey)) {
+#if PY_MAJOR_VERSION < 3
+            if (PyUnicode_Check(tmpkey)) {
+                char *text = PyText_AsUTF8(tmpkey);
+                if (text == NULL)
+                    goto error;
+                Py_DECREF(tmpkey);
+                tmpkey = PyString_FromString(text);
+                if (tmpkey == NULL)
+                    goto error;
+            }
+            else
+#endif
+            {
+                PyErr_SetString(PyExc_TypeError,
+                                "enumerators must be a list of strings");
+                goto error;
+            }
         }
         lvalue = PyLong_AsLong(value);
         if ((lvalue == -1 && PyErr_Occurred()) || lvalue != (int)lvalue) {
             PyErr_Format(PyExc_OverflowError,
                          "enum '%s' declaration for '%s' does not fit an int",
-                         ename, PyText_AS_UTF8(key));
+                         ename, PyText_AS_UTF8(tmpkey));
             goto error;
         }
-        if (PyDict_SetItem(dict1, key, value) < 0)
+        if (PyDict_SetItem(dict1, tmpkey, value) < 0)
             goto error;
+        Py_DECREF(tmpkey);
+        tmpkey = NULL;
     }
 
     dict2 = PyDict_New();
     return (PyObject *)td;
 
  error:
+    Py_XDECREF(tmpkey);
     Py_XDECREF(combined);
     Py_XDECREF(dict2);
     Py_XDECREF(dict1);
     assert p.a1 == "c"
     e = py.test.raises(TypeError, newp, BStructPtr, [None])
     assert "must be a str or int, not NoneType" in str(e.value)
+    if sys.version_info < (3,):
+        p.a1 = unicode("def")
+        assert p.a1 == "def" and type(p.a1) is str
+        py.test.raises(UnicodeEncodeError, "p.a1 = unichr(1234)")
+        BEnum2 = new_enum_type(unicode("foo"), (unicode('abc'),), (5,))
+        assert string(cast(BEnum2, unicode('abc'))) == 'abc'
 
 def test_enum_overflow():
     for ovf in (2**63, -2**63-1, 2**31, -2**31-1):