Commits

Brian Kearns  committed 9f3d775

fix validation of missing r/w op_flag to nditer

  • Participants
  • Parent commits 17a4ad7

Comments (0)

Files changed (2)

File pypy/module/micronumpy/nditer.py

 
 class OpFlag(object):
     def __init__(self):
-        self.rw = 'r'
+        self.rw = ''
         self.broadcast = True
         self.force_contig = False
         self.force_align = False
         else:
             raise OperationError(space.w_ValueError, space.wrap(
                 'op_flags must be a tuple or array of per-op flag-tuples'))
-        if op_flag.rw == 'r':
+        if op_flag.rw == '':
+            raise oefmt(space.w_ValueError,
+                        "None of the iterator flags READWRITE, READONLY, or "
+                        "WRITEONLY were specified for an operand")
+        elif op_flag.rw == 'r':
             op_flag.get_it_item = (get_readonly_item, get_readonly_slice)
         elif op_flag.rw == 'rw':
             op_flag.get_it_item = (get_readwrite_item, get_readwrite_slice)

File pypy/module/micronumpy/test/test_nditer.py

         a = arange(6).reshape(2,3) - 3
         exc = raises(TypeError, nditer, a, op_dtypes=['complex'])
         assert str(exc.value).startswith("Iterator operand required copying or buffering")
+        exc = raises(ValueError, nditer, a, op_flags=['copy'], op_dtypes=['complex128'])
+        assert str(exc.value) == "None of the iterator flags READWRITE, READONLY, or WRITEONLY were specified for an operand"
         r = []
         for x in nditer(a, op_flags=['readonly','copy'],
                         op_dtypes=['complex128']):
             r.append(sqrt(x))
         assert abs((array(r) - [1.73205080757j, 1.41421356237j, 1j, 0j,
-                1+0j, 1.41421356237+0j]).sum()) < 1e-5
-        r = []
-        for x in nditer(a, op_flags=['copy'],
-                        op_dtypes=['complex128']):
-            r.append(sqrt(x))
-        assert abs((array(r) - [1.73205080757j, 1.41421356237j, 1j, 0j,
-                            1+0j, 1.41421356237+0j]).sum()) < 1e-5
+                                1+0j, 1.41421356237+0j]).sum()) < 1e-5
         multi = nditer([None, array([2, 3], dtype='int64'), array(2., dtype='double')],
-                       op_dtypes = ['int64', 'int64', 'float64'],
-                       op_flags = [['writeonly', 'allocate'], ['readonly'], ['readonly']])
+                       op_dtypes=['int64', 'int64', 'float64'],
+                       op_flags=[['writeonly', 'allocate'], ['readonly'], ['readonly']])
         for a, b, c in multi:
             a[...] = b * c
         assert (multi.operands[0] == [4, 6]).all()