Commits

Anonymous committed 8d2e1c3

carraywrap works

  • Participants
  • Parent commits 99c47d5

Comments (0)

Files changed (2)

carraywrap/carraywrapmodule.c

 #include "structmember.h"
 #include <numpy/arrayobject.h>
 
+#define CARRAYWRAP_MAXDIM 3
+
 typedef struct __CArrayWrap__{
   PyObject_HEAD
   char *data, *array;
-  int ndim, *dims, *ptrsize;
+  int ndim, dims[CARRAYWRAP_MAXDIM];
   PyObject *nparray;
 } CArrayWrap;
 
 CArrayWrap_dealloc(CArrayWrap* self)
 {
   Py_XDECREF(self->nparray);
-  free(self->dims);
-  free(self->ptrsize);
   free(self->data);
+  switch(self->ndim){
+  case 2:
+    free(self->array);
+    break;
+  case 3:
+    free(((char **)self->array)[0]);
+    free(self->array);
+    break;
+  }
   self->ob_type->tp_free((PyObject*)self);
 }
 
   CArrayWrap *self;
   self = (CArrayWrap *)type->tp_alloc(type, 0);
   self->ndim = 0;
-  self->ptrsize = NULL;
-  self->dims = NULL;
   self->nparray = NULL;
   return (PyObject *)self;
 }
 CArrayWrap_init(CArrayWrap *self, PyObject *args, PyObject *kwds)
 {
   PyObject *obj;
-  int i, size;
-  char *ptr;
+  int i, elemnum;
 
   if (!PyTuple_Check(args)){
     PyErr_SetString( PyExc_TypeError, "bad arguments");
     goto fail;
   }
-  /* alloc self->dims */
+  /* check ndims */
   self->ndim = PyTuple_Size(args);
-  if ((self->dims = malloc(self->ndim*sizeof(int))) == NULL){
-    PyErr_SetString( PyExc_RuntimeError, "malloc for dims fails");
-    goto fail;
-  }
-  if ((self->ptrsize = malloc(self->ndim*sizeof(int))) == NULL){
-    PyErr_SetString( PyExc_RuntimeError, "malloc for ptrsize fails");
+  if (self->ndim > CARRAYWRAP_MAXDIM){
+    PyErr_SetString( PyExc_ValueError, "too many arguments (ndim)");
     goto fail;
   }
   /* set self->dims */
-  size = 1;
+  elemnum = 1;
   for (i = 0; i < self->ndim; ++i){
     obj = PyTuple_GetItem(args, i);
     if (!PyInt_Check(obj)){
 		    "cannot get value(%d) or the value is not positive", i);
       goto fail;
     }
-    if (i == 0){
-      self->ptrsize[i] = self->dims[i];
-    }else{
-      self->ptrsize[i] = self->dims[i] * self->ptrsize[i-1];
-    }
+    elemnum *= self->dims[i];
   }
   /* alloc self->data */
-  self->data = (char *)malloc(self->ptrsize[self->ndim-1]*sizeof(int));
+  self->data = (char *)malloc(elemnum*sizeof(int));
   if (self->data == NULL){
     PyErr_SetString( PyExc_RuntimeError, "malloc for data fails");
     goto fail;
   }
-  ptr = (char *)&self->array;
-  for (i = 0; i < self->ndim-1; ++i){
-    *((char **)ptr) = (char *)malloc(self->ptrsize[i]*sizeof(char *));
-    ptr = *((char **)ptr);
-    /* if (ptr == NULL){ */
-    /* } */
+  /* set pointer */
+  switch(self->ndim){
+  case 1:
+    self->array = self->data;
+    break;
+  case 2:
+    {
+      int **a2d;
+      a2d = (int**)malloc(self->dims[0]*sizeof(int*));
+      a2d[0] = (int*)self->data;
+      for (i = 1; i < self->dims[0]; i++){
+	a2d[i] = a2d[0] + i * self->dims[1];
+      }
+      self->array = (char*)a2d;
+    }
+    break;
+  case 3:
+    {
+      int ***a3d;
+      int j;
+      a3d = (int***)malloc(self->dims[0]*sizeof(int**));
+      a3d[0] = (int**)malloc(self->dims[0]*self->dims[1]*sizeof(int*));
+      a3d[0][0] = (int*)self->data;
+      for (i = 0; i < self->dims[0]; i++) {
+	a3d[i] = a3d[0] + i * self->dims[1];
+	for (j = 0; j < self->dims[1]; j++){
+	  a3d[i][j] = a3d[0][0]
+	    + i * self->dims[1]*self->dims[2] + j * self->dims[2];
+	}
+      }
+      self->array = (char*)a3d;
+    }
+    break;
   }
-  *(char **)ptr = self->data;
 
   self->nparray =
     PyArray_SimpleNewFromData( self->ndim, self->dims, NPY_LONG,
   return 0;
  fail:
   free(self->data);
-  free(self->dims);
-  free(self->ptrsize);
   self->ndim = 0;
   return -1;
 }

carraywrap/use_caw.py

+import numpy
 from carraywrap import *
 
-caw = CArrayWrap(10,10)
+def print_2way(caw):
+    print caw.nparray
+    caw.print_array()
+
+def set_range(caw):
+    caw.nparray[:] = numpy.arange(
+        numpy.prod(caw.nparray.shape)).reshape(caw.nparray.shape)
+
+caw = CArrayWrap(3,3,3)
+
 caw.nparray.fill(0)
-print caw.nparray
-caw.print_array()
+print_2way(caw)
 
-caw.nparray.fill(1)
-print caw.nparray
-caw.print_array()
+set_range(caw)
+print_2way(caw)
+
+caw = CArrayWrap(3,3)
+
+caw.nparray.fill(0)
+print_2way(caw)
+
+set_range(caw)
+print_2way(caw)
+
+caw = CArrayWrap(3)
+
+caw.nparray.fill(0)
+print_2way(caw)
+
+set_range(caw)
+print_2way(caw)