Commits

Nick Coghlan committed e3a7491

Return NotImplemented in more cases, test some additional builtin sequence types

  • Participants
  • Parent commits 84930cc
  • Branches respect_LHS_precedence

Comments (0)

Files changed (9)

Include/abstract.h

     This is typically a new iterator but if the argument
     is an iterator, this returns itself. */
 
+#define PyObject_IsIterable(obj) \
+    ((obj)->ob_type->tp_iter != NULL || PySequence_Check(obj))
+
 #define PyIter_Check(obj) \
     ((obj)->ob_type->tp_iternext != NULL && \
      (obj)->ob_type->tp_iternext != &_PyObject_NextNotImplemented)

Lib/collections/__init__.py

             return self.__class__(self.data + other.data)
         elif isinstance(other, type(self.data)):
             return self.__class__(self.data + other)
-        return self.__class__(self.data + list(other))
+        try:
+            other_data = list(other)
+        except TypeError:
+            return NotImplemented
+        return self.__class__(self.data + other_data)
     def __radd__(self, other):
         if isinstance(other, UserList):
             return self.__class__(other.data + self.data)
         elif isinstance(other, type(self.data)):
             return self.__class__(other + self.data)
-        return self.__class__(list(other) + self.data)
+        try:
+            other_data = list(other)
+        except TypeError:
+            return NotImplemented
+        return self.__class__(other_data + self.data)
     def __iadd__(self, other):
         if isinstance(other, UserList):
             self.data += other.data
         elif isinstance(other, type(self.data)):
             self.data += other
         else:
-            self.data += list(other)
+            try:
+                other_data = list(other)
+            except TypeError:
+                return NotImplemented
+            self.data += other_data
         return self
     def __mul__(self, n):
         return self.__class__(self.data*n)

Lib/test/list_tests.py

         u2 = u
         u += [2, 3]
         self.assertIs(u, u2)
+        
+        self.assertEqual(u.__iadd__(None), NotImplemented)
+        with self.assertRaises(TypeError):
+            u += None
 
         u = self.type2test("spam")
         u += "eggs"
         self.assertEqual(u, self.type2test("spameggs"))
 
-        self.assertRaises(TypeError, u.__iadd__, None)
-
     def test_imul(self):
         u = self.type2test([0, 1])
         u *= 3

Lib/test/test_binop.py

         # when A only implements sq_concat (but not nb_add)
         testcase = self
         class RHS:
-            def __init__(self):
-                self.allow_radd = True
-            def __iter__(self):
-                yield "Excellent!"
+            def __init__(self, value):
+                self.value = value
+                self.radd_called = False
             def __radd__(self, other):
-                if not self.allow_radd:
-                    testcase.fail("RHS __radd__ called!")
-                return other + type(other)(self)
-        lhs = []
-        rhs = RHS()
-        self.assertEqual(lhs.__add__(rhs), NotImplemented)
-        self.assertEqual(lhs + rhs, ["Excellent!"])
-        with self.assertRaises(TypeError):
-            lhs + 1
-        rhs.allow_radd = False
-        orig_lhs = lhs
-        lhs += rhs
-        self.assertIs(lhs, orig_lhs)
-        self.assertEqual(lhs, ["Excellent!"])
-        with self.assertRaises(TypeError):
-            lhs += 1
+                self.radd_called = True
+                return other + self.value
+        cases = [
+            ([], ["Excellent!"]),
+            (bytearray(), bytearray(b"Excellent!")),
+            ((), ("Excellent!",)),
+            ("", "Excellent!"),
+            (b"", b"Excellent!"),
+        ]
+        for lhs, expected in cases:
+            rhs = RHS(expected)
+            # Check A + B
+            self.assertEqual(lhs.__add__(rhs), NotImplemented)
+            self.assertEqual(lhs + rhs, expected)
+            self.assertTrue(rhs.radd_called)
+            with self.assertRaises(TypeError):
+                lhs + 1
+            # Check A += B
+            if hasattr(lhs, "__iadd__"):
+                self.assertEqual(lhs.__iadd__(rhs), NotImplemented)
+            orig_lhs = lhs
+            rhs.radd_called = False
+            lhs += rhs
+            self.assertEqual(lhs, expected)
+            self.assertTrue(rhs.radd_called)
+            self.assertIsNot(lhs, orig_lhs)
+            with self.assertRaises(TypeError):
+                lhs += 1
 
     def test_issue11477_sequence_concatenation_subclass(self):
         # Check overloading for A + B and A += B

Lib/test/test_userlist.py

         super().test_iadd()
         u = [0, 1]
         u += UserList([0, 1])
+        self.assertEqual(type(u), list)
         self.assertEqual(u, [0, 1, 0, 1])
 
+
     def test_mixedcmp(self):
         u = self.type2test([0, 1])
         self.assertEqual(u, [0, 1])

Objects/bytearrayobject.c

     return 0;
 }
 
-PyObject *
-PyByteArray_Concat(PyObject *a, PyObject *b)
+static PyObject *
+bytearray_concat(PyObject *a, PyObject *b)
 {
     Py_ssize_t size;
     Py_buffer va, vb;
         PyBuffer_Release(&va);
     if (vb.len != -1)
         PyBuffer_Release(&vb);
-    return (PyObject *)result;
+    return result;
+}
+
+PyObject *
+PyByteArray_Concat(PyObject *a, PyObject *b)
+{
+    PyObject *result = bytearray_concat(a, b);
+    if (result == Py_NotImplemented) {
+        Py_DECREF(result);
+        PyErr_Format(PyExc_TypeError, "can't concat %.100s to %.100s",
+                     Py_TYPE(a)->tp_name, Py_TYPE(b)->tp_name);
+        return NULL;
+    }
+
+    return result;
 }
 
 /* Functions stuffed into the type object */
     Py_buffer vo;
 
     if (_getbuffer(other, &vo) < 0) {
-        PyErr_Format(PyExc_TypeError, "can't concat %.100s to %.100s",
-                     Py_TYPE(other)->tp_name, Py_TYPE(self)->tp_name);
-        return NULL;
+        Py_RETURN_NOTIMPLEMENTED;
     }
 
     mysize = Py_SIZE(self);
         Py_DECREF(encoded);
         if (new == NULL)
             return -1;
+        if (new == Py_NotImplemented) {
+            Py_DECREF(new);
+            return -1;
+        }
         Py_DECREF(new);
         return 0;
     }
 
 static PySequenceMethods bytearray_as_sequence = {
     (lenfunc)bytearray_length,              /* sq_length */
-    (binaryfunc)PyByteArray_Concat,         /* sq_concat */
+    (binaryfunc)bytearray_concat,           /* sq_concat */
     (ssizeargfunc)bytearray_repeat,         /* sq_repeat */
     (ssizeargfunc)bytearray_getitem,        /* sq_item */
     0,                                      /* sq_slice */

Objects/listobject.c

     PyObject **src, **dest;
     PyListObject *np;
     if (!PyList_Check(bb)) {
-        Py_INCREF(Py_NotImplemented);
-        return Py_NotImplemented;
+        Py_RETURN_NOTIMPLEMENTED;
     }
 #define b ((PyListObject *)bb)
     size = Py_SIZE(a) + Py_SIZE(b);
 list_inplace_concat(PyListObject *self, PyObject *other)
 {
     PyObject *result;
+    
+    if (!PyObject_IsIterable(other))
+        Py_RETURN_NOTIMPLEMENTED;
 
     result = listextend(self, other);
     if (result == NULL)

Objects/tupleobject.c

     PyObject **src, **dest;
     PyTupleObject *np;
     if (!PyTuple_Check(bb)) {
-        PyErr_Format(PyExc_TypeError,
-             "can only concatenate tuple (not \"%.200s\") to tuple",
-                 Py_TYPE(bb)->tp_name);
-        return NULL;
+        Py_RETURN_NOTIMPLEMENTED;
     }
 #define b ((PyTupleObject *)bb)
     size = Py_SIZE(a) + Py_SIZE(b);

Objects/unicodeobject.c

 
 /* Concat to string or Unicode object giving a new Unicode object. */
 
+static PyObject *
+unicode_concat(PyObject *left, PyObject *right)
+{
+    /* Unlike the public C API, the slot impl may return NotImplemented */
+    if (!PyUnicode_Check(left) || !PyUnicode_Check(right)) {
+        Py_RETURN_NOTIMPLEMENTED;
+    }
+    return PyUnicode_Concat(left, right);
+}
+
 PyObject *
 PyUnicode_Concat(PyObject *left, PyObject *right)
 {
     /* Coerce the two arguments */
     u = PyUnicode_FromObject(left);
     if (u == NULL) {
-        PyObject *ni = Py_NotImplemented;
-        /* XXX: PyErr_Matches check??? */
-        PyErr_Clear();
-        Py_INCREF(ni);
-        return ni;
+        goto onError;
     }
     v = PyUnicode_FromObject(right);
     if (v == NULL) {
-        PyObject *ni = Py_NotImplemented;
-        /* XXX: PyErr_Matches check??? */
-        PyErr_Clear();
-        Py_DECREF(u);
-        Py_INCREF(ni);
-        return ni;
+        goto onError;
     }
 
     /* Shortcuts */
 
 static PySequenceMethods unicode_as_sequence = {
     (lenfunc) unicode_length,       /* sq_length */
-    PyUnicode_Concat,           /* sq_concat */
+    unicode_concat,           /* sq_concat */
     (ssizeargfunc) unicode_repeat,  /* sq_repeat */
     (ssizeargfunc) unicode_getitem,     /* sq_item */
     0,                  /* sq_slice */