Commits

Armin Rigo committed d44a9f5

slice assignment.

  • Participants
  • Parent commits 8e01f52
  • Branches slicing

Comments (0)

Files changed (2)

File c/_cffi_backend.c

 static PyObject *
 new_array_type(CTypeDescrObject *ctptr, PyObject *lengthobj);   /* forward */
 
-static PyObject *
-cdata_slice(CDataObject *cd, PySliceObject *slice)
+static CTypeDescrObject *
+_cdata_getslicearg(CDataObject *cd, PySliceObject *slice, Py_ssize_t bounds[])
 {
     Py_ssize_t start, stop;
-    CDataObject_own_length *scd;
     CTypeDescrObject *ct;
 
     start = PyInt_AsSsize_t(slice->start);
         return NULL;
     }
 
+    bounds[0] = start;
+    bounds[1] = stop - start;
+    return ct;
+}
+
+static PyObject *
+cdata_slice(CDataObject *cd, PySliceObject *slice)
+{
+    Py_ssize_t bounds[2];
+    CDataObject_own_length *scd;
+    CTypeDescrObject *ct = _cdata_getslicearg(cd, slice, bounds);
+    if (ct == NULL)
+        return NULL;
+
     if (ct->ct_stuff == NULL) {
         ct->ct_stuff = new_array_type(ct, Py_None);
         if (ct->ct_stuff == NULL)
         return NULL;
     Py_INCREF(ct);
     scd->head.c_type = ct;
-    scd->head.c_data = cd->c_data + ct->ct_itemdescr->ct_size * start;
+    scd->head.c_data = cd->c_data + ct->ct_itemdescr->ct_size * bounds[0];
     scd->head.c_weakreflist = NULL;
-    scd->length = stop - start;
+    scd->length = bounds[1];
     return (PyObject *)scd;
 }
 
+static int
+cdata_ass_slice(CDataObject *cd, PySliceObject *slice, PyObject *v)
+{
+    Py_ssize_t bounds[2], i, length, itemsize;
+    PyObject *it, *item;
+    PyObject *(*iternext)(PyObject *);
+    char *cdata;
+    int err;
+    CTypeDescrObject *ct = _cdata_getslicearg(cd, slice, bounds);
+    if (ct == NULL)
+        return -1;
+
+    it = PyObject_GetIter(v);
+    if (it == NULL)
+        return -1;
+    iternext = *it->ob_type->tp_iternext;
+
+    ct = ct->ct_itemdescr;
+    itemsize = ct->ct_size;
+    cdata = cd->c_data + itemsize * bounds[0];
+    length = bounds[1];
+    for (i = 0; i < length; i++) {
+        item = iternext(it);
+        if (item == NULL) {
+            if (!PyErr_Occurred())
+                PyErr_Format(PyExc_ValueError,
+                             "need %zd values to unpack, got %zd",
+                             length, i);
+            goto error;
+        }
+        err = convert_from_object(cdata, ct, item);
+        Py_DECREF(item);
+        if (err < 0)
+            goto error;
+
+        cdata += itemsize;
+    }
+    item = iternext(it);
+    if (item != NULL) {
+        Py_DECREF(item);
+        PyErr_Format(PyExc_ValueError,
+                     "got more than %zd values to unpack", length);
+    }
+ error:
+    Py_DECREF(it);
+    return PyErr_Occurred() ? -1 : 0;
+}
+
 static PyObject *
 cdataowning_subscript(CDataObject *cd, PyObject *key)
 {
 static int
 cdata_ass_sub(CDataObject *cd, PyObject *key, PyObject *v)
 {
-    char *c = _cdata_get_indexed_ptr(cd, key);
-    CTypeDescrObject *ctitem = cd->c_type->ct_itemdescr;
+    char *c;
+    CTypeDescrObject *ctitem;
+    if (PySlice_Check(key))
+        return cdata_ass_slice(cd, (PySliceObject *)key, v);
+
+    c = _cdata_get_indexed_ptr(cd, key);
+    ctitem = cd->c_type->ct_itemdescr;
     /* use 'mp_ass_subscript' instead of 'sq_ass_item' because we don't want
        negative indexes to be corrected automatically */
     if (c == NULL && PyErr_Occurred())
     assert str(e.value) == "slice start > stop"
     e = py.test.raises(IndexError, "c[6:6]")
     assert str(e.value) == "index too large (expected 6 < 5)"
+
+def test_setslice():
+    BIntP = new_pointer_type(new_primitive_type("int"))
+    BIntArray = new_array_type(BIntP, None)
+    c = newp(BIntArray, 5)
+    c[1:3] = [100, 200]
+    assert list(c) == [0, 100, 200, 0, 0]
+    cp = c + 3
+    cp[-1:1] = [300, 400]
+    assert list(c) == [0, 100, 300, 400, 0]
+    cp[-1:1] = iter([500, 600])
+    assert list(c) == [0, 100, 500, 600, 0]
+    py.test.raises(ValueError, "cp[-1:1] = [1000]")
+    assert list(c) == [0, 100, 1000, 600, 0]
+    py.test.raises(ValueError, "cp[-1:1] = (700, 800, 900)")
+    assert list(c) == [0, 100, 700, 800, 0]