Snippets

Eugene A tiny solution for CartPole-v0 OpenAI gym

Created by Eugene
import gym
import autograd.numpy as np
import autograd.numpy.random as npr
from autograd import value_and_grad

ENV = gym.make('CartPole-v0')

def play_visualise(current_policy):
    total_reward = 0
    observation = ENV.reset()
    while True:
        ENV.render()
        p_0 = sigmoid(np.dot(current_policy, observation))
        action = 0 if p_0 > 0.5 else 1
        observation, reward, done, info = ENV.step(action)
        total_reward += reward
        if done:
            print("Episode finished with reward {}".format(total_reward))
            break

def sigmoid(x):
    return 1.0 / (1.0 + np.exp(-x))

def do_episode(current_policy):
    rewards = []
    pi = []
    i = 0
    observation = ENV.reset()
    for i in xrange(250):
        p_0 = sigmoid(np.dot(current_policy, observation))
        if npr.uniform() < p_0:
            action = 0
            pi.append(p_0)
        else:
            action = 1
            pi.append(1 - p_0)
        observation, reward, done, info = ENV.step(action)
        rewards.append(reward)
        if done:
            break
    rewards = np.array(rewards)
    rewards = np.cumsum(rewards[::-1])[::-1]
    pi = np.array(pi)
    return np.dot(rewards, np.log(pi))

def learn(init_step_size, n_learning_iterations, visualisation_freq):
    policy = npr.random(4) / 100
    vg = value_and_grad(do_episode)
    for i in xrange(n_learning_iterations):
        step_size = init_step_size * 1.0 / np.sqrt(i + 1)
        value, grad = vg(policy)
        policy += step_size * grad

        if (visualisation_freq > 0) and (i % visualisation_freq == 0):
            print 'policy', policy
            play_visualise(policy)


if __name__ == '__main__':
    init_step_size = 1e-1
    n_learning_iterations = 10000
    visualisation_freq = 200
    learn(init_step_size, n_learning_iterations, visualisation_freq)

Comments (0)

HTTPS SSH

You can clone a snippet to your computer for local editing. Learn more.