Source

privateer / privateer_pyevolve / privateer_sim.py

Full commit
# encoding: utf-8
"""

.. module:: privateer_sim
   :synopsis: actual sim worker

.. moduleauthor:: Dan MacKinlay <fillmewithspam@email.possumpalace.org>

"""
from pyevolve import G1DList
from pyevolve import GSimpleGA
from pyevolve import Selectors
from pyevolve import Statistics
# from pyevolve import DBAdapters
from pyevolve import Selectors, Crossovers, Consts, Initializators, Mutators, Scaling


import numpy as np
import matplotlib.pyplot as plt
import market
import PrivateerCrossovers
import PrivateerMutators

import argparse

   
def main():
    import privateers
    parser = argparse.ArgumentParser(description="simulate some traders in a  stochastic market")
    parser.add_argument('--trader-class', help='Trader class name', default='CARATrader' )
    parser.add_argument('--seed', help='seed value', default=None )
    values = parser.parse_args()
    trader_class = getattr(privateers, values.trader_class)
    run_sim(trader_class, values.seed)
    
def run_sim(trader_class, seed):
    # Genome instance, 1D List of 1 element
    genome = G1DList.G1DList(1)

    # Sets the range max and min of the 1D List
    genome.setParams(rangemin=-20.0, rangemax=20.0)
        
    genome.initializator.set(Initializators.G1DListInitializatorReal)
    
    # The evaluator function (evaluation function)
    def value_of_lifetime(gen):        
        return market.value_of_lifetime(gen, Klass=trader_class)
    
    genome.evaluator.set(value_of_lifetime)
    
    # Mutator function 
    # genome.mutator.set(Mutators.G1DListMutatorRealGaussian)
    # genome.mutator.set(Mutators.G1DListMutatorRealRange)
    genome.mutator.set(PrivateerMutators.G1DListMutatorRealPseudoBitwise)
    
    # Crossover function must be set to something that can handle 1-elem list
    # genome.crossover.set(Crossovers.G1DListCrossoverUniform)
    # genome.crossover.set(PrivateerCrossovers.G1DListCrossoverMean)
    genome.crossover.set(PrivateerCrossovers.G1DListCrossoverBitwise)
    
    # Genetic Algorithm Instance
    ga_engine = GSimpleGA.GSimpleGA(genome,seed=seed)

    ga_engine.selector.set(Selectors.GTournamentSelector)
    # ga_engine.selector.set(Selectors.GRouletteWheel)
    #ga_engine.selector.set(Selectors.GTournamentSelectorAlternative)

    # Set the selector method, the number of generations and
    # the termination criteria
        
    ga_engine.setGenerations(500)
    
    #Don't check to convergence - it converges on false minima all the time
    # ga_engine.terminationCriteria.set(GSimplega_engine.ConvergenceCriteria)
    
    ga_engine.setMinimax(Consts.minimaxType["maximize"])
    
    pop = ga_engine.getPopulation()
    pop.scaleMethod.set(Scaling.SigmaTruncScaling)
    
    ga_engine.stepCallback.set(update_plot)
    
    ga_engine.setMutationRate(0.01)
    # 
    # ga_engine.setElitism(True)
    # ga_engine.setElitismReplacement(80)
    
    # # Sets the DB Adapter, the resetDB flag will make the Adapter recreate
    # # the database and erase all data every run, you should use this flag
    # # just in the first time, after the pyevolve.db was created, you can
    # # omit it.
    # sqlite_adapter = DBAdapters.DBSQLite(identify="ex1", resetDB=True)
    # ga_engine.setDBAdapter(sqlite_adapter)
    
    # ga_engine.setInteractiveMode(True)
    
    # Do the evolution, with stats dump
    # frequency of 20 generations
    ga_engine.evolve(freq_stats=0)
    
    # Best individual
    print ga_engine.bestIndividual()


plot_line = None
plt.ion()

def update_plot(ga_engine):
    global plot_line
    generation = ga_engine.getCurrentGeneration()

    # import pdb; pdb.set_trace()
    genomes = []
    fitnesses = []
    
    for individual in ga_engine.getPopulation():
        genomes.append(individual.genomeList[0])
        fitnesses.append(individual.fitness)
    # plt.clf()
    # series = plt.gca().lines
    # if len(series): del(series[0])
    
    if plot_line is None:
        plot_line, = plt.plot(genomes, fitnesses)
    plot_line.set_ydata(fitnesses)
    plot_line.set_xdata(genomes)
    plt.draw()

    
    return False

if __name__ == "__main__":
   main()