Source

pypy / pypy / module / micronumpy / signature.py

Full commit
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
from pypy.rlib.objectmodel import r_dict, compute_identity_hash, compute_hash
from pypy.rlib.rarithmetic import intmask
from pypy.module.micronumpy.interp_iter import ViewIterator, ArrayIterator, \
     ConstantIterator, AxisIterator, ViewTransform,\
     BroadcastTransform
from pypy.rlib.jit import hint, unroll_safe, promote

""" Signature specifies both the numpy expression that has been constructed
and the assembler to be compiled. This is a very important observation -
Two expressions will be using the same assembler if and only if they are
compiled to the same signature.

This is also a very convinient tool for specializations. For example
a + a and a + b (where a != b) will compile to different assembler because
we specialize on the same array access.

When evaluating, signatures will create iterators per signature node,
potentially sharing some of them. Iterators depend also on the actual
expression, they're not only dependant on the array itself. For example
a + b where a is dim 2 and b is dim 1 would create a broadcasted iterator for
the array b.

Such iterator changes are called Transformations. An actual iterator would
be a combination of array and various transformation, like view, broadcast,
dimension swapping etc.

See interp_iter for transformations
"""

def new_printable_location(driver_name):
    def get_printable_location(shapelen, sig):
        return 'numpy ' + sig.debug_repr() + ' [%d dims,%s]' % (shapelen, driver_name)
    return get_printable_location

def sigeq(one, two):
    return one.eq(two)

def sigeq_no_numbering(one, two):
    """ Cache for iterator numbering should not compare array numbers
    """
    return one.eq(two, compare_array_no=False)

def sighash(sig):
    return sig.hash()

known_sigs = r_dict(sigeq, sighash)

def find_sig(sig, arr):
    sig.invent_array_numbering(arr)
    try:
        return known_sigs[sig]
    except KeyError:
        sig.invent_numbering()
        known_sigs[sig] = sig
        return sig

class NumpyEvalFrame(object):
    _virtualizable2_ = ['iterators[*]', 'final_iter', 'arraylist[*]',
                        'value', 'identity']

    @unroll_safe
    def __init__(self, iterators, arrays):
        self = hint(self, access_directly=True, fresh_virtualizable=True)
        self.iterators = iterators[:]
        self.arrays = arrays[:]
        for i in range(len(self.iterators)):
            iter = self.iterators[i]
            if not isinstance(iter, ConstantIterator):
                self.final_iter = i
                break
        else:
            self.final_iter = -1

    def done(self):
        final_iter = promote(self.final_iter)
        if final_iter < 0:
            assert False
        return self.iterators[final_iter].done()

    @unroll_safe
    def next(self, shapelen):
        for i in range(len(self.iterators)):
            self.iterators[i] = self.iterators[i].next(shapelen)

    def get_final_iter(self):
        final_iter = promote(self.final_iter)
        if final_iter < 0:
            assert False
        return self.iterators[final_iter]

def _add_ptr_to_cache(ptr, cache):
    i = 0
    for p in cache:
        if ptr == p:
            return i
        i += 1
    else:
        res = len(cache)
        cache.append(ptr)
        return res

def new_cache():
    return r_dict(sigeq_no_numbering, sighash)

class Signature(object):
    _attrs_ = ['iter_no', 'array_no']
    _immutable_fields_ = ['iter_no', 'array_no']

    array_no = 0
    iter_no = 0

    def invent_numbering(self):
        cache = new_cache()
        allnumbers = []
        self._invent_numbering(cache, allnumbers)

    def invent_array_numbering(self, arr):
        cache = []
        self._invent_array_numbering(arr, cache)

    def _invent_numbering(self, cache, allnumbers):
        try:
            no = cache[self]
        except KeyError:
            no = len(allnumbers)
            cache[self] = no
            allnumbers.append(no)
        self.iter_no = no

    def create_frame(self, arr):
        iterlist = []
        arraylist = []
        self._create_iter(iterlist, arraylist, arr, [])
        return NumpyEvalFrame(iterlist, arraylist)


class ConcreteSignature(Signature):
    _immutable_fields_ = ['dtype']

    def __init__(self, dtype):
        self.dtype = dtype

    def eq(self, other, compare_array_no=True):
        if type(self) is not type(other):
            return False
        assert isinstance(other, ConcreteSignature)
        if compare_array_no:
            if self.array_no != other.array_no:
                return False
        return self.dtype is other.dtype

    def hash(self):
        return compute_identity_hash(self.dtype)

class ArraySignature(ConcreteSignature):
    def debug_repr(self):
        return 'Array'

    def _invent_array_numbering(self, arr, cache):
        from pypy.module.micronumpy.interp_numarray import ConcreteArray
        concr = arr.get_concrete()
        # this get_concrete never forces assembler. If we're here and array
        # is not of a concrete class it means that we have a _forced_result,
        # otherwise the signature would not match
        assert isinstance(concr, ConcreteArray)
        assert concr.dtype is self.dtype
        self.array_no = _add_ptr_to_cache(concr.storage, cache)

    def _create_iter(self, iterlist, arraylist, arr, transforms):
        from pypy.module.micronumpy.interp_numarray import ConcreteArray
        concr = arr.get_concrete()
        assert isinstance(concr, ConcreteArray)
        storage = concr.storage
        if self.iter_no >= len(iterlist):
            iterlist.append(self.allocate_iter(concr, transforms))
        if self.array_no >= len(arraylist):
            arraylist.append(storage)

    def allocate_iter(self, arr, transforms):
        return ArrayIterator(arr.size).apply_transformations(arr, transforms)

    def eval(self, frame, arr):
        iter = frame.iterators[self.iter_no]
        return self.dtype.getitem(frame.arrays[self.array_no], iter.offset)

class ScalarSignature(ConcreteSignature):
    def debug_repr(self):
        return 'Scalar'

    def _invent_array_numbering(self, arr, cache):
        pass

    def _create_iter(self, iterlist, arraylist, arr, transforms):
        if self.iter_no >= len(iterlist):
            iter = ConstantIterator()
            iterlist.append(iter)

class ViewSignature(ArraySignature):
    def debug_repr(self):
        return 'Slice'

    def _invent_numbering(self, cache, allnumbers):
        # always invent a new number for view
        no = len(allnumbers)
        allnumbers.append(no)
        self.iter_no = no

    def allocate_iter(self, arr, transforms):
        return ViewIterator(arr.start, arr.strides, arr.backstrides,
                            arr.shape).apply_transformations(arr, transforms)

class VirtualSliceSignature(Signature):
    def __init__(self, child):
        self.child = child

    def _invent_array_numbering(self, arr, cache):
        from pypy.module.micronumpy.interp_numarray import VirtualSlice
        assert isinstance(arr, VirtualSlice)
        self.child._invent_array_numbering(arr.child, cache)

    def _invent_numbering(self, cache, allnumbers):
        self.child._invent_numbering(new_cache(), allnumbers)

    def hash(self):
        return intmask(self.child.hash() ^ 1234)

    def eq(self, other, compare_array_no=True):
        if type(self) is not type(other):
            return False
        assert isinstance(other, VirtualSliceSignature)
        return self.child.eq(other.child, compare_array_no)

    def _create_iter(self, iterlist, arraylist, arr, transforms):
        from pypy.module.micronumpy.interp_numarray import VirtualSlice
        assert isinstance(arr, VirtualSlice)
        transforms = transforms + [ViewTransform(arr.chunks)]
        self.child._create_iter(iterlist, arraylist, arr.child, transforms)

    def eval(self, frame, arr):
        from pypy.module.micronumpy.interp_numarray import VirtualSlice
        assert isinstance(arr, VirtualSlice)
        return self.child.eval(frame, arr.child)

class Call1(Signature):
    _immutable_fields_ = ['unfunc', 'name', 'child']

    def __init__(self, func, name, child):
        self.unfunc = func
        self.child = child
        self.name = name

    def hash(self):
        return compute_hash(self.name) ^ intmask(self.child.hash() << 1)

    def eq(self, other, compare_array_no=True):
        if type(self) is not type(other):
            return False
        assert isinstance(other, Call1)
        return (self.unfunc is other.unfunc and
                self.child.eq(other.child, compare_array_no))

    def debug_repr(self):
        return 'Call1(%s, %s)' % (self.name, self.child.debug_repr())

    def _invent_numbering(self, cache, allnumbers):
        self.child._invent_numbering(cache, allnumbers)

    def _invent_array_numbering(self, arr, cache):
        from pypy.module.micronumpy.interp_numarray import Call1
        assert isinstance(arr, Call1)
        self.child._invent_array_numbering(arr.values, cache)

    def _create_iter(self, iterlist, arraylist, arr, transforms):
        from pypy.module.micronumpy.interp_numarray import Call1
        assert isinstance(arr, Call1)
        self.child._create_iter(iterlist, arraylist, arr.values, transforms)

    def eval(self, frame, arr):
        from pypy.module.micronumpy.interp_numarray import Call1
        assert isinstance(arr, Call1)
        v = self.child.eval(frame, arr.values).convert_to(arr.res_dtype)
        return self.unfunc(arr.res_dtype, v)

class Call2(Signature):
    _immutable_fields_ = ['binfunc', 'name', 'calc_dtype', 'left', 'right']

    def __init__(self, func, name, calc_dtype, left, right):
        self.binfunc = func
        self.left = left
        self.right = right
        self.name = name
        self.calc_dtype = calc_dtype

    def hash(self):
        return (compute_hash(self.name) ^ intmask(self.left.hash() << 1) ^
                intmask(self.right.hash() << 2))

    def eq(self, other, compare_array_no=True):
        if type(self) is not type(other):
            return False
        assert isinstance(other, Call2)
        return (self.binfunc is other.binfunc and
                self.calc_dtype is other.calc_dtype and
                self.left.eq(other.left, compare_array_no) and
                self.right.eq(other.right, compare_array_no))

    def _invent_array_numbering(self, arr, cache):
        from pypy.module.micronumpy.interp_numarray import Call2
        assert isinstance(arr, Call2)
        self.left._invent_array_numbering(arr.left, cache)
        self.right._invent_array_numbering(arr.right, cache)

    def _invent_numbering(self, cache, allnumbers):
        self.left._invent_numbering(cache, allnumbers)
        self.right._invent_numbering(cache, allnumbers)

    def _create_iter(self, iterlist, arraylist, arr, transforms):
        from pypy.module.micronumpy.interp_numarray import Call2

        assert isinstance(arr, Call2)
        self.left._create_iter(iterlist, arraylist, arr.left, transforms)
        self.right._create_iter(iterlist, arraylist, arr.right, transforms)

    def eval(self, frame, arr):
        from pypy.module.micronumpy.interp_numarray import Call2
        assert isinstance(arr, Call2)
        lhs = self.left.eval(frame, arr.left).convert_to(self.calc_dtype)
        rhs = self.right.eval(frame, arr.right).convert_to(self.calc_dtype)

        return self.binfunc(self.calc_dtype, lhs, rhs)

    def debug_repr(self):
        return 'Call2(%s, %s, %s)' % (self.name, self.left.debug_repr(),
                                      self.right.debug_repr())

class BroadcastLeft(Call2):
    def _invent_numbering(self, cache, allnumbers):
        self.left._invent_numbering(new_cache(), allnumbers)
        self.right._invent_numbering(cache, allnumbers)

    def _create_iter(self, iterlist, arraylist, arr, transforms):
        from pypy.module.micronumpy.interp_numarray import Call2

        assert isinstance(arr, Call2)
        ltransforms = transforms + [BroadcastTransform(arr.shape)]
        self.left._create_iter(iterlist, arraylist, arr.left, ltransforms)
        self.right._create_iter(iterlist, arraylist, arr.right, transforms)

class BroadcastRight(Call2):
    def _invent_numbering(self, cache, allnumbers):
        self.left._invent_numbering(cache, allnumbers)
        self.right._invent_numbering(new_cache(), allnumbers)

    def _create_iter(self, iterlist, arraylist, arr, transforms):
        from pypy.module.micronumpy.interp_numarray import Call2

        assert isinstance(arr, Call2)
        rtransforms = transforms + [BroadcastTransform(arr.shape)]
        self.left._create_iter(iterlist, arraylist, arr.left, transforms)
        self.right._create_iter(iterlist, arraylist, arr.right, rtransforms)

class BroadcastBoth(Call2):
    def _invent_numbering(self, cache, allnumbers):
        self.left._invent_numbering(new_cache(), allnumbers)
        self.right._invent_numbering(new_cache(), allnumbers)

    def _create_iter(self, iterlist, arraylist, arr, transforms):
        from pypy.module.micronumpy.interp_numarray import Call2

        assert isinstance(arr, Call2)
        rtransforms = transforms + [BroadcastTransform(arr.shape)]
        ltransforms = transforms + [BroadcastTransform(arr.shape)]
        self.left._create_iter(iterlist, arraylist, arr.left, ltransforms)
        self.right._create_iter(iterlist, arraylist, arr.right, rtransforms)

class ReduceSignature(Call2):
    def _create_iter(self, iterlist, arraylist, arr, transforms):
        self.right._create_iter(iterlist, arraylist, arr, transforms)

    def _invent_numbering(self, cache, allnumbers):
        self.right._invent_numbering(cache, allnumbers)

    def _invent_array_numbering(self, arr, cache):
        self.right._invent_array_numbering(arr, cache)

    def eval(self, frame, arr):
        return self.right.eval(frame, arr)

    def debug_repr(self):
        return 'ReduceSig(%s, %s)' % (self.name, self.right.debug_repr())

class SliceloopSignature(Call2):
    def eval(self, frame, arr):
        from pypy.module.micronumpy.interp_numarray import Call2

        assert isinstance(arr, Call2)
        ofs = frame.iterators[0].offset
        arr.left.setitem(ofs, self.right.eval(frame, arr.right).convert_to(
            self.calc_dtype))

    def debug_repr(self):
        return 'SliceLoop(%s, %s, %s)' % (self.name, self.left.debug_repr(),
                                          self.right.debug_repr())

class SliceloopBroadcastSignature(SliceloopSignature):
    def _invent_numbering(self, cache, allnumbers):
        self.left._invent_numbering(new_cache(), allnumbers)
        self.right._invent_numbering(cache, allnumbers)

    def _create_iter(self, iterlist, arraylist, arr, transforms):
        from pypy.module.micronumpy.interp_numarray import SliceArray

        assert isinstance(arr, SliceArray)
        rtransforms = transforms + [BroadcastTransform(arr.shape)]
        self.left._create_iter(iterlist, arraylist, arr.left, transforms)
        self.right._create_iter(iterlist, arraylist, arr.right, rtransforms)

class AxisReduceSignature(Call2):
    def _create_iter(self, iterlist, arraylist, arr, transforms):
        from pypy.module.micronumpy.interp_numarray import AxisReduce,\
             ConcreteArray

        assert isinstance(arr, AxisReduce)
        left = arr.left
        assert isinstance(left, ConcreteArray)
        iterlist.append(AxisIterator(left.start, arr.dim, arr.shape,
                                     left.strides, left.backstrides))
        self.right._create_iter(iterlist, arraylist, arr.right, transforms)

    def _invent_numbering(self, cache, allnumbers):
        allnumbers.append(0)
        self.right._invent_numbering(cache, allnumbers)

    def _invent_array_numbering(self, arr, cache):
        from pypy.module.micronumpy.interp_numarray import AxisReduce

        assert isinstance(arr, AxisReduce)
        self.right._invent_array_numbering(arr.right, cache)

    def eval(self, frame, arr):
        from pypy.module.micronumpy.interp_numarray import AxisReduce

        assert isinstance(arr, AxisReduce)
        return self.right.eval(frame, arr.right).convert_to(self.calc_dtype)

    def debug_repr(self):
        return 'AxisReduceSig(%s, %s)' % (self.name, self.right.debug_repr())