Commits

Sam Skillman committed ddc0c5c

A bit of updating

Comments (0)

Files changed (2)

multiplot_tools/shared_axes.py

 import numpy as na
 import h5py as h5
 from matplotlib import pylab as pl
-from mpl_toolkits.axes_grid import ImageGrid 
+from mpl_toolkits.axes_grid import ImageGrid
 from matplotlib.ticker import FuncFormatter
 
 def make_axes_formatter(vmin, vmax):
 
 
 class Multipanel(object):
-    def __init__(self,nx=1,ny=1,n_colorbar=1,padding=0.0):
+    def __init__(self,nx=1,ny=1,n_colorbar=1,pad=0.0,frame_size=3.0, xpad=None,
+            ypad = None):
         self.nx = nx
         self.ny = ny
         self.n_colorbar = n_colorbar
-        self.padding = 0.0
-            
+        self.pad = pad
+        self.xpad = pad
+        self.ypad = pad
+        if xpad is not None:
+            self.xpad = xpad
+        if ypad is not None:
+            self.ypad = ypad
+        self.frame_size = frame_size
         self.axes = {}
         self.fig = None
-        
+
         self._build()
         #self.plot_extents()
-        self.subdivide()
-        self.turn_off_internal_labels()
+        if self.nx > 1 or self.ny > 1:
+            self.subdivide()
+            self.turn_off_internal_labels()
 
         if n_colorbar == 2:
             self.turn_off_left_labels()
-    
+
     def _build(self):
-        self.fig = pl.figure(figsize=[4.*self.nx,4.*self.ny])
+        if isinstance(self.frame_size,tuple):
+            self.fig = pl.figure(figsize=(self.frame_size[0]*self.nx,
+                self.frame_size[1]*self.ny))
+        else:
+            self.fig = pl.figure(figsize=(self.frame_size*self.nx,
+                self.frame_size*self.ny))
         cf = 0.02
         if self.n_colorbar == 1:
             fig_axes = self.fig.add_axes([0.0, 0.0, 1.-cf, 1.-cf])
             cbar_axes = self.fig.add_axes([1.-cf, 0.0, cf, 1.-cf])
             cbar2_axes = None
         elif self.n_colorbar == 2:
-            cbar_axes = self.fig.add_axes([1.-cf, 0.0, cf, 1.-1.9*cf])
-            cbar2_axes = self.fig.add_axes([0.0, 0.0, cf, 1.-1.9*cf])
+            cbar_axes = self.fig.add_axes([1.-cf, 0.1*cf, cf, 1.-2.0*cf])
+            cbar2_axes = self.fig.add_axes([0.0, 0.1*cf, cf, 1.-2.0*cf])
             fig_axes = self.fig.add_axes([cf, 0.0, 1.-2*cf, 1.-1.9*cf])
             print fig_axes._position
+        else:
+            cbar_axes = None
+            cbar2_axes = None
+            fig_axes = self.fig.add_axes([0.0,0.0,1.0,1.0])
         self.axes['cbar'] = cbar_axes
         self.axes['left_cbar'] = cbar2_axes
         self.axes['figure'] = fig_axes
         op = ax._position
         w = op.width
         h = op.height
+        print w, h
         xi = op.x0
         yi = op.y0
-        dw = w/self.nx
+        dw = w/self.nx 
         dh = h/self.ny
         #dw = min(dw,dh)
         #dh = min(dw,dh)
         for j in range(self.ny):
             for i in range(self.nx):
-                self.axes[(i,j)] = self.fig.add_axes([xi+i*dw, yi+j*dh, dw, dh])
+                self.axes[(i,j)] = self.fig.add_axes([xi+i*(dw+self.xpad),
+                    yi+j*(dh+self.ypad), dw, dh])
         self.fig.delaxes(ax)
         self.axes['figure']=self.axes[(0,0)]
-    
-    def save_tmp(self,tight=True):
+
+    def make_cbar_axes(self, axkey, cf=0.02):
+        if axkey is None:
+            axkey = 'figure'
+        ax = self.axes[axkey]
+        op = ax._position
+        cw = op.width*cf
+        ch = op.height*(1.-cf)
+        rat = 1.0*self.nx/self.ny
+        w = op.width*(1.-cf)
+        h = op.height*(1.-cf)
+        xi = op.x0
+        yi = op.y0
+        offh = op.height*cf*0.5
+
+        bn = self.get_keyname(axkey)
+        axname = bn + '_cbar'
+
+        #self.fig.delaxes(ax)
+        ax.set_position([xi,yi+offh, w, h])
+        self.axes[axname] = self.fig.add_axes([xi+w,yi+offh,cw,ch])
+        self.axes['figure']=self.axes[(0,0)]
+
+    def save_tmp(self,fn='tmp.png', tight=True):
         if tight:
-            self.fig.savefig('tmp.png',bbox_inches='tight')
+            self.fig.savefig(fn,bbox_inches='tight')
         else:
-            self.fig.savefig('tmp.png',bbox_inches=None)
+            self.fig.savefig(fn,bbox_inches=None)
 
     def turn_off_internal_labels(self):
         for k,ax in self.axes.iteritems():
             if isinstance(k,tuple):
                 if k[0] != 0:
                     ax.set_yticks([])
-                if k[1] !=0:
+                if k[1] != 0:
                     ax.set_xticks([])
-    
+
     def turn_off_left_labels(self):
         for k,ax in self.axes.iteritems():
             if isinstance(k,tuple):
                 if k[0] == 0:
                     ax.set_yticks([])
 
-    def remove_last_tick(self):
+    def remove_last_tick(self, do_x=True, do_y=True):
          for k,ax in self.axes.iteritems():
             if isinstance(k,tuple):
-                if k[0] != self.nx-1:
+                if k[0] != self.nx-1 and do_x:
                     tks = ax.get_xticks()
                     if len(tks)>1:
                         ax.set_xticks(tks[:-1])
 
-                if k[1] != self.ny-1:
+                if k[1] != self.ny-1 and do_y:
                     tks = ax.get_yticks()
                     if len(tks)>1:
                         ax.set_yticks(tks[:-1])
 
+    def get_keyname(self,key):
+        if isinstance(key,tuple):
+            name = '%i_%i' % (key[0],key[1])
+        elif isinstance(key,string):
+            name = key
+        else:
+            name = 'None'
+        return name
+
     def colorbar(self, axkey, caxkey):
-        self.fig.colorbar(self.axes[axkey].images[0], cax=self.axes[caxkey])
-        if 'left' in caxkey:
-            self.axes[caxkey].yaxis.set_ticks_position('left')
+        if caxkey is None:
+            self.make_cbar_axes(axkey)
+            caxkey = self.get_keyname(axkey)+'_cbar'
 
+        cb = self.fig.colorbar(self.axes[axkey].images[0],cax=self.axes[caxkey])
+        if caxkey is not None:
+            if 'left' in caxkey:
+                self.axes[caxkey].yaxis.set_ticks_position('left')
+                self.axes[caxkey].yaxis.set_label_position('left')
+        return cb
 
-    def finalize(self):
-        self.remove_last_tick()
+    def finalize(self,*args,**kwargs):
+        self.remove_last_tick(args, kwargs)
+

multiplot_tools/test_share.py

 from shared_axes import Multipanel
 import numpy as np
-mp = Multipanel(nx=2,ny=1,n_colorbar=2)
+mp = Multipanel(nx=2,ny=2,n_colorbar=2,frame_size=5.)
 
 mp.axes[(0,0)].imshow(np.random.random((16,16)),aspect='auto',extent=[-3.,3.,-4.,4.])
 mp.axes[(1,0)].imshow(np.random.random((32,32)),aspect='auto',extent=[-3.,3.,-4.,4.])
-#mp.axes[(1,0)].imshow(np.random.random((64,64)),aspect='auto',extent=[-3.,3.,-4.,4.])
-#mp.axes[(1,1)].imshow(np.random.random((128,128)),aspect='auto',extent=[-3.,3.,-4.,4.])
+mp.axes[(0,1)].imshow(np.random.random((64,64)),aspect='auto',extent=[-3.,3.,-4.,4.])
+mp.axes[(1,1)].imshow(np.random.random((128,128)),aspect='auto',extent=[-3.,3.,-4.,4.])
 
 mp.colorbar((0,0), 'cbar')
-mp.colorbar((0,0), 'left_cbar')
+mp.colorbar((0,1), 'left_cbar')
 mp.finalize()
 
 mp.axes[(0,0)].set_xlabel('x [Mpc]')