Commits

Matthew Turk  committed 7f49529 Merge

Merged in samskillman/yt (pull request #435)

Upgrades to ImageArray, Camera image saving, handling arbitrary background colors.

  • Participants
  • Parent commits 5fb2661, 0d2a983

Comments (0)

Files changed (7)

File yt/data_objects/image_array.py

 
     >>> im = np.zeros([64,128,3])
     >>> for i in xrange(im.shape[0]):
-    >>>     for k in xrange(im.shape[2]):
-    >>>         im[i,:,k] = np.linspace(0.,0.3*k, im.shape[1])
+    ...     for k in xrange(im.shape[2]):
+    ...         im[i,:,k] = np.linspace(0.,0.3*k, im.shape[1])
 
     >>> myinfo = {'field':'dinosaurs', 'east_vector':np.array([1.,0.,0.]), 
-    >>>     'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),  
-    >>>     'width':0.245, 'units':'cm', 'type':'rendering'}
+    ...     'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),  
+    ...     'width':0.245, 'units':'cm', 'type':'rendering'}
 
     >>> im_arr = ImageArray(im, info=myinfo)
     >>> im_arr.save('test_ImageArray')
         -------- 
         >>> im = np.zeros([64,128,3])
         >>> for i in xrange(im.shape[0]):
-        >>>     for k in xrange(im.shape[2]):
-        >>>         im[i,:,k] = np.linspace(0.,0.3*k, im.shape[1])
+        ...     for k in xrange(im.shape[2]):
+        ...         im[i,:,k] = np.linspace(0.,0.3*k, im.shape[1])
 
         >>> myinfo = {'field':'dinosaurs', 'east_vector':np.array([1.,0.,0.]), 
-        >>>     'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),  
-        >>>     'width':0.245, 'units':'cm', 'type':'rendering'}
+        ...     'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),  
+        ...     'width':0.245, 'units':'cm', 'type':'rendering'}
 
         >>> im_arr = ImageArray(im, info=myinfo)
         >>> im_arr.write_hdf5('test_ImageArray.h5')
             d.attrs.create(k, v)
         f.close()
 
-    def write_png(self, filename, clip_ratio=None):
+    def add_background_color(self, background='black', inline=True):
+        r"""Adds a background color to a 4-channel ImageArray
+
+        This adds a background color to a 4-channel ImageArray, by default
+        doing so inline.  The ImageArray must already be normalized to the
+        [0,1] range.
+
+        Parameters
+        ----------
+        background: 
+            This can be used to set a background color for the image, and can
+            take several types of values:
+
+               * ``white``: white background, opaque
+               * ``black``: black background, opaque
+               * ``None``: transparent background
+               * 4-element array [r,g,b,a]: arbitrary rgba setting.
+
+            Default: 'black'
+        inline: boolean, optional
+            If True, original ImageArray is modified. If False, a copy is first
+            created, then modified. Default: True
+
+        Returns
+        -------
+        out: ImageArray
+            The modified ImageArray with a background color added.
+       
+        Examples
+        --------
+        >>> im = np.zeros([64,128,4])
+        >>> for i in xrange(im.shape[0]):
+        ...     for k in xrange(im.shape[2]):
+        ...         im[i,:,k] = np.linspace(0.,10.*k, im.shape[1])
+
+        >>> im_arr = ImageArray(im)
+        >>> im_arr.rescale()
+        >>> new_im = im_arr.add_background_color([1.,0.,0.,1.], inline=False)
+        >>> new_im.write_png('red_bg.png')
+        >>> im_arr.add_background_color('black')
+        >>> im_arr.write_png('black_bg.png')
+        """
+        assert(self.shape[-1] == 4)
+        
+        if background == None:
+            background = (0., 0., 0., 0.)
+        elif background == 'white':
+            background = (1., 1., 1., 1.)
+        elif background == 'black':
+            background = (0., 0., 0., 1.)
+
+        # Alpha blending to background
+        if inline:
+            out = self
+        else:
+            out = self.copy()
+
+        for i in range(3):
+            out[:,:,i] = self[:,:,i]*self[:,:,3] + \
+                    background[i]*background[3]*(1.0-self[:,:,3])
+        out[:,:,3] = self[:,:,3] + background[3]*(1.0-self[:,:,3]) 
+        return out 
+
+
+    def rescale(self, cmax=None, amax=None, inline=True):
+        r"""Rescales the image to be in [0,1] range.
+
+        Parameters
+        ----------
+        cmax: float, optional
+            Normalization value to use for rgb channels. Defaults to None,
+            corresponding to using the maximum value in the rgb channels.
+        amax: float, optional
+            Normalization value to use for alpha channel. Defaults to None,
+            corresponding to using the maximum value in the alpha channel.
+        inline: boolean, optional
+            Specifies whether or not the rescaling is done inline. If false,
+            a new copy of the ImageArray will be created, returned. 
+            Default:True.
+
+        Returns
+        -------
+        out: ImageArray
+            The rescaled ImageArray, clipped to the [0,1] range.
+
+        Notes
+        -----
+        This requires that the shape of the ImageArray to have a length of 3,
+        and for the third dimension to be >= 3.  If the third dimension has
+        a shape of 4, the alpha channel will also be rescaled.
+       
+        Examples
+        -------- 
+        >>> im = np.zeros([64,128,4])
+        >>> for i in xrange(im.shape[0]):
+        ...     for k in xrange(im.shape[2]):
+        ...         im[i,:,k] = np.linspace(0.,0.3*k, im.shape[1])
+
+        >>> im_arr.write_png('original.png')
+        >>> im_arr.rescale()
+        >>> im_arr.write_png('normalized.png')
+
+        """
+        assert(len(self.shape) == 3)
+        assert(self.shape[2] >= 3)
+        if inline:
+            out = self
+        else:
+            out = self.copy()
+        if cmax is None: 
+            cmax = self[:,:,:3].sum(axis=2).max()
+
+        np.multiply(self[:,:,:3], 1./cmax, out[:,:,:3])
+
+        if self.shape[2] == 4:
+            if amax is None:
+                amax = self[:,:,3].max()
+            if amax > 0.0:
+                np.multiply(self[:,:,3], 1./amax, out[:,:,3])
+        
+        np.clip(out, 0.0, 1.0, out)
+        return out
+
+    def write_png(self, filename, clip_ratio=None, background='black',
+                 rescale=True):
         r"""Writes ImageArray to png file.
 
         Parameters
         ----------
         filename: string
             Note filename not be modified.
+        clip_ratio: float, optional
+            Image will be clipped before saving to the standard deviation
+            of the image multiplied by this value.  Useful for enhancing 
+            images. Default: None
+        background: 
+            This can be used to set a background color for the image, and can
+            take several types of values:
+
+               * ``white``: white background, opaque
+               * ``black``: black background, opaque
+               * ``None``: transparent background
+               * 4-element array [r,g,b,a]: arbitrary rgba setting.
+
+            Default: 'black'
+        rescale: boolean, optional
+            If True, will write out a rescaled image (without modifying the
+            original image). Default: True
        
         Examples
         --------
-        
-        >>> im = np.zeros([64,128,3])
+        >>> im = np.zeros([64,128,4])
         >>> for i in xrange(im.shape[0]):
-        >>>     for k in xrange(im.shape[2]):
-        >>>         im[i,:,k] = np.linspace(0.,0.3*k, im.shape[1])
+        ...     for k in xrange(im.shape[2]):
+        ...         im[i,:,k] = np.linspace(0.,10.*k, im.shape[1])
 
-        >>> myinfo = {'field':'dinosaurs', 'east_vector':np.array([1.,0.,0.]), 
-        >>>     'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),  
-        >>>     'width':0.245, 'units':'cm', 'type':'rendering'}
-
-        >>> im_arr = ImageArray(im, info=myinfo)
-        >>> im_arr.write_png('test_ImageArray.png')
+        >>> im_arr = ImageArray(im)
+        >>> im_arr.write_png('standard.png')
+        >>> im_arr.write_png('non-scaled.png', rescale=False)
+        >>> im_arr.write_png('black_bg.png', background='black')
+        >>> im_arr.write_png('white_bg.png', background='white')
+        >>> im_arr.write_png('green_bg.png', background=[0,1,0,1])
+        >>> im_arr.write_png('transparent_bg.png', background=None)
 
         """
+        if rescale:
+            scaled = self.rescale(inline=False)
+        else:
+            scaled = self
+
+        if self.shape[-1] == 4:
+            out = scaled.add_background_color(background, inline=False)
+        else:
+            out = scaled
+
         if filename[-4:] != '.png': 
             filename += '.png'
 
         if clip_ratio is not None:
-            return write_bitmap(self.swapaxes(0, 1), filename,
-                                clip_ratio * self.std())
+            nz = out[:,:,:3][out[:,:,:3].nonzero()]
+            return write_bitmap(out.swapaxes(0, 1), filename,
+                                nz.mean() + \
+                                clip_ratio * nz.std())
         else:
-            return write_bitmap(self.swapaxes(0, 1), filename)
+            return write_bitmap(out.swapaxes(0, 1), filename)
 
     def write_image(self, filename, color_bounds=None, channel=None,  cmap_name="algae", func=lambda x: x):
         r"""Writes a single channel of the ImageArray to a png file.
         
         >>> im = np.zeros([64,128])
         >>> for i in xrange(im.shape[0]):
-        >>>     im[i,:] = np.linspace(0.,0.3*k, im.shape[1])
+        ...     im[i,:] = np.linspace(0.,0.3*k, im.shape[1])
 
         >>> myinfo = {'field':'dinosaurs', 'east_vector':np.array([1.,0.,0.]), 
-        >>>     'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),  
-        >>>     'width':0.245, 'units':'cm', 'type':'rendering'}
+        ...     'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),  
+        ...     'width':0.245, 'units':'cm', 'type':'rendering'}
 
         >>> im_arr = ImageArray(im, info=myinfo)
         >>> im_arr.write_image('test_ImageArray.png')
 
     __doc__ += np.ndarray.__doc__
 
-if __name__ == "__main__":
-    im = np.zeros([64,128,3])
-    for i in xrange(im.shape[0]):
-        for k in xrange(im.shape[2]):
-            im[i,:,k] = np.linspace(0.,0.3*k, im.shape[1])
-
-    myinfo = {'field':'dinosaurs', 'east_vector':np.array([1.,0.,0.]), 
-        'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),  
-        'width':0.245, 'units':'cm', 'type':'rendering'}
-
-    im_arr = ImageArray(im, info=myinfo)
-    im_arr.save('test_3d_ImageArray')
-
-    im = np.zeros([64,128])
-    for i in xrange(im.shape[0]):
-        im[i,:] = np.linspace(0.,0.3*k, im.shape[1])
-
-    myinfo = {'field':'dinosaurs', 'east_vector':np.array([1.,0.,0.]), 
-        'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),  
-        'width':0.245, 'units':'cm', 'type':'rendering'}
-
-    im_arr = ImageArray(im, info=myinfo)
-    im_arr.save('test_2d_ImageArray')
-

File yt/data_objects/tests/test_image_array.py

+from yt.testing import *
+from yt.data_objects.image_array import ImageArray
+import numpy as np
+import os
+import tempfile
+import shutil
+
+def setup():
+    from yt.config import ytcfg
+    ytcfg["yt","__withintesting"] = "True"
+    np.seterr(all = 'ignore')
+
+def test_rgba_rescale():
+    im = np.zeros([64,128,4])
+    for i in xrange(im.shape[0]):
+        for k in xrange(im.shape[2]):
+            im[i,:,k] = np.linspace(0.,10.*k, im.shape[1])
+    im_arr = ImageArray(im)
+
+    new_im = im_arr.rescale(inline=False)
+    yield assert_equal, im_arr[:,:,:3].max(), 2*10.
+    yield assert_equal, im_arr[:,:,3].max(), 3*10.
+    yield assert_equal, new_im[:,:,:3].sum(axis=2).max(), 1.0 
+    yield assert_equal, new_im[:,:,3].max(), 1.0
+
+    im_arr.rescale()
+    yield assert_equal, im_arr[:,:,:3].sum(axis=2).max(), 1.0
+    yield assert_equal, im_arr[:,:,3].max(), 1.0
+
+def test_image_array_hdf5():
+    # Perform I/O in safe place instead of yt main dir
+    tmpdir = tempfile.mkdtemp()
+    curdir = os.getcwd()
+    os.chdir(tmpdir)
+
+    im = np.zeros([64,128,3])
+    for i in xrange(im.shape[0]):
+        for k in xrange(im.shape[2]):
+            im[i,:,k] = np.linspace(0.,0.3*k, im.shape[1])
+
+    myinfo = {'field':'dinosaurs', 'east_vector':np.array([1.,0.,0.]), 
+        'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),  
+        'width':0.245, 'units':'cm', 'type':'rendering'}
+
+    im_arr = ImageArray(im, info=myinfo)
+    im_arr.save('test_3d_ImageArray')
+
+    im = np.zeros([64,128])
+    for i in xrange(im.shape[0]):
+        im[i,:] = np.linspace(0.,0.3*k, im.shape[1])
+
+    myinfo = {'field':'dinosaurs', 'east_vector':np.array([1.,0.,0.]), 
+        'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),  
+        'width':0.245, 'units':'cm', 'type':'rendering'}
+
+    im_arr = ImageArray(im, info=myinfo)
+    im_arr.save('test_2d_ImageArray')
+
+    os.chdir(curdir)
+    # clean up
+    shutil.rmtree(tmpdir)
+
+def test_image_array_rgb_png():
+    # Perform I/O in safe place instead of yt main dir
+    tmpdir = tempfile.mkdtemp()
+    curdir = os.getcwd()
+    os.chdir(tmpdir)
+
+    im = np.zeros([64,128,3])
+    for i in xrange(im.shape[0]):
+        for k in xrange(im.shape[2]):
+            im[i,:,k] = np.linspace(0.,10.*k, im.shape[1])
+
+    im_arr = ImageArray(im)
+    im_arr.write_png('standard.png')
+
+def test_image_array_rgba_png():
+    # Perform I/O in safe place instead of yt main dir
+    tmpdir = tempfile.mkdtemp()
+    curdir = os.getcwd()
+    os.chdir(tmpdir)
+
+    im = np.zeros([64,128,4])
+    for i in xrange(im.shape[0]):
+        for k in xrange(im.shape[2]):
+            im[i,:,k] = np.linspace(0.,10.*k, im.shape[1])
+
+    im_arr = ImageArray(im)
+    im_arr.write_png('standard.png')
+    im_arr.write_png('non-scaled.png', rescale=False)
+    im_arr.write_png('black_bg.png', background='black')
+    im_arr.write_png('white_bg.png', background='white')
+    im_arr.write_png('green_bg.png', background=[0.,1.,0.,1.])
+    im_arr.write_png('transparent_bg.png', background=None)
+
+
+def test_image_array_background():
+    # Perform I/O in safe place instead of yt main dir
+    tmpdir = tempfile.mkdtemp()
+    curdir = os.getcwd()
+    os.chdir(tmpdir)
+
+    im = np.zeros([64,128,4])
+    for i in xrange(im.shape[0]):
+        for k in xrange(im.shape[2]):
+            im[i,:,k] = np.linspace(0.,10.*k, im.shape[1])
+
+    im_arr = ImageArray(im)
+    im_arr.rescale()
+    new_im = im_arr.add_background_color([1.,0.,0.,1.], inline=False)
+    new_im.write_png('red_bg.png')
+    im_arr.add_background_color('black')
+    im_arr.write_png('black_bg2.png')
+ 
+    os.chdir(curdir)
+    # clean up
+    shutil.rmtree(tmpdir)
+
+
+
+
+
+
+
+
+
+
+
+
+

File yt/utilities/amr_kdtree/amr_kdtree.py

 from yt.utilities.lib.grid_traversal import PartitionedGrid
 from yt.utilities.math_utils import periodic_position
 
-import pdb
-
-def my_break():
-    my_debug = False 
-    if my_debug: pdb.set_trace()
-
 steps = np.array([[-1, -1, -1], [-1, -1,  0], [-1, -1,  1],
                   [-1,  0, -1], [-1,  0,  0], [-1,  0,  1],
                   [-1,  1, -1], [-1,  1,  0], [-1,  1,  1],
         self.build(grids)
 
     def add_grids(self, grids):
-        my_break() 
         lvl_range = range(self.min_level, self.max_level+1)
         if grids is None:
             level_iter = self.pf.hierarchy.get_levels()
                 gles =  np.array([g.LeftEdge for g in grids])[gmask]
                 gres =  np.array([g.RightEdge for g in grids])[gmask]
                 gids = np.array([g.id for g in grids])[gmask]
-                my_break()
                 add_grids(self.trunk, gles, gres, gids, self.comm_rank, self.comm_size)
                 del gles, gres, gids, grids
         else:

File yt/utilities/lib/misc_utilities.pyx

     cdef int nx = image.shape[0]
     cdef int ny = image.shape[1]
     cdef int nl = xs.shape[0]
-    cdef np.float64_t alpha[3], nalpha 
+    cdef np.float64_t alpha[4]
     cdef int i, j
     cdef int dx, dy, sx, sy, e2, err
     cdef np.int64_t x0, x1, y0, y1
+    cdef int has_alpha = (image.shape[-1] == 4)
     for j in range(0, nl, 2):
         # From wikipedia http://en.wikipedia.org/wiki/Bresenham's_line_algorithm
         x0 = xs[j]; y0 = ys[j]; x1 = xs[j+1]; y1 = ys[j+1]
         dx = abs(x1-x0)
         dy = abs(y1-y0)
         err = dx - dy
-        for i in range(3):
-            alpha[i] = colors[j/points_per_color,3]*colors[j/points_per_color,i]
-        nalpha = 1.0-colors[j/points_per_color,3]
+        if has_alpha:
+            for i in range(4):
+                alpha[i] = colors[j/points_per_color,i]
+        else:
+            for i in range(3):
+                alpha[i] = colors[j/points_per_color,3]*\
+                        colors[j/points_per_color,i]
         if x0 < x1: 
             sx = 1
         else:
             elif (y0 < 0 and sy == -1): break
             elif (y0 >= nx and sy == 1): break
             if (x0 >=0 and x0 < nx and y0 >= 0 and y0 < ny):
-                for i in range(3):
-                    image[x0,y0,i] = (1.-alpha[i])*image[x0,y0,i] + alpha[i]
+                if has_alpha:
+                    for i in range(4):
+                        image[x0,y0,i] = (1.-alpha[i])*image[x0,y0,i] + alpha[i]
+                else:
+                    for i in range(3):
+                        image[x0,y0,i] = (1.-alpha[i])*image[x0,y0,i] + alpha[i]
+
             if (x0 == x1 and y0 == y1):
                 break
             e2 = 2*err

File yt/visualization/image_writer.py

 from yt.funcs import *
 import _colormap_data as cmd
 import yt.utilities.lib as au
+import __builtin__
 
 def scale_image(image, mi=None, ma=None):
     r"""Scale an image ([NxNxM] where M = 1-4) to be uint8 and values scaled 
     r"""Write out a bitmapped image directly to a PNG file.
 
     This accepts a three- or four-channel `bitmap_array`.  If the image is not
-    already uint8, it will be scaled and converted.  If it is not four channel, a
-    fourth alpha channel will be added and set to fully opaque.  The resultant
-    image will be directly written to `filename` as a PNG with no colormap
-    applied.  `max_val` is a value used if the array is passed in as anything
-    other than uint8; it will be the value used for scaling and clipping when the
-    array is converted.  Additionally, the minimum is assumed to be zero; this
-    makes it primarily suited for the results of volume rendered images, rather
-    than misaligned projections.
+    already uint8, it will be scaled and converted.  If it is four channel,
+    only the first three channels will be scaled, while the fourth channel is
+    assumed to be in the range of [0,1]. If it is not four channel, a fourth
+    alpha channel will be added and set to fully opaque.  The resultant image
+    will be directly written to `filename` as a PNG with no colormap applied.
+    `max_val` is a value used if the array is passed in as anything other than
+    uint8; it will be the value used for scaling and clipping in the first
+    three channels when the array is converted.  Additionally, the minimum is
+    assumed to be zero; this makes it primarily suited for the results of
+    volume rendered images, rather than misaligned projections.
 
     Parameters
     ----------
         The upper limit to clip values to in the output, if converting to uint8.
         If `bitmap_array` is already uint8, this will be ignore.
     """
-    if bitmap_array.dtype != np.uint8:
-        if max_val is None: max_val = bitmap_array.max()
-        bitmap_array = np.clip(bitmap_array / max_val, 0.0, 1.0) * 255
-        bitmap_array = bitmap_array.astype("uint8")
     if len(bitmap_array.shape) != 3 or bitmap_array.shape[-1] not in (3,4):
         raise RuntimeError
-    if bitmap_array.shape[-1] == 3:
+    if bitmap_array.dtype != np.uint8:
         s1, s2 = bitmap_array.shape[:2]
-        alpha_channel = 255*np.ones((s1,s2,1), dtype='uint8')
-        bitmap_array = np.concatenate([bitmap_array, alpha_channel], axis=-1)
+        if bitmap_array.shape[-1] == 3:
+            alpha_channel = 255*np.ones((s1,s2,1), dtype='uint8')
+        else:
+            alpha_channel = (255*bitmap_array[:,:,3]).astype('uint8')
+            alpha_channel.shape = s1, s2, 1
+        if max_val is None: max_val = bitmap_array[:,:,:3].max()
+        bitmap_array = np.clip(bitmap_array[:,:,:3] / max_val, 0.0, 1.0) * 255
+        bitmap_array = np.concatenate([bitmap_array.astype('uint8'),
+                                       alpha_channel], axis=-1)
     if transpose:
         bitmap_array = bitmap_array.swapaxes(0,1)
     if filename is not None:
 
 
 def write_fits(image, filename_prefix, clobber=True, coords=None, gzip_file=False) :
-
     """
     This will export a FITS image of a floating point array. The output filename is
     *filename_prefix*. If clobber is set to True, this will overwrite any existing
         clob = ""
         if (clobber) : clob="-f"
         system("gzip "+clob+" %s.fits" % (filename_prefix))
+
+def display_in_notebook(image, max_val=None):
+    """
+    A helper function to display images in an IPython notebook
     
+    Must be run from within an IPython notebook, or else it will raise
+    a YTNotInsideNotebook exception.
+        
+    Parameters
+    ----------
+    image : array_like
+        This is an (unscaled) array of floating point values, shape (N,N,3) or
+        (N,N,4) to display in the notebook. The first three channels will be
+        scaled automatically.  
+    max_val : float, optional
+        The upper limit to clip values of the image.  Only applies to the first
+        three channels.
+    """
+ 
+    if "__IPYTHON__" in dir(__builtin__):
+        from IPython.core.displaypub import publish_display_data
+        data = write_bitmap(image, None, max_val=max_val)
+        publish_display_data(
+            'yt.visualization.image_writer.display_in_notebook',
+            {'image/png' : data}
+        )
+    else:
+        raise YTNotInsideNotebook
+

File yt/visualization/volume_rendering/blenders.py

             del nz
     np.clip(im, 0.0, 1.0, im)
 
+def enhance_rgba(im, stdval=6.0):
+    nzc = im[:,:,:3][im[:,:,:3]>0.0]
+    cmax = nzc.mean()+stdval*nzc.std()
+
+
+    nza = im[:,:,3][im[:,:,3]>0.0]
+    if len(nza) == 0:
+        im[:,:,3]=1.0
+        amax = 1.0
+    else:
+        amax = nza.mean()+stdval*nza.std()
+
+    im.rescale(amax=amax, cmax=cmax, inline=True)
+    np.clip(im, 0.0, 1.0, im)
+

File yt/visualization/volume_rendering/camera.py

 from yt.utilities.parallel_tools.parallel_analysis_interface import \
     ParallelAnalysisInterface, ProcessorPool, parallel_objects
 from yt.utilities.amr_kdtree.api import AMRKDTree
-from .blenders import  enhance
+from .blenders import  enhance_rgba
 from numpy import pi
 
 def get_corners(le, re):
         px, py, dz = self.project_to_plane(vertices, res=im.shape[:2])
         
         # Must normalize the image
-        ma = im.max()
-        if ma > 0.0: 
-            enhance(im)
+        nim = im.rescale(inline=False)
+        enhance_rgba(nim)
+        nim.add_background_color('black', inline=True)
        
-        lines(im, px, py, colors, 24)
+        lines(nim, px, py, colors, 24)
+        return nim
 
     def draw_line(self, im, x0, x1, color=None):
         r"""Draws a line on an existing volume rendering.
         >>> write_bitmap(im, 'render_with_domain_boundary.png')
 
         """
-
-        ma = im.max()
-        if ma > 0.0: 
-            enhance(im)
-        self.draw_box(im, self.pf.domain_left_edge, self.pf.domain_right_edge,
+        # Must normalize the image
+        nim = im.rescale(inline=False)
+        enhance_rgba(nim)
+        nim.add_background_color('black', inline=True)
+ 
+        self.draw_box(nim, self.pf.domain_left_edge, self.pf.domain_right_edge,
                         color=np.array([1.0,1.0,1.0,alpha]))
+        return nim
 
     def draw_box(self, im, le, re, color=None):
         r"""Draws a box on an existing volume rendering.
     def finalize_image(self, image):
         view_pos = self.front_center + self.orienter.unit_vectors[2] * 1.0e6 * self.width[2]
         image = self.volume.reduce_tree_images(image, view_pos)
+        if self.transfer_function.grey_opacity is False:
+            image[:,:,3]=1.0
         return image
 
     def _render(self, double_check, num_threads, image, sampler):
         self.annotate(ax.axes, enhance)
         self._pylab.savefig(fn, bbox_inches='tight', facecolor='black', dpi=dpi)
         
-    def save_image(self, fn, clip_ratio, image, transparent=False):
-        if self.comm.rank is 0 and fn is not None:
+    def save_image(self, image, fn=None, clip_ratio=None, transparent=False):
+        if self.comm.rank == 0 and fn is not None:
             if transparent:
-                image.write_png(fn, clip_ratio=clip_ratio)
+                image.write_png(fn, clip_ratio=clip_ratio, rescale=True,
+                                background=None)
             else:
-                image[:,:,:3].write_png(fn, clip_ratio=clip_ratio)
+                image.write_png(fn, clip_ratio=clip_ratio, rescale=True,
+                                background='black')
 
     def initialize_source(self):
         return self.volume.initialize_source()
         return info_dict
 
     def snapshot(self, fn = None, clip_ratio = None, double_check = False,
-                 num_threads = 0):
+                 num_threads = 0, transparent=False):
         r"""Ray-cast the camera.
 
         This method instructs the camera to take a snapshot -- i.e., call the ray
             If supplied, will use 'num_threads' number of OpenMP threads during
             the rendering.  Defaults to 0, which uses the environment variable
             OMP_NUM_THREADS.
+        transparent: bool, optional
+            Optionally saves out the 4-channel rgba image, which can appear 
+            empty if the alpha channel is low everywhere. Default: False
 
         Returns
         -------
         image = ImageArray(self._render(double_check, num_threads, 
                                         image, sampler),
                            info=self.get_information())
-        self.save_image(fn, clip_ratio, image)
+        self.save_image(image, fn=fn, clip_ratio=clip_ratio, 
+                       transparent=transparent)
         return image
 
     def show(self, clip_ratio = None):
         image = ImageArray(self._render(double_check, num_threads, 
                                         image, sampler),
                            info=self.get_information())
-        self.save_image(fn, clim, image, label = label)
+        self.save_image(image, fn=fn, clim=clim, label = label)
         return image
 
-    def save_image(self, fn, clim, image, label = None):
-        if self.comm.rank is 0 and fn is not None:
+    def save_image(self, image, fn=None, clim=None, label = None):
+        if self.comm.rank == 0 and fn is not None:
             # This assumes Density; this is a relatively safe assumption.
             import matplotlib.figure
             import matplotlib.backends.backend_agg
             sto.id = self.imj*self.nimx + self.imi
             sto.result = image
         image = self.reduce_images(my_storage)
-        self.save_image(fn, clip_ratio, image)
+        self.save_image(image, fn=fn, clip_ratio=clip_ratio)
         return image
 
     def reduce_images(self,im_dict):
         image = self.finalize_image(sampler.aimage)
         return image
 
-    def save_image(self, fn, clip_ratio, image):
+    def save_image(self, image, fn=None, clip_ratio=None):
         if self.pf.field_info[self.field].take_log:
             im = np.log10(image)
         else:
             im = image
-        if self.comm.rank is 0 and fn is not None:
+        if self.comm.rank == 0 and fn is not None:
             if clip_ratio is not None:
                 write_image(im, fn)
             else:
                                         image, sampler),
                            info=self.get_information())
 
-        self.save_image(fn, clip_ratio, image)
+        self.save_image(image, fn=fn, clip_ratio=clip_ratio)
 
         return image
     snapshot.__doc__ = Camera.snapshot.__doc__