privateer / privateer_pyevolve / privateer_sim.py

# encoding: utf-8
"""

.. module:: privateer_sim
   :synopsis: actual sim worker

.. moduleauthor:: Dan MacKinlay <fillmewithspam@email.possumpalace.org>

"""
from pyevolve import G1DList
from pyevolve import GSimpleGA
from pyevolve import Selectors
from pyevolve import Statistics
# from pyevolve import DBAdapters
from pyevolve import Selectors, Crossovers, Consts, Initializators, Mutators, Scaling


import numpy as np
import matplotlib.pyplot as plt
import market
import PrivateerCrossovers

import argparse

   
def main():
    import privateers
    parser = argparse.ArgumentParser(description="simulate some traders in a  stochastic market")
    parser.add_argument('--trader-class', help='Trader class name', default='CARATrader' )
    parser.add_argument('--seed', help='seed value', default=None )
    values = parser.parse_args()
    trader_class = getattr(privateers, values.trader_class)
    run_sim(trader_class, values.seed)
    
def run_sim(trader_class, seed):
    # Genome instance, 1D List of 1 element
    genome = G1DList.G1DList(1)

    # Sets the range max and min of the 1D List
    # sigma is delicate - without it being set high, convergence is to a non-zero value
    #TODO - explore the sensitivity of this
    genome.setParams(rangemin=-20.0, rangemax=20.0,
      gauss_mu=0.00, gauss_sigma=10.)
        
    genome.initializator.set(Initializators.G1DListInitializatorReal)
    
    # The evaluator function (evaluation function)
    def value_of_lifetime(gen):        
        return market.value_of_lifetime(gen, Klass=trader_class)
    
    genome.evaluator.set(value_of_lifetime)
    
    # Mutator function 
    genome.mutator.set(Mutators.G1DListMutatorRealGaussian)
    #genome.mutator.set(Mutators.G1DListMutatorRealRange)
    
    # Crossover function must be set to something that can handle 1-elem list
    genome.crossover.set(Crossovers.G1DListCrossoverUniform)
    # genome.crossover.set(PrivateerCrossovers.G1DListCrossoverMean)
    # genome.crossover.set(PrivateerCrossovers.G1DListCrossoverBitwise)
    
    # Genetic Algorithm Instance
    ga_engine = GSimpleGA.GSimpleGA(genome,seed=seed)

    ga_engine.selector.set(Selectors.GTournamentSelector)
    # ga_engine.selector.set(Selectors.GRouletteWheel)
    #ga_engine.selector.set(Selectors.GTournamentSelectorAlternative)

    # Set the selector method, the number of generations and
    # the termination criteria
        
    ga_engine.setGenerations(500)
    
    #Don't check to convergence - it converges on false minima all the time
    # ga_engine.terminationCriteria.set(GSimplega_engine.ConvergenceCriteria)
    
    ga_engine.setMinimax(Consts.minimaxType["maximize"])
    
    pop = ga_engine.getPopulation()
    pop.scaleMethod.set(Scaling.SigmaTruncScaling)
    
    ga_engine.stepCallback.set(update_plot)
    
    # ga_engine.setMutationRate(0.05)
    # 
    # ga_engine.setElitism(True)
    # ga_engine.setElitismReplacement(80)
    
    # # Sets the DB Adapter, the resetDB flag will make the Adapter recreate
    # # the database and erase all data every run, you should use this flag
    # # just in the first time, after the pyevolve.db was created, you can
    # # omit it.
    # sqlite_adapter = DBAdapters.DBSQLite(identify="ex1", resetDB=True)
    # ga_engine.setDBAdapter(sqlite_adapter)
    
    # ga_engine.setInteractiveMode(True)
    
    # Do the evolution, with stats dump
    # frequency of 20 generations
    ga_engine.evolve(freq_stats=0)
    
    # Best individual
    print ga_engine.bestIndividual()


plot_line = None
plt.ion()

def update_plot(ga_engine):
    global plot_line
    generation = ga_engine.getCurrentGeneration()

    # import pdb; pdb.set_trace()
    genomes = []
    fitnesses = []
    
    for individual in ga_engine.getPopulation():
        genomes.append(individual.genomeList[0])
        fitnesses.append(individual.fitness)
    # plt.clf()
    # series = plt.gca().lines
    # if len(series): del(series[0])
    
    if plot_line is None:
        plot_line, = plt.plot(genomes, fitnesses)
    plot_line.set_ydata(fitnesses)
    plot_line.set_xdata(genomes)
    plt.draw()

    
    return False

if __name__ == "__main__":
   main()

    
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.