Source

PyRNN / pyrnn / learn.py

# -*- coding: utf-8 -*-
import numpy
import pylab

class Learn(object):

  def __init__(self, nn, iter_times):
    self.nn = nn
    num_rec = nn.record().shape[0] + 1
    self.rec = numpy.zeros((iter_times,num_rec), dtype=numpy.float)
    self.rec[:,0] = numpy.arange(iter_times)

  def set_callback(self,cb_func,cb_step):
    self.cb_func = cb_func
    self.cb_step = cb_step

  def learn(self):
    nn = self.nn
    rec = self.rec
    for i in xrange(rec.shape[0]):
      nn.reset_grad()
      nn.fptt()
      nn.set_E()
      nn.bptt()
      nn.change_net()
      rec[i,1:] = nn.record()
      if 0 == i%self.cb_step:
        self.cb_func()

def phase_plot_tixo(ln, xomin=-1, xomax=1, xcmin=-1, xcmax=1):
  nn = ln.nn
  xo = nn.ns.xo
  to = nn.ns.to
  xc = nn.ns.xc
  pylab.clf()
  pylab.subplot(221)
  pylab.loglog(ln.rec[:,0],ln.rec[:,1])
  pylab.subplot(222)
  for i in xrange(3,ln.rec.shape[1]):
    pylab.loglog(ln.rec[:,0],ln.rec[:,i])
  pylab.subplot(223)
  for i in range(1,xo.shape[1],2):
    pylab.plot(xo[1:,i-1],xo[1:,i])
    pylab.plot(to[1:,i-1],to[1:,i],'.-')
  if 1 == xo.shape[1] % 2:
    pylab.plot(xo[1:,0],xo[1:,-1])
    pylab.plot(to[1:,0],to[1:,-1],'.-')
  pylab.xlim(xomin,xomax)
  pylab.ylim(xomin,xomax)
  pylab.subplot(224)
  for i in range(1,xc.shape[1],2):
    pylab.plot(xc[:,i-1],xc[:,i])
  if 1 == xc.shape[1] % 2:
    pylab.plot(xc[:,0],xc[:,-1])
  pylab.xlim(xcmin,xcmax)
  pylab.ylim(xcmin,xcmax)

def get_cb_phase_plot(ln, xomin=-1, xomax=1, xcmin=-1, xcmax=1):
  def cb_phase_plot():
    phase_plot_tixo(ln, xomin, xomax, xcmin, xcmax)
    pylab.draw()
    pylab.show()
  return cb_phase_plot


def tms_plot_tixo(ln, xomin=-1, xomax=1, xcmin=-1, xcmax=1):
  nn = ln.nn
  xo = nn.ns.xo
  to = nn.ns.to
  xc = nn.ns.xc
  pylab.clf()
  pylab.subplot(221)
  pylab.loglog(ln.rec[:,0],ln.rec[:,1])
  pylab.subplot(222)
  for i in xrange(3,ln.rec.shape[1]):
    pylab.loglog(ln.rec[:,0],ln.rec[:,i])
  pylab.subplot(223)
  for i in range(xo.shape[1]):
    pylab.plot(nn.ns.time_step[1:], xo[1:,i])
    pylab.plot(nn.ns.time_step[1:], to[1:,i],'.-')
  pylab.ylim(xomin,xomax)
  pylab.subplot(224)
  for i in range(xc.shape[1]):
    pylab.plot(nn.ns.time_step, xc[:,i])
  pylab.ylim(xcmin,xcmax)

def get_cb_tms_plot(ln, xomin=-1, xomax=1, xcmin=-1, xcmax=1):
  def cb_tms_plot():
    tms_plot_tixo(ln, xomin, xomax, xcmin, xcmax)
    pylab.draw()
    pylab.show()
  return cb_tms_plot