pypy / pypy / module / micronumpy / interp_iter.py

from pypy.rlib import jit
from pypy.rlib.objectmodel import instantiate
from pypy.module.micronumpy.strides import calculate_broadcast_strides,\
     calculate_slice_strides

# structures to describe slicing

class BaseChunk(object):
    pass

class Chunk(BaseChunk):
    def __init__(self, start, stop, step, lgt):
        self.start = start
        self.stop = stop
        self.step = step
        self.lgt = lgt

    def extend_shape(self, shape):
        if self.step != 0:
            shape.append(self.lgt)

    def get_iter(self):
        xxx

class IntArrayChunk(BaseChunk):
    def __init__(self, arr):
        self.arr = arr.get_concrete()

    def extend_shape(self, shape):
        shape.extend(self.arr.shape)

    def get_iter(self):
        return self.arr.create_iter()

    def get_index(self, iter):
        return self.arr.getitem(iter.offset).convert_to_int()

class BoolArrayChunk(BaseChunk):
    def __init__(self, arr):
        self.arr = arr.get_concrete()

    def extend_shape(self, shape):
        xxx

    def get_iter(self):
        xxx

class BaseTransform(object):
    pass

class ViewTransform(BaseTransform):
    def __init__(self, chunks):
        # 4-tuple specifying slicing
        self.chunks = chunks

class BroadcastTransform(BaseTransform):
    def __init__(self, res_shape):
        self.res_shape = res_shape

class BaseIterator(object):
    def next(self, shapelen):
        raise NotImplementedError

    def done(self):
        raise NotImplementedError

    def apply_transformations(self, arr, transformations):
        v = self
        for transform in transformations:
            v = v.transform(arr, transform)
        return v

    def transform(self, arr, t):
        raise NotImplementedError

class ArrayIterator(BaseIterator):
    def __init__(self, size):
        self.offset = 0
        self.size = size

    def next(self, shapelen):
        return self._next(1)

    def _next(self, ofs):
        arr = instantiate(ArrayIterator)
        arr.size = self.size
        arr.offset = self.offset + ofs
        return arr

    def next_no_increase(self, shapelen):
        # a hack to make JIT believe this is always virtual
        return self._next(0)

    def done(self):
        return self.offset >= self.size

    def transform(self, arr, t):
        return ViewIterator(arr.start, arr.strides, arr.backstrides,
                            arr.shape).transform(arr, t)

class OneDimIterator(BaseIterator):
    def __init__(self, start, step, stop):
        self.offset = start
        self.step = step
        self.size = stop * step + start

    def next(self, shapelen):
        arr = instantiate(OneDimIterator)
        arr.size = self.size
        arr.step = self.step
        arr.offset = self.offset + self.step
        return arr

    def done(self):
        return self.offset == self.size

class ViewIterator(BaseIterator):
    def __init__(self, start, strides, backstrides, shape):
        self.offset  = start
        self._done   = False
        self.strides = strides
        self.backstrides = backstrides
        self.res_shape = shape
        self.indices = [0] * len(self.res_shape)

    def transform(self, arr, t):
        if isinstance(t, BroadcastTransform):
            r = calculate_broadcast_strides(self.strides, self.backstrides,
                                            self.res_shape, t.res_shape)
            return ViewIterator(self.offset, r[0], r[1], t.res_shape)
        elif isinstance(t, ViewTransform):
            r = calculate_slice_strides(self.res_shape, self.offset,
                                        self.strides,
                                        self.backstrides, t.chunks)
            return ViewIterator(r[1], r[2], r[3], r[0])

    @jit.unroll_safe
    def next(self, shapelen):
        shapelen = jit.promote(len(self.res_shape))
        offset = self.offset
        indices = [0] * shapelen
        for i in range(shapelen):
            indices[i] = self.indices[i]
        done = False
        for i in range(shapelen - 1, -1, -1):
            if indices[i] < self.res_shape[i] - 1:
                indices[i] += 1
                offset += self.strides[i]
                break
            else:
                indices[i] = 0
                offset -= self.backstrides[i]
        else:
            done = True
        res = instantiate(ViewIterator)
        res.offset = offset
        res.indices = indices
        res.strides = self.strides
        res.backstrides = self.backstrides
        res.res_shape = self.res_shape
        res._done = done
        return res

    def apply_transformations(self, arr, transformations):
        v = BaseIterator.apply_transformations(self, arr, transformations)
        if len(arr.shape) == 1:
            return OneDimIterator(self.offset, self.strides[0],
                                  self.res_shape[0])
        return v

    def done(self):
        return self._done

class ConstantIterator(BaseIterator):
    def next(self, shapelen):
        return self

    def transform(self, arr, t):
        pass

class AxisIterator(BaseIterator):
    def __init__(self, start, dim, shape, strides, backstrides):
        self.res_shape = shape[:]
        self.strides = strides[:dim] + [0] + strides[dim:]
        self.backstrides = backstrides[:dim] + [0] + backstrides[dim:]
        self.first_line = True
        self.indices = [0] * len(shape)
        self._done = False
        self.offset = start
        self.dim = dim

    @jit.unroll_safe
    def next(self, shapelen):
        offset = self.offset
        first_line = self.first_line
        indices = [0] * shapelen
        for i in range(shapelen):
            indices[i] = self.indices[i]
        done = False
        for i in range(shapelen - 1, -1, -1):
            if indices[i] < self.res_shape[i] - 1:
                if i == self.dim:
                    first_line = False
                indices[i] += 1
                offset += self.strides[i]
                break
            else:
                if i == self.dim:
                    first_line = True
                indices[i] = 0
                offset -= self.backstrides[i]
        else:
            done = True
        res = instantiate(AxisIterator)
        res.offset = offset
        res.indices = indices
        res.strides = self.strides
        res.backstrides = self.backstrides
        res.res_shape = self.res_shape
        res._done = done
        res.first_line = first_line
        res.dim = self.dim
        return res        

    def done(self):
        return self._done

# ------ other iterators that are not part of the computation frame ----------

class ChunkIterator(object):
    def __init__(self, shape, chunks):
        self.chunks = chunks
        self.indices = [0] * len(shape)
        self.shape = shape
        self.chunk_iters = [chunk.get_iter() for chunk in self.chunks]

    def next(self, shapelen):
        for i in range(shapelen - 1, -1, -1):
            if self.indices[i] < self.shape[i] - 1:
                self.indices[i] += 1
                self.chunk_iters[i] = self.chunk_iters[i].next(shapelen)
                break
            else:
                self.indices[i] = 0
                # XXX reset one dim iter probably
        return self

    def get_index(self, shapelen):
        l = []
        for i in range(shapelen):
            l.append(self.chunks[i].get_index(self.chunk_iters[i]))
        return l
    
class SkipLastAxisIterator(object):
    def __init__(self, arr):
        self.arr = arr
        self.indices = [0] * (len(arr.shape) - 1)
        self.done = False
        self.offset = arr.start

    def next(self):
        for i in range(len(self.arr.shape) - 2, -1, -1):
            if self.indices[i] < self.arr.shape[i] - 1:
                self.indices[i] += 1
                self.offset += self.arr.strides[i]
                break
            else:
                self.indices[i] = 0
                self.offset -= self.arr.backstrides[i]
        else:
            self.done = True
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.