Commits

Ilya Osadchiy committed b2dc68e

numpy: something on multidimensions

Comments (0)

Files changed (2)

pypy/module/micronumpy/interp_numarray.py

     def descr_len(self, space):
         return self.get_concrete().descr_len(space)
 
+    def subscript_to_index(subscript, shape):
+        # TODO: is it better to store cumulative multiply of shape and then index = reduce("add", map("mul", subscript, cummult_shape)) ?
+        index = 0
+        stride = 1
+        for ind, size in zip(subscript, shape):
+            index += ind * stride
+            stride *= size
+
     def descr_getitem(self, space, w_idx):
-        # TODO: indexing by tuples
-        start, stop, step, slice_length = space.decode_index4(w_idx, self.find_size())
-        if step == 0:
-            # Single index
-            return space.wrap(self.get_concrete().getitem(start))
+        if space.is_true(space.isinstance(w_idx, space.w_tuple)):
+            # TODO: slices inside tuples, incomplete ind etc
+            subscript = space.unpacktuple(w_idx)
+            shape = self.find_shape()
+            if len(subscript) == len(shape):
+                # Fully qualified index
+                idx = subscript_to_index(subscript, shape)
+                is_single_elem = True
+        else:
+            start, stop, step, slice_length = space.decode_index4(w_idx, self.find_size())
+            idx = start
+            is_single_elem = (step == 0)
+            
+        if is_single_elem:
+            # Single element
+            return space.wrap(self.get_concrete().getitem(idx))
         else:
             # Slice
             res = SingleDimSlice(start, stop, step, slice_length, self, self.signature.transition(SingleDimSlice.static_signature))
         BaseArray.__init__(self)
         self.float_value = float_value
 
+    def find_shape(self):
+        raise ValueError
+
     def find_size(self):
         raise ValueError
 
     """
     Class for representing virtual arrays, such as binary ops or ufuncs
     """
+    _immutable_fields_ = ["shape"]
     def __init__(self, signature):
         BaseArray.__init__(self)
         self.forced_result = None
         i = 0
         signature = self.signature
         result_size = self.find_size()
-        result = SingleDimArray(result_size)
+        result_shape = self.find_shape()
+        if len(result_shape) == 1:
+            result = SingleDimArray(result_size)
+        else:
+            result = MultiDimArray(result_size)
         while i < result_size:
             numpy_driver.jit_merge_point(signature=signature,
                                          result_size=result_size, i=i,
             return self.forced_result.eval(i)
         return self._eval(i)
 
+    def find_shape(self):
+        if self.forced_result is not None:
+            # The result has been computed and sources may be unavailable
+            return self.forced_result.find_shape()
+        return self._find_shape()
+        
     def find_size(self):
         if self.forced_result is not None:
             # The result has been computed and sources may be unavailable
             return self.forced_result.find_size()
         return self._find_size()
 
-
 class Call1(VirtualArray):
     _immutable_fields_ = ["function", "values"]
 
     def _del_sources(self):
         self.values = None
 
+    def _find_shape(self):
+        return self.values.find_shape()
+
     def _find_size(self):
         return self.values.find_size()
 
         self.left = None
         self.right = None
 
+    def _find_shape(self):
+        try:
+            return self.left.find_shape()
+        except ValueError:
+            pass
+        return self.right.find_shape()
+
     def _find_size(self):
         try:
             return self.left.find_size()
         self.step = step
         self.size = slice_length
 
+    def find_shape(self):
+        return (self.size,)
+
     def find_size(self):
         return self.size
 
         return (self.start + item * self.step)
 
 
-class SingleDimArray(BaseArray):
+class ConcreteArray(BaseArray):
+    """
+    Class for array arrays that actually store data
+    """
     signature = Signature()
 
     def __init__(self, size):
     def eval(self, i):
         return self.storage[i]
 
+    def getitem(self, item):
+        return self.storage[item]
+
+    def __del__(self):
+        lltype.free(self.storage, flavor='raw')
+
+class SingleDimArray(ConcreteArray):
+    def __init__(self, size):
+        ConcreteArray.__init__(self, size)
+
+    def find_shape(self):
+        return (self.size,)
+
     def getindex(self, space, item):
         if item >= self.size:
             raise operationerrfmt(space.w_IndexError,
     def descr_len(self, space):
         return space.wrap(self.size)
 
-    def getitem(self, item):
-        return self.storage[item]
-
     @unwrap_spec(item=int, value=float)
     def descr_setitem(self, space, item, value):
         item = self.getindex(space, item)
         self.invalidated()
         self.storage[item] = value
 
-    def __del__(self):
-        lltype.free(self.storage, flavor='raw')
+class MultiDimArray(ConcreteArray):
+    _immutable_fields_ = ["shape"]
+    def __init__(self, size, shape):
+        ConcreteArray.__init__(self, size)
+        self.shape = shape
+
+    def find_shape(self):
+        return self.shape
+
+    def descr_len(self, space):
+        return space.wrap(self.shape(0))
+
+    def descr_setitem(self, space, w_subscript, w_value):
+        item = self.getindex(space, item)
+        self.invalidated()
+        self.storage[item] = value
 
 def descr_new_numarray(space, w_type, w_size_or_iterable):
     l = space.listview(w_size_or_iterable)
         i += 1
     return space.wrap(arr)
 
-@unwrap_spec(ObjSpace, int)
-def zeros(space, size):
-    return space.wrap(SingleDimArray(size))
-
+#@unwrap_spec(ObjSpace, int)
+def zeros(space, w_size):
+    if space.is_true(space.isinstance(w_size, space.w_tuple)):
+        shape = tuple(space.unpackiterable(w_size))
+        size = reduce(lambda x, y: x*y, shape)
+        return space.wrap(MultiDimArray(size, shape))
+    elif space.is_true(space.isinstance(w_size, space.w_int)):
+        return space.wrap(SingleDimArray(space.int_w(w_size)))
+    else:
+        raise OperationError(space.w_TypeError, space.wrap("expected sequence object with len >= 0"))
 
 BaseArray.typedef = TypeDef(
     'numarray',

pypy/module/micronumpy/test/test_numarray.py

         a[2] = 20
         assert s[2] == 20
 
-
     def test_slice_invaidate(self):
         # check that slice shares invalidation list with 
         from numpy import array
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.