Alex Gaynor avatar Alex Gaynor committed 923df65

issue #1005 -- accept 'float' as a dtype

Comments (0)

Files changed (2)

pypy/module/micronumpy/interp_dtype.py

 class W_Dtype(Wrappable):
     _immutable_fields_ = ["itemtype", "num", "kind"]
 
-    def __init__(self, itemtype, num, kind, name, char, w_box_type, alternate_constructors=[]):
+    def __init__(self, itemtype, num, kind, name, char, w_box_type, alternate_constructors=[], aliases=[]):
         self.itemtype = itemtype
         self.num = num
         self.kind = kind
         self.char = char
         self.w_box_type = w_box_type
         self.alternate_constructors = alternate_constructors
+        self.aliases = aliases
 
     def malloc(self, length):
         # XXX find out why test_zjit explodes with tracking of allocations
         elif space.isinstance_w(w_dtype, space.w_str):
             name = space.str_w(w_dtype)
             for dtype in cache.builtin_dtypes:
-                if dtype.name == name or dtype.char == name:
+                if dtype.name == name or dtype.char == name or name in dtype.aliases:
                     return dtype
         else:
             for dtype in cache.builtin_dtypes:
             kind=BOOLLTR,
             name="bool",
             char="?",
-            w_box_type = space.gettypefor(interp_boxes.W_BoolBox),
+            w_box_type=space.gettypefor(interp_boxes.W_BoolBox),
             alternate_constructors=[space.w_bool],
         )
         self.w_int8dtype = W_Dtype(
             kind=SIGNEDLTR,
             name="int8",
             char="b",
-            w_box_type = space.gettypefor(interp_boxes.W_Int8Box)
+            w_box_type=space.gettypefor(interp_boxes.W_Int8Box)
         )
         self.w_uint8dtype = W_Dtype(
             types.UInt8(),
             kind=UNSIGNEDLTR,
             name="uint8",
             char="B",
-            w_box_type = space.gettypefor(interp_boxes.W_UInt8Box),
+            w_box_type=space.gettypefor(interp_boxes.W_UInt8Box),
         )
         self.w_int16dtype = W_Dtype(
             types.Int16(),
             kind=SIGNEDLTR,
             name="int16",
             char="h",
-            w_box_type = space.gettypefor(interp_boxes.W_Int16Box),
+            w_box_type=space.gettypefor(interp_boxes.W_Int16Box),
         )
         self.w_uint16dtype = W_Dtype(
             types.UInt16(),
             kind=UNSIGNEDLTR,
             name="uint16",
             char="H",
-            w_box_type = space.gettypefor(interp_boxes.W_UInt16Box),
+            w_box_type=space.gettypefor(interp_boxes.W_UInt16Box),
         )
         self.w_int32dtype = W_Dtype(
             types.Int32(),
             kind=SIGNEDLTR,
             name="int32",
             char="i",
-             w_box_type = space.gettypefor(interp_boxes.W_Int32Box),
+            w_box_type=space.gettypefor(interp_boxes.W_Int32Box),
        )
         self.w_uint32dtype = W_Dtype(
             types.UInt32(),
             kind=UNSIGNEDLTR,
             name="uint32",
             char="I",
-            w_box_type = space.gettypefor(interp_boxes.W_UInt32Box),
+            w_box_type=space.gettypefor(interp_boxes.W_UInt32Box),
         )
         if LONG_BIT == 32:
             name = "int32"
             kind=SIGNEDLTR,
             name=name,
             char="l",
-            w_box_type = space.gettypefor(interp_boxes.W_LongBox),
+            w_box_type=space.gettypefor(interp_boxes.W_LongBox),
             alternate_constructors=[space.w_int],
         )
         self.w_ulongdtype = W_Dtype(
             kind=UNSIGNEDLTR,
             name="u" + name,
             char="L",
-            w_box_type = space.gettypefor(interp_boxes.W_ULongBox),
+            w_box_type=space.gettypefor(interp_boxes.W_ULongBox),
         )
         self.w_int64dtype = W_Dtype(
             types.Int64(),
             kind=SIGNEDLTR,
             name="int64",
             char="q",
-            w_box_type = space.gettypefor(interp_boxes.W_Int64Box),
+            w_box_type=space.gettypefor(interp_boxes.W_Int64Box),
             alternate_constructors=[space.w_long],
         )
         self.w_uint64dtype = W_Dtype(
             kind=UNSIGNEDLTR,
             name="uint64",
             char="Q",
-            w_box_type = space.gettypefor(interp_boxes.W_UInt64Box),
+            w_box_type=space.gettypefor(interp_boxes.W_UInt64Box),
         )
         self.w_float32dtype = W_Dtype(
             types.Float32(),
             kind=FLOATINGLTR,
             name="float32",
             char="f",
-            w_box_type = space.gettypefor(interp_boxes.W_Float32Box),
+            w_box_type=space.gettypefor(interp_boxes.W_Float32Box),
         )
         self.w_float64dtype = W_Dtype(
             types.Float64(),
             char="d",
             w_box_type = space.gettypefor(interp_boxes.W_Float64Box),
             alternate_constructors=[space.w_float],
+            aliases=["float"],
         )
 
         self.builtin_dtypes = [

pypy/module/micronumpy/test/test_dtypes.py

         # You can't subclass dtype
         raises(TypeError, type, "Foo", (dtype,), {})
 
-    def test_new(self):
-        import _numpypy as np
-        assert np.int_(4) == 4
-        assert np.float_(3.4) == 3.4
+    def test_aliases(self):
+        from _numpypy import dtype
 
-    def test_pow(self):
-        from _numpypy import int_
-        assert int_(4) ** 2 == 16
+        assert dtype("float") is dtype(float)
+
 
 class AppTestTypes(BaseNumpyAppTest):
     def test_abstract_types(self):
         raises(TypeError, numpy.floating, 0)
         raises(TypeError, numpy.inexact, 0)
 
+    def test_new(self):
+        import _numpypy as np
+        assert np.int_(4) == 4
+        assert np.float_(3.4) == 3.4
+
+    def test_pow(self):
+        from _numpypy import int_
+        assert int_(4) ** 2 == 16
+
     def test_bool(self):
         import _numpypy as numpy
 
         else:
             raises(OverflowError, numpy.int64, 9223372036854775807)
             raises(OverflowError, numpy.int64, '9223372036854775807')
-        
+
         raises(OverflowError, numpy.int64, 9223372036854775808)
         raises(OverflowError, numpy.int64, '9223372036854775808')
 
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.