Source

stockwelltransform / stockwell / plots.py

Full commit
"""
provide matplotlib-based visualization functions for stockwell transforms
"""
import stockwell 
import matplotlib.pyplot as plt

from matplotlib.offsetbox import TextArea, DrawingArea, OffsetImage, \
     AnnotationBbox

from mpl_toolkits.axes_grid1 import host_subplot

def plotspec(psx, fs=2.0, lofreq=None, hifreq=None, t0=None, t1=None):
    """
    useful for plotting the power of a stockwell transform
    it relies upon matplotlib for display
    example:
    # for a signal x, with sampling frequency 200
    >>> import stockwell, pylab
    >>> import stockwell.plots as plots
    >>> x = pylab.zeros(1000.0) # 5 seconds
    >>> fs = 200 # sample rate
    >>> x[250:350] = 1.0 # step function
    >>> sx = stockwell.st(x)
    >>> psx = abs(sx) # create power
    >>> r=plots.plotspec(psx,200)
    >>> # pylab.show() # to visualize this
    """
    extent = [0,psx.shape[1], 0.0, fs/2.0] # default extents
    if t0 != None and t1 != None:
        extent[0] = t0
        extent[1] = t1
    if lofreq != None:
        extent[2] = lofreq
    if hifreq != None:
        extent[3] = hifreq
    plt.ylabel('Hz')
    return plt.imshow(psx, extent=extent, aspect='auto', origin='lower')


def stspecgram(x,fs,lofreq=None, hifreq=None, t0=None, t1=None):
    """
    plot out the stockwell spectrum abs(st(x))
    given frequency sampling fs in Hz

    lofreq and hifreq are the frequency limits expressed in terms of the nyquist frequency(?)
    """

    n = x.shape[0]
    if t0==None:
        t0=0.0
    if t1==None:
        t1=n/float(fs)+t0

    if lofreq==None and hifreq==None:
        sx=stockwell.st(x)
        return plotspec(abs(sx),fs, t0=t0,t1=t1)

    lorow=stockwell.stfreq(lofreq,n,fs)
    hirow=stockwell.stfreq(hifreq,n,fs)
    sx=stockwell.st(x, lorow,hirow)
    return plotspec(abs(sx), fs, lofreq=lofreq,hifreq=hifreq, t0=t0,t1=t1)


def stpowerstack(x,stx):
    """
    need to add row labels
    """
    ax1=host_subplot(211)
    plt.plot(x)
    ax2 = host_subplot(212)
    #pax2 = ax2.twinx()
    #pax2.set_ylabel('frequency(Hz)')
    plt.imshow(abs(stx), aspect='auto')
    plt.ylabel('frequency(Hz)')
    ylocs,ylabels = yticks()
    #plt.ylabel('st-row(f*L/fs)')
    yt,xt = stx.shape


    return ax2
    
if __name__ == "__main__":
    import doctest
    doctest.testmod()