Commits

Anonymous committed 16285c1

Issue #15958: bytes.join and bytearray.join now accept arbitrary buffer objects.

Comments (0)

Files changed (5)

Lib/test/test_bytes.py

             self.assertEqual(self.type2test(b"").join(lst), b"abc")
             self.assertEqual(self.type2test(b"").join(tuple(lst)), b"abc")
             self.assertEqual(self.type2test(b"").join(iter(lst)), b"abc")
-        self.assertEqual(self.type2test(b".").join([b"ab", b"cd"]), b"ab.cd")
-        # XXX more...
+        dot_join = self.type2test(b".:").join
+        self.assertEqual(dot_join([b"ab", b"cd"]), b"ab.:cd")
+        self.assertEqual(dot_join([memoryview(b"ab"), b"cd"]), b"ab.:cd")
+        self.assertEqual(dot_join([b"ab", memoryview(b"cd")]), b"ab.:cd")
+        self.assertEqual(dot_join([bytearray(b"ab"), b"cd"]), b"ab.:cd")
+        self.assertEqual(dot_join([b"ab", bytearray(b"cd")]), b"ab.:cd")
+        # Stress it with many items
+        seq = [b"abc"] * 1000
+        expected = b"abc" + b".:abc" * 999
+        self.assertEqual(dot_join(seq), expected)
+        # Error handling and cleanup when some item in the middle of the
+        # sequence has the wrong type.
+        with self.assertRaises(TypeError):
+            dot_join([bytearray(b"ab"), "cd", b"ef"])
+        with self.assertRaises(TypeError):
+            dot_join([memoryview(b"ab"), "cd", b"ef"])
 
     def test_count(self):
         b = self.type2test(b'mississippi')
             self.assertEqual(val, newval)
             self.assertTrue(val is not newval,
                             expr+' returned val on a mutable object')
+        sep = self.marshal(b'')
+        newval = sep.join([val])
+        self.assertEqual(val, newval)
+        self.assertIsNot(val, newval)
+
 
 class FixedStringTest(test.string_tests.BaseTest):
 
 Core and Builtins
 -----------------
 
+- Issue #15958: bytes.join and bytearray.join now accept arbitrary buffer
+  objects.
+
 - Issue #14783: Improve int() docstring and switch docstrings for str(),
   range(), and slice() to use multi-line signatures.
 

Objects/bytearrayobject.c

 #define FASTSEARCH fastsearch
 #define STRINGLIB(F) stringlib_##F
 #define STRINGLIB_CHAR char
+#define STRINGLIB_SIZEOF_CHAR 1
 #define STRINGLIB_LEN PyByteArray_GET_SIZE
 #define STRINGLIB_STR PyByteArray_AS_STRING
 #define STRINGLIB_NEW PyByteArray_FromStringAndSize
 #include "stringlib/fastsearch.h"
 #include "stringlib/count.h"
 #include "stringlib/find.h"
+#include "stringlib/join.h"
 #include "stringlib/partition.h"
 #include "stringlib/split.h"
 #include "stringlib/ctype.h"
 in between each pair, and return the result as a new bytearray.");
 
 static PyObject *
-bytearray_join(PyByteArrayObject *self, PyObject *it)
+bytearray_join(PyObject *self, PyObject *iterable)
 {
-    PyObject *seq;
-    Py_ssize_t mysize = Py_SIZE(self);
-    Py_ssize_t i;
-    Py_ssize_t n;
-    PyObject **items;
-    Py_ssize_t totalsize = 0;
-    PyObject *result;
-    char *dest;
-
-    seq = PySequence_Fast(it, "can only join an iterable");
-    if (seq == NULL)
-        return NULL;
-    n = PySequence_Fast_GET_SIZE(seq);
-    items = PySequence_Fast_ITEMS(seq);
-
-    /* Compute the total size, and check that they are all bytes */
-    /* XXX Shouldn't we use _getbuffer() on these items instead? */
-    for (i = 0; i < n; i++) {
-        PyObject *obj = items[i];
-        if (!PyByteArray_Check(obj) && !PyBytes_Check(obj)) {
-            PyErr_Format(PyExc_TypeError,
-                         "can only join an iterable of bytes "
-                         "(item %ld has type '%.100s')",
-                         /* XXX %ld isn't right on Win64 */
-                         (long)i, Py_TYPE(obj)->tp_name);
-            goto error;
-        }
-        if (i > 0)
-            totalsize += mysize;
-        totalsize += Py_SIZE(obj);
-        if (totalsize < 0) {
-            PyErr_NoMemory();
-            goto error;
-        }
-    }
-
-    /* Allocate the result, and copy the bytes */
-    result = PyByteArray_FromStringAndSize(NULL, totalsize);
-    if (result == NULL)
-        goto error;
-    dest = PyByteArray_AS_STRING(result);
-    for (i = 0; i < n; i++) {
-        PyObject *obj = items[i];
-        Py_ssize_t size = Py_SIZE(obj);
-        char *buf;
-        if (PyByteArray_Check(obj))
-           buf = PyByteArray_AS_STRING(obj);
-        else
-           buf = PyBytes_AS_STRING(obj);
-        if (i) {
-            memcpy(dest, self->ob_bytes, mysize);
-            dest += mysize;
-        }
-        memcpy(dest, buf, size);
-        dest += size;
-    }
-
-    /* Done */
-    Py_DECREF(seq);
-    return result;
-
-    /* Error handling */
-  error:
-    Py_DECREF(seq);
-    return NULL;
+    return stringlib_bytes_join(self, iterable);
 }
 
 PyDoc_STRVAR(splitlines__doc__,

Objects/bytesobject.c

 static Py_ssize_t
 _getbuffer(PyObject *obj, Py_buffer *view)
 {
-    PyBufferProcs *buffer = Py_TYPE(obj)->tp_as_buffer;
+    PyBufferProcs *bufferprocs;
+    if (PyBytes_CheckExact(obj)) {
+        /* Fast path, e.g. for .join() of many bytes objects */
+        Py_INCREF(obj);
+        view->obj = obj;
+        view->buf = PyBytes_AS_STRING(obj);
+        view->len = PyBytes_GET_SIZE(obj);
+        return view->len;
+    }
 
-    if (buffer == NULL || buffer->bf_getbuffer == NULL)
+    bufferprocs = Py_TYPE(obj)->tp_as_buffer;
+    if (bufferprocs == NULL || bufferprocs->bf_getbuffer == NULL)
     {
         PyErr_Format(PyExc_TypeError,
                      "Type %.100s doesn't support the buffer API",
         return -1;
     }
 
-    if (buffer->bf_getbuffer(obj, view, PyBUF_SIMPLE) < 0)
+    if (bufferprocs->bf_getbuffer(obj, view, PyBUF_SIMPLE) < 0)
         return -1;
     return view->len;
 }
 #include "stringlib/fastsearch.h"
 #include "stringlib/count.h"
 #include "stringlib/find.h"
+#include "stringlib/join.h"
 #include "stringlib/partition.h"
 #include "stringlib/split.h"
 #include "stringlib/ctype.h"
 Example: b'.'.join([b'ab', b'pq', b'rs']) -> b'ab.pq.rs'.");
 
 static PyObject *
-bytes_join(PyObject *self, PyObject *orig)
+bytes_join(PyObject *self, PyObject *iterable)
 {
-    char *sep = PyBytes_AS_STRING(self);
-    const Py_ssize_t seplen = PyBytes_GET_SIZE(self);
-    PyObject *res = NULL;
-    char *p;
-    Py_ssize_t seqlen = 0;
-    size_t sz = 0;
-    Py_ssize_t i;
-    PyObject *seq, *item;
-
-    seq = PySequence_Fast(orig, "");
-    if (seq == NULL) {
-        return NULL;
-    }
-
-    seqlen = PySequence_Size(seq);
-    if (seqlen == 0) {
-        Py_DECREF(seq);
-        return PyBytes_FromString("");
-    }
-    if (seqlen == 1) {
-        item = PySequence_Fast_GET_ITEM(seq, 0);
-        if (PyBytes_CheckExact(item)) {
-            Py_INCREF(item);
-            Py_DECREF(seq);
-            return item;
-        }
-    }
-
-    /* There are at least two things to join, or else we have a subclass
-     * of the builtin types in the sequence.
-     * Do a pre-pass to figure out the total amount of space we'll
-     * need (sz), and see whether all argument are bytes.
-     */
-    /* XXX Shouldn't we use _getbuffer() on these items instead? */
-    for (i = 0; i < seqlen; i++) {
-        const size_t old_sz = sz;
-        item = PySequence_Fast_GET_ITEM(seq, i);
-        if (!PyBytes_Check(item) && !PyByteArray_Check(item)) {
-            PyErr_Format(PyExc_TypeError,
-                         "sequence item %zd: expected bytes,"
-                         " %.80s found",
-                         i, Py_TYPE(item)->tp_name);
-            Py_DECREF(seq);
-            return NULL;
-        }
-        sz += Py_SIZE(item);
-        if (i != 0)
-            sz += seplen;
-        if (sz < old_sz || sz > PY_SSIZE_T_MAX) {
-            PyErr_SetString(PyExc_OverflowError,
-                "join() result is too long for bytes");
-            Py_DECREF(seq);
-            return NULL;
-        }
-    }
-
-    /* Allocate result space. */
-    res = PyBytes_FromStringAndSize((char*)NULL, sz);
-    if (res == NULL) {
-        Py_DECREF(seq);
-        return NULL;
-    }
-
-    /* Catenate everything. */
-    /* I'm not worried about a PyByteArray item growing because there's
-       nowhere in this function where we release the GIL. */
-    p = PyBytes_AS_STRING(res);
-    for (i = 0; i < seqlen; ++i) {
-        size_t n;
-        char *q;
-        if (i) {
-            Py_MEMCPY(p, sep, seplen);
-            p += seplen;
-        }
-        item = PySequence_Fast_GET_ITEM(seq, i);
-        n = Py_SIZE(item);
-        if (PyBytes_Check(item))
-            q = PyBytes_AS_STRING(item);
-        else
-            q = PyByteArray_AS_STRING(item);
-        Py_MEMCPY(p, q, n);
-        p += n;
-    }
-
-    Py_DECREF(seq);
-    return res;
+    return stringlib_bytes_join(self, iterable);
 }
 
 PyObject *

Objects/stringlib/join.h

+/* stringlib: bytes joining implementation */
+
+#if STRINGLIB_SIZEOF_CHAR != 1
+#error join.h only compatible with byte-wise strings
+#endif
+
+Py_LOCAL_INLINE(PyObject *)
+STRINGLIB(bytes_join)(PyObject *sep, PyObject *iterable)
+{
+    char *sepstr = STRINGLIB_STR(sep);
+    const Py_ssize_t seplen = STRINGLIB_LEN(sep);
+    PyObject *res = NULL;
+    char *p;
+    Py_ssize_t seqlen = 0;
+    Py_ssize_t sz = 0;
+    Py_ssize_t i, nbufs;
+    PyObject *seq, *item;
+    Py_buffer *buffers = NULL;
+#define NB_STATIC_BUFFERS 10
+    Py_buffer static_buffers[NB_STATIC_BUFFERS];
+
+    seq = PySequence_Fast(iterable, "can only join an iterable");
+    if (seq == NULL) {
+        return NULL;
+    }
+
+    seqlen = PySequence_Fast_GET_SIZE(seq);
+    if (seqlen == 0) {
+        Py_DECREF(seq);
+        return STRINGLIB_NEW(NULL, 0);
+    }
+#ifndef STRINGLIB_MUTABLE
+    if (seqlen == 1) {
+        item = PySequence_Fast_GET_ITEM(seq, 0);
+        if (STRINGLIB_CHECK_EXACT(item)) {
+            Py_INCREF(item);
+            Py_DECREF(seq);
+            return item;
+        }
+    }
+#endif
+    if (seqlen > NB_STATIC_BUFFERS) {
+        buffers = PyMem_NEW(Py_buffer, seqlen);
+        if (buffers == NULL) {
+            Py_DECREF(seq);
+            return NULL;
+        }
+    }
+    else {
+        buffers = static_buffers;
+    }
+
+    /* Here is the general case.  Do a pre-pass to figure out the total
+     * amount of space we'll need (sz), and see whether all arguments are
+     * buffer-compatible.
+     */
+    for (i = 0, nbufs = 0; i < seqlen; i++) {
+        Py_ssize_t itemlen;
+        item = PySequence_Fast_GET_ITEM(seq, i);
+        if (_getbuffer(item, &buffers[i]) < 0) {
+            PyErr_Format(PyExc_TypeError,
+                         "sequence item %zd: expected bytes, bytearray, "
+                         "or an object with the buffer interface, %.80s found",
+                         i, Py_TYPE(item)->tp_name);
+            goto error;
+        }
+        nbufs = i + 1;  /* for error cleanup */
+        itemlen = buffers[i].len;
+        if (itemlen > PY_SSIZE_T_MAX - sz) {
+            PyErr_SetString(PyExc_OverflowError,
+                            "join() result is too long");
+            goto error;
+        }
+        sz += itemlen;
+        if (i != 0) {
+            if (seplen > PY_SSIZE_T_MAX - sz) {
+                PyErr_SetString(PyExc_OverflowError,
+                                "join() result is too long");
+                goto error;
+            }
+            sz += seplen;
+        }
+        if (seqlen != PySequence_Fast_GET_SIZE(seq)) {
+            PyErr_SetString(PyExc_RuntimeError,
+                            "sequence changed size during iteration");
+            goto error;
+        }
+    }
+
+    /* Allocate result space. */
+    res = STRINGLIB_NEW(NULL, sz);
+    if (res == NULL)
+        goto error;
+
+    /* Catenate everything. */
+    p = STRINGLIB_STR(res);
+    for (i = 0; i < nbufs; i++) {
+        Py_ssize_t n;
+        char *q;
+        if (i) {
+            Py_MEMCPY(p, sepstr, seplen);
+            p += seplen;
+        }
+        n = buffers[i].len;
+        q = buffers[i].buf;
+        Py_MEMCPY(p, q, n);
+        p += n;
+    }
+    goto done;
+
+error:
+    res = NULL;
+done:
+    Py_DECREF(seq);
+    for (i = 0; i < nbufs; i++)
+        PyBuffer_Release(&buffers[i]);
+    if (buffers != static_buffers)
+        PyMem_FREE(buffers);
+    return res;
+}
+
+#undef NB_STATIC_BUFFERS