Commits

Yichao Yu committed 2aabeb7

make numpy scalar non-iterable

  • Participants
  • Parent commits f20ac16

Comments (0)

Files changed (2)

File pypy/module/micronumpy/boxes.py

         raise OperationError(space.w_IndexError, space.wrap(
             "invalid index to scalar variable"))
 
+    def descr_iter(self, space):
+        # Making numpy scalar non-iterable with a valid __getitem__ method
+        raise oefmt(space.w_TypeError,
+                    "'%T' object is not iterable", self)
+
     def descr_str(self, space):
         return space.wrap(self.get_dtype(space).itemtype.str_format(self))
 
     __new__ = interp2app(W_GenericBox.descr__new__.im_func),
 
     __getitem__ = interp2app(W_GenericBox.descr_getitem),
+    __iter__ = interp2app(W_GenericBox.descr_iter),
     __str__ = interp2app(W_GenericBox.descr_str),
     __repr__ = interp2app(W_GenericBox.descr_str),
     __format__ = interp2app(W_GenericBox.descr_format),

File pypy/module/test_lib_pypy/numpypy/core/test_numeric.py

         assert d.dtype == dtype('int32')
         assert (d == [[1, 0, 0], [0, 1, 0], [0, 0, 1]]).all()
 
-
+    def test_scalar_iter(self):
+        from numpypy import int8, int16, int32, int64, float32, float64
+        for t in int8, int16, int32, int64, float32, float64:
+            try:
+                iter(t(17))
+            except TypeError:
+                pass
+            else:
+                assert False, "%s object should not be iterable." % t