Commits

Roberto De Almeida committed ce48c19

Improved type conversion between numpy and netcdf.

  • Participants
  • Parent commits 3cbede8

Comments (0)

Files changed (1)

 NC_ATTRIBUTE = asbytes('\x00\x00\x00\x0c')
 
 
-TYPEMAP = { NC_BYTE:   ('b', 1),
-            NC_CHAR:   ('c', 1),
-            NC_SHORT:  ('h', 2),
-            NC_INT:    ('i', 4),
-            NC_FLOAT:  ('f', 4),
-            NC_DOUBLE: ('d', 8) }
+TYPEMAP = { NC_BYTE:   dtype(np.byte),
+            NC_CHAR:   dtype('c'),
+            NC_SHORT:  dtype(np.int16).newbyteorder('>'),
+            NC_INT:    dtype(np.int32).newbyteorder('>'),
+            NC_FLOAT:  dtype(np.float32).newbyteorder('>'),
+            NC_DOUBLE: dtype(np.float64).newbyteorder('>'),
+            }
 
-REVERSE = { ('b', 1): NC_BYTE,
-            ('B', 1): NC_CHAR,
-            ('c', 1): NC_CHAR,
-            ('h', 2): NC_SHORT,
-            ('i', 4): NC_INT,
-            ('f', 4): NC_FLOAT,
-            ('d', 8): NC_DOUBLE,
-
-            # these come from asarray(1).dtype.char and asarray('foo').dtype.char,
-            # used when getting the types from generic attributes.
-            ('l', 4): NC_INT,
-            ('S', 1): NC_CHAR }
+REVERSE = { dtype(np.byte):    NC_BYTE,
+            dtype('c'):        NC_CHAR,
+            dtype(np.int16):   NC_SHORT,
+            dtype(np.int32):   NC_INT,
+            dtype(np.int64):   NC_INT,  # will be converted to int32
+            dtype(np.float32): NC_FLOAT,
+            dtype(np.float64): NC_DOUBLE,
+            }
 
 
 class netcdf_file(object):
         shape_ = tuple([dim or 0 for dim in shape])  # replace None with 0 for numpy
 
         if isinstance(type, basestring): type = dtype(type)
-        typecode, size = type.char, type.itemsize
-        if (typecode, size) not in REVERSE:
+        if type not in REVERSE:
             raise ValueError("NetCDF 3 does not support type %s" % type)
-        dtype_ = '>%s' % typecode
-        if size > 1: dtype_ += str(size)
 
-        data = empty(shape_, dtype=dtype_)
-        self.variables[name] = netcdf_variable(data, typecode, size, shape, dimensions)
+        data = empty(shape_, type)
+        self.variables[name] = netcdf_variable(data, type, shape, dimensions)
         return self.variables[name]
 
     def flush(self):
 
         self._write_att_array(var._attributes)
 
-        nc_type = REVERSE[var.typecode(), var.itemsize()]
+        nc_type = REVERSE[var.dtype]
         self.fp.write(asbytes(nc_type))
 
         if not var.isrec:
         self.fp.seek(the_beguine)
 
         # Write data.
+        if (var.data.dtype.byteorder == '<' or 
+                (var.data.dtype.byteorder == '=' and LITTLE_ENDIAN)):
+            var.data = var.data.byteswap()
+
         if not var.isrec:
             self.fp.write(var.data.tostring())
             count = var.data.size * var.data.itemsize
 
             pos0 = pos = self.fp.tell()
             for rec in var.data:
-                # Apparently scalars cannot be converted to big endian. If we
-                # try to convert a ``=i4`` scalar to, say, '>i4' the dtype
-                # will remain as ``=i4``.
-                if not rec.shape and (rec.dtype.byteorder == '<' or
-                        (rec.dtype.byteorder == '=' and LITTLE_ENDIAN)):
-                    rec = rec.byteswap()
                 self.fp.write(rec.tostring())
                 # Padding
                 count = rec.size * rec.itemsize
 
     def _write_values(self, values):
         if hasattr(values, 'dtype'):
-            nc_type = REVERSE[values.dtype.char, values.dtype.itemsize]
+            nc_type = REVERSE[values.dtype]
         else:
             types = [
                     (int, NC_INT),
             for class_, nc_type in types:
                 if isinstance(sample, class_): break
 
-        typecode, size = TYPEMAP[nc_type]
-        dtype_ = '>%s' % typecode
-
-        values = asarray(values, dtype=dtype_)
+        values = asarray(values, TYPEMAP[nc_type])
 
         self.fp.write(asbytes(nc_type))
 
         rec_vars = []
         count = self._unpack_int()
         for var in range(count):
-            (name, dimensions, shape, attributes,
-             typecode, size, dtype_, begin_, vsize) = self._read_var()
+            name, dimensions, shape, attributes, type, begin_, vsize = self._read_var()
             # http://www.unidata.ucar.edu/software/netcdf/docs/netcdf.html
             # Note that vsize is the product of the dimension lengths
             # (omitting the record dimension) and the number of bytes
                 self.__dict__['_recsize'] += vsize
                 if begin == 0: begin = begin_
                 dtypes['names'].append(name)
-                dtypes['formats'].append(str(shape[1:]) + dtype_)
+                dtypes['formats'].append(str(shape[1:]) + type)
 
                 # Handle padding with a virtual variable.
-                if typecode in 'bch':
-                    actual_size = reduce(mul, (1,) + shape[1:]) * size
+                if type.char in 'bch':
+                    actual_size = reduce(mul, (1,) + shape[1:]) * type.itemsize
                     padding = -actual_size % 4
                     if padding:
                         dtypes['names'].append('_padding_%d' % var)
                 data = None
             else:  # not a record variable
                 # Calculate size to avoid problems with vsize (above)
-                a_size = reduce(mul, shape, 1) * size
+                a_size = reduce(mul, shape, 1) * type.itemsize
                 pos = self.fp.tell()
                 if self.use_mmap:
                     mm = mmap(self.fp.fileno(), begin_+a_size, access=ACCESS_READ)
-                    data = ndarray.__new__(ndarray, shape, dtype=dtype_,
+                    data = ndarray.__new__(ndarray, shape, dtype=type,
                             buffer=mm, offset=begin_, order=0)
                 else:
                     self.fp.seek(begin_)
-                    data = fromstring(self.fp.read(a_size), dtype=dtype_)
+                    data = fromstring(self.fp.read(a_size), type)
                     data.shape = shape
                 self.fp.seek(pos)
 
             # Add variable.
-            self.variables[name] = netcdf_variable(
-                    data, typecode, size, shape, dimensions, attributes)
+            self.variables[name] = netcdf_variable(data, type, shape, dimensions, attributes)
 
         if rec_vars:
             # Remove padding when only one record variable.
         nc_type = self.fp.read(4)
         vsize = self._unpack_int()
         begin = [self._unpack_int, self._unpack_int64][self.version_byte-1]()
+        type = TYPEMAP[nc_type]
 
-        typecode, size = TYPEMAP[nc_type]
-        dtype_ = '>%s' % typecode
-
-        return name, dimensions, shape, attributes, typecode, size, dtype_, begin, vsize
+        return name, dimensions, shape, attributes, type, begin, vsize
 
     def _read_values(self):
         nc_type = self.fp.read(4)
         n = self._unpack_int()
 
-        typecode, size = TYPEMAP[nc_type]
+        type = TYPEMAP[nc_type]
 
-        count = n*size
+        count = n*type.itemsize
         values = self.fp.read(int(count))
         self.fp.read(-count % 4)  # read padding
 
-        if typecode is not 'c':
-            values = fromstring(values, dtype='>%s' % typecode)
+        if type.char is not 'c':
+            values = fromstring(values, type)
             if values.shape == (1,): values = values[0]
         else:
             ## text values are encoded via UTF-8, per NetCDF standard
     data : array_like
         The data array that holds the values for the variable.
         Typically, this is initialized as empty, but with the proper shape.
-    typecode : dtype character code
+    type: numpy dtype
         Desired data-type for the data array.
-    size : int
-        Desired element size for the data array.
     shape : sequence of ints
         The shape of the array.  This should match the lengths of the
         variable's dimensions.
     isrec, shape
 
     """
-    def __init__(self, data, typecode, size, shape, dimensions, attributes=None):
+    def __init__(self, data, type, shape, dimensions, attributes=None):
         self.data = data
-        self._typecode = typecode
-        self._size = size
+        self.dtype = type
         self._shape = shape
         self.dimensions = dimensions
 
             The character typecode of the variable (eg, 'i' for int).
 
         """
-        return self._typecode
+        return self.dtype.char
 
     def itemsize(self):
         """
             The element size of the variable (eg, 8 for float64).
 
         """
-        return self._size
+        return self.dtype.itemsize
 
     def __getitem__(self, index):
         return self.data[index]