Commits

Vinay Sajip committed edba722

Closes #12291: Fixed bug which was found when doing multiple loads from one stream.

Comments (0)

Files changed (4)

Lib/importlib/test/source/test_file_loader.py

                                                 lambda bc: bc[:8] + b'<test>',
                                                 del_source=del_source)
             file_path = mapping['_temp'] if not del_source else bytecode_path
-            with self.assertRaises(ValueError):
+            with self.assertRaises(EOFError):
                 self.import_(file_path, '_temp')
 
     def _test_bad_magic(self, test, *, del_source=False):

Lib/test/test_marshal.py

         invalid_string = b'l\x02\x00\x00\x00\x00\x00\x00\x00'
         self.assertRaises(ValueError, marshal.loads, invalid_string)
 
+    def test_multiple_dumps_and_loads(self):
+        # Issue 12291: marshal.load() should be callable multiple times
+        # with interleaved data written by non-marshal code
+        # Adapted from a patch by Engelbert Gruber.
+        data = (1, 'abc', b'def', 1.0, (2, 'a', ['b', b'c']))
+        for interleaved in (b'', b'0123'):
+            ilen = len(interleaved)
+            positions = []
+            try:
+                with open(support.TESTFN, 'wb') as f:
+                    for d in data:
+                        marshal.dump(d, f)
+                        if ilen:
+                            f.write(interleaved)
+                        positions.append(f.tell())
+                with open(support.TESTFN, 'rb') as f:
+                    for i, d in enumerate(data):
+                        self.assertEqual(d, marshal.load(f))
+                        if ilen:
+                            f.read(ilen)
+                        self.assertEqual(positions[i], f.tell())
+            finally:
+                support.unlink(support.TESTFN)
+
 
 def test_main():
     support.run_unittest(IntTestCase,
 Core and Builtins
 -----------------
 
+- Issue #12291: You can now load multiple marshalled objects from a stream,
+  with other data interleaved between marshalled objects.
+
 - Issue #12084: os.stat on Windows now works properly with relative symbolic
   links when called from any directory.
 
     int error;  /* see WFERR_* values */
     int depth;
     /* If fp == NULL, the following are valid: */
+    PyObject * readable;    /* Stream-like object being read from */
     PyObject *str;
     char *ptr;
     char *end;
 
 #define rs_byte(p) (((p)->ptr < (p)->end) ? (unsigned char)*(p)->ptr++ : EOF)
 
-#define r_byte(p) ((p)->fp ? getc((p)->fp) : rs_byte(p))
-
 static int
 r_string(char *s, int n, RFILE *p)
 {
-    if (p->fp != NULL)
-        /* The result fits into int because it must be <=n. */
-        return (int)fread(s, 1, n, p->fp);
-    if (p->end - p->ptr < n)
-        n = (int)(p->end - p->ptr);
-    memcpy(s, p->ptr, n);
-    p->ptr += n;
-    return n;
+    char * ptr;
+    int read, left;
+
+    if (!p->readable) {
+        if (p->fp != NULL)
+            /* The result fits into int because it must be <=n. */
+            read = (int) fread(s, 1, n, p->fp);
+        else {
+            left = (int)(p->end - p->ptr);
+            read = (left < n) ? left : n;
+            memcpy(s, p->ptr, read);
+            p->ptr += read;
+        }
+    }
+    else {
+        PyObject *data = PyObject_CallMethod(p->readable, "read", "i", n);
+        read = 0;
+        if (data != NULL) {
+            if (!PyBytes_Check(data)) {
+                PyErr_Format(PyExc_TypeError,
+                             "f.read() returned not bytes but %.100s",
+                             data->ob_type->tp_name);
+            }
+            else {
+                read = PyBytes_GET_SIZE(data);
+                if (read > 0) {
+                    ptr = PyBytes_AS_STRING(data);
+                    memcpy(s, ptr, read);
+                }
+            }
+            Py_DECREF(data);
+        }
+    }
+    if (!PyErr_Occurred() && (read < n)) {
+        PyErr_SetString(PyExc_EOFError, "EOF read where not expected");
+    }
+    return read;
+}
+
+
+static int
+r_byte(RFILE *p)
+{
+    int c = EOF;
+    unsigned char ch;
+    int n;
+
+    if (!p->readable)
+        c = p->fp ? getc(p->fp) : rs_byte(p);
+    else {
+        n = r_string((char *) &ch, 1, p);
+        if (n > 0)
+            c = ch;
+    }
+    return c;
 }
 
 static int
 r_short(RFILE *p)
 {
     register short x;
-    x = r_byte(p);
-    x |= r_byte(p) << 8;
+    unsigned char buffer[2];
+
+    r_string((char *) buffer, 2, p);
+    x = buffer[0];
+    x |= buffer[1] << 8;
     /* Sign-extension, in case short greater than 16 bits */
     x |= -(x & 0x8000);
     return x;
 r_long(RFILE *p)
 {
     register long x;
-    register FILE *fp = p->fp;
-    if (fp) {
-        x = getc(fp);
-        x |= (long)getc(fp) << 8;
-        x |= (long)getc(fp) << 16;
-        x |= (long)getc(fp) << 24;
-    }
-    else {
-        x = rs_byte(p);
-        x |= (long)rs_byte(p) << 8;
-        x |= (long)rs_byte(p) << 16;
-        x |= (long)rs_byte(p) << 24;
-    }
+    unsigned char buffer[4];
+
+    r_string((char *) buffer, 4, p);
+    x = buffer[0];
+    x |= (long)buffer[1] << 8;
+    x |= (long)buffer[2] << 16;
+    x |= (long)buffer[3] << 24;
 #if SIZEOF_LONG > 4
     /* Sign extension for 64-bit machines */
     x |= -(x & 0x80000000L);
 static PyObject *
 r_long64(RFILE *p)
 {
+    PyObject * result = NULL;
     long lo4 = r_long(p);
     long hi4 = r_long(p);
+
+    if (!PyErr_Occurred()) {
 #if SIZEOF_LONG > 4
-    long x = (hi4 << 32) | (lo4 & 0xFFFFFFFFL);
-    return PyLong_FromLong(x);
+        long x = (hi4 << 32) | (lo4 & 0xFFFFFFFFL);
+        result = PyLong_FromLong(x);
 #else
-    unsigned char buf[8];
-    int one = 1;
-    int is_little_endian = (int)*(char*)&one;
-    if (is_little_endian) {
-        memcpy(buf, &lo4, 4);
-        memcpy(buf+4, &hi4, 4);
+        unsigned char buf[8];
+        int one = 1;
+        int is_little_endian = (int)*(char*)&one;
+        if (is_little_endian) {
+            memcpy(buf, &lo4, 4);
+            memcpy(buf+4, &hi4, 4);
+        }
+        else {
+            memcpy(buf, &hi4, 4);
+            memcpy(buf+4, &lo4, 4);
+        }
+        result = _PyLong_FromByteArray(buf, 8, is_little_endian, 1);
+#endif
     }
-    else {
-        memcpy(buf, &hi4, 4);
-        memcpy(buf+4, &lo4, 4);
-    }
-    return _PyLong_FromByteArray(buf, 8, is_little_endian, 1);
-#endif
+    return result;
 }
 
 static PyObject *
     digit d;
 
     n = r_long(p);
+    if (PyErr_Occurred())
+        return NULL;
     if (n == 0)
         return (PyObject *)_PyLong_New(0);
     if (n < -INT_MAX || n > INT_MAX) {
         d = 0;
         for (j=0; j < PyLong_MARSHAL_RATIO; j++) {
             md = r_short(p);
+            if (PyErr_Occurred())
+                break;
             if (md < 0 || md > PyLong_MARSHAL_BASE)
                 goto bad_digit;
             d += (digit)md << j*PyLong_MARSHAL_SHIFT;
     d = 0;
     for (j=0; j < shorts_in_top_digit; j++) {
         md = r_short(p);
+        if (PyErr_Occurred())
+            break;
         if (md < 0 || md > PyLong_MARSHAL_BASE)
             goto bad_digit;
         /* topmost marshal digit should be nonzero */
         }
         d += (digit)md << j*PyLong_MARSHAL_SHIFT;
     }
+    if (PyErr_Occurred()) {
+        Py_DECREF(ob);
+        return NULL;
+    }
     /* top digit should be nonzero, else the resulting PyLong won't be
        normalized */
     ob->ob_digit[size-1] = d;
         break;
 
     case TYPE_INT:
-        retval = PyLong_FromLong(r_long(p));
+        n = r_long(p);
+        retval = PyErr_Occurred() ? NULL : PyLong_FromLong(n);
         break;
 
     case TYPE_INT64:
 
     case TYPE_STRING:
         n = r_long(p);
+        if (PyErr_Occurred()) {
+            retval = NULL;
+            break;
+        }
         if (n < 0 || n > INT_MAX) {
             PyErr_SetString(PyExc_ValueError, "bad marshal data (string size out of range)");
             retval = NULL;
         char *buffer;
 
         n = r_long(p);
+        if (PyErr_Occurred()) {
+            retval = NULL;
+            break;
+        }
         if (n < 0 || n > INT_MAX) {
             PyErr_SetString(PyExc_ValueError, "bad marshal data (unicode size out of range)");
             retval = NULL;
 
     case TYPE_TUPLE:
         n = r_long(p);
+        if (PyErr_Occurred()) {
+            retval = NULL;
+            break;
+        }
         if (n < 0 || n > INT_MAX) {
             PyErr_SetString(PyExc_ValueError, "bad marshal data (tuple size out of range)");
             retval = NULL;
 
     case TYPE_LIST:
         n = r_long(p);
+        if (PyErr_Occurred()) {
+            retval = NULL;
+            break;
+        }
         if (n < 0 || n > INT_MAX) {
             PyErr_SetString(PyExc_ValueError, "bad marshal data (list size out of range)");
             retval = NULL;
     case TYPE_SET:
     case TYPE_FROZENSET:
         n = r_long(p);
+        if (PyErr_Occurred()) {
+            retval = NULL;
+            break;
+        }
         if (n < 0 || n > INT_MAX) {
             PyErr_SetString(PyExc_ValueError, "bad marshal data (set size out of range)");
             retval = NULL;
 
             /* XXX ignore long->int overflows for now */
             argcount = (int)r_long(p);
+            if (PyErr_Occurred())
+                goto code_error;
             kwonlyargcount = (int)r_long(p);
+            if (PyErr_Occurred())
+                goto code_error;
             nlocals = (int)r_long(p);
+            if (PyErr_Occurred())
+                goto code_error;
             stacksize = (int)r_long(p);
+            if (PyErr_Occurred())
+                goto code_error;
             flags = (int)r_long(p);
+            if (PyErr_Occurred())
+                goto code_error;
             code = r_object(p);
             if (code == NULL)
                 goto code_error;
 {
     RFILE rf;
     assert(fp);
+    rf.readable = NULL;
     rf.fp = fp;
     rf.strings = NULL;
     rf.end = rf.ptr = NULL;
 {
     RFILE rf;
     rf.fp = fp;
+    rf.readable = NULL;
     rf.strings = NULL;
     rf.ptr = rf.end = NULL;
     return r_long(&rf);
     RFILE rf;
     PyObject *result;
     rf.fp = fp;
+    rf.readable = NULL;
     rf.strings = PyList_New(0);
     rf.depth = 0;
     rf.ptr = rf.end = NULL;
     RFILE rf;
     PyObject *result;
     rf.fp = NULL;
+    rf.readable = NULL;
     rf.ptr = str;
     rf.end = str + len;
     rf.strings = PyList_New(0);
     PyObject *res = NULL;
 
     wf.fp = NULL;
+    wf.readable = NULL;
     wf.str = PyBytes_FromStringAndSize((char *)NULL, 50);
     if (wf.str == NULL)
         return NULL;
 static PyObject *
 marshal_load(PyObject *self, PyObject *f)
 {
-    /* XXX Quick hack -- need to do this differently */
     PyObject *data, *result;
     RFILE rf;
-    data = PyObject_CallMethod(f, "read", "");
+    char *p;
+    int n;
+
+    /*
+     * Make a call to the read method, but read zero bytes.
+     * This is to ensure that the object passed in at least
+     * has a read method which returns bytes.
+     */
+    data = PyObject_CallMethod(f, "read", "i", 0);
     if (data == NULL)
         return NULL;
-    rf.fp = NULL;
-    if (PyBytes_Check(data)) {
-        rf.ptr = PyBytes_AS_STRING(data);
-        rf.end = rf.ptr + PyBytes_GET_SIZE(data);
-    }
-    else if (PyBytes_Check(data)) {
-        rf.ptr = PyBytes_AS_STRING(data);
-        rf.end = rf.ptr + PyBytes_GET_SIZE(data);
+    if (!PyBytes_Check(data)) {
+        PyErr_Format(PyExc_TypeError,
+                     "f.read() returned not bytes but %.100s",
+                     data->ob_type->tp_name);
+        result = NULL;
     }
     else {
-        PyErr_Format(PyExc_TypeError,
-                     "f.read() returned neither string "
-                     "nor bytes but %.100s",
-                     data->ob_type->tp_name);
-        Py_DECREF(data);
-        return NULL;
+        rf.strings = PyList_New(0);
+        rf.depth = 0;
+        rf.fp = NULL;
+        rf.readable = f;
+        result = read_object(&rf);
+        Py_DECREF(rf.strings);
     }
-    rf.strings = PyList_New(0);
-    rf.depth = 0;
-    result = read_object(&rf);
-    Py_DECREF(rf.strings);
     Py_DECREF(data);
     return result;
 }
     s = p.buf;
     n = p.len;
     rf.fp = NULL;
+    rf.readable = NULL;
     rf.ptr = s;
     rf.end = s + n;
     rf.strings = PyList_New(0);