Commits

Brian Kearns committed 4d472c7

improve the numpypy concatenate test

Comments (0)

Files changed (1)

pypy/module/micronumpy/test/test_numarray.py

-
-import py, sys
+import py
+import sys
 
 from pypy.conftest import option
 from pypy.module.micronumpy.appbridge import get_appbridge_cache
 from pypy.module.micronumpy.interp_numarray import W_NDimArray
 from pypy.module.micronumpy.test.test_base import BaseNumpyAppTest
 
+
 class MockDtype(object):
     class itemtype(object):
         @staticmethod
 def create_slice(a, chunks):
     return Chunks(chunks).apply(W_NDimArray(a)).implementation
 
+
 def create_array(*args, **kwargs):
     return W_NDimArray.from_shape(*args, **kwargs).implementation
 
+
 class TestNumArrayDirect(object):
     def newslice(self, *args):
         return self.space.newslice(*[self.space.wrap(arg) for arg in args])
         assert d.shape == (3, 3)
         assert d.dtype == dtype('int32')
         assert (d == [[1, 0, 0], [0, 1, 0], [0, 0, 1]]).all()
-   
+
     def test_eye(self):
         from _numpypy import eye
         from _numpypy import int32, dtype
         assert g.shape == (3, 4)
         assert (g == [[0, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0]]).all()
 
-
-
-
     def test_prod(self):
         from _numpypy import array
         a = array(range(1, 6))
         assert len(a) == 6
         assert (a == [0,1,2,3,4,5]).all()
         assert a.dtype is dtype(int)
+        a = concatenate((a1, a2), axis=1)
+        assert (a == [0,1,2,3,4,5]).all()
+        a = concatenate((a1, a2), axis=-1)
+        assert (a == [0,1,2,3,4,5]).all()
+
         b1 = array([[1, 2], [3, 4]])
         b2 = array([[5, 6]])
         b = concatenate((b1, b2), axis=0)
         f = concatenate((f1, [2], f1, [7]))
         assert (f == [0,1,2,0,1,7]).all()
 
-        bad_axis = raises(IndexError, concatenate, (a1,a2), axis=1)
-        assert str(bad_axis.value) == "axis 1 out of bounds [0, 1)"
+        g1 = array([[0,1,2]])
+        g2 = array([[3,4,5]])
+        g = concatenate((g1, g2), axis=-2)
+        assert (g == [[0,1,2],[3,4,5]]).all()
+        exc = raises(IndexError, concatenate, (g1, g2), axis=2)
+        assert str(exc.value) == "axis 2 out of bounds [0, 2)"
+        exc = raises(IndexError, concatenate, (g1, g2), axis=-3)
+        assert str(exc.value) == "axis -3 out of bounds [0, 2)"
 
-        concat_zero = raises(ValueError, concatenate, ())
-        assert str(concat_zero.value) == \
-            "need at least one array to concatenate"
+        exc = raises(ValueError, concatenate, ())
+        assert str(exc.value) == \
+                "need at least one array to concatenate"
 
-        dims_disagree = raises(ValueError, concatenate, (a1, b1), axis=0)
-        assert str(dims_disagree.value) == \
-            "all the input arrays must have same number of dimensions"
+        exc = raises(ValueError, concatenate, (a1, b1), axis=0)
+        assert str(exc.value) == \
+                "all the input arrays must have same number of dimensions"
+
+        g1 = array([0,1,2])
+        g2 = array([[3,4,5]])
+        exc = raises(ValueError, concatenate, (g1, g2), axis=2)
+        assert str(exc.value) == \
+                "all the input arrays must have same number of dimensions"
+
         a = array([1, 2, 3, 4, 5, 6])
         a = (a + a)[::2]
         b = concatenate((a[:3], a[-3:]))
                                            [[3, 9], [6, 12]]])).all() 
         assert (x.swapaxes(1, 2) == array([[[1, 4], [2, 5], [3, 6]], 
                                            [[7, 10], [8, 11],[9, 12]]])).all() 
-        
+
         # test slice
         assert (x[0:1,0:2].swapaxes(0,2) == array([[[1], [4]], [[2], [5]], 
                                                    [[3], [6]]])).all()
                             False, False, True, False, False, False]).all()
         assert ((b > range(12)) == [False, True, True,False, True, True,
                             False, False, True, False, False, False]).all()
+
     def test_flatiter_view(self):
         from _numpypy import arange
         a = arange(10).reshape(5, 2)
         a = arange(12).reshape(2, 3, 2)
         assert (a.diagonal(0, 0, 1) == [[0, 8], [1, 9]]).all()
         assert a.diagonal(3, 0, 1).shape == (2, 0)
-        assert (a.diagonal(1, 0, 1) == [[2, 10], [3, 11]]).all()         
-        assert (a.diagonal(0, 2, 1) == [[0, 3], [6, 9]]).all()        
-        assert (a.diagonal(2, 2, 1) == [[4], [10]]).all()        
-        assert (a.diagonal(1, 2, 1) == [[2, 5], [8, 11]]).all()        
+        assert (a.diagonal(1, 0, 1) == [[2, 10], [3, 11]]).all()
+        assert (a.diagonal(0, 2, 1) == [[0, 3], [6, 9]]).all()
+        assert (a.diagonal(2, 2, 1) == [[4], [10]]).all()
+        assert (a.diagonal(1, 2, 1) == [[2, 5], [8, 11]]).all()
 
     def test_diagonal_axis_neg_ofs(self):
         from _numpypy import arange
         assert (a.diagonal(-1, 0, 1) == [[6], [7]]).all()
         assert a.diagonal(-2, 0, 1).shape == (2, 0)
 
+
 class AppTestSupport(BaseNumpyAppTest):
     def setup_class(cls):
         import struct
         assert (a.argsort(axis=0) == [[1, 0, 0], [0, 1, 1]]).all()
         assert (a.argsort(axis=1) == [[2, 1, 0], [0, 1, 2]]).all()
 
+
 class AppTestRanges(BaseNumpyAppTest):
     def test_arange(self):
         from _numpypy import arange, dtype
         cache.w_array_repr = cls.old_array_repr
         cache.w_array_str = cls.old_array_str
 
+
 class AppTestRecordDtype(BaseNumpyAppTest):
     def test_zeros(self):
         from _numpypy import zeros, integer
         assert x.__pypy_data__ is obj
         del x.__pypy_data__
         assert x.__pypy_data__ is None
-