importgymimportautograd.numpyasnpimportautograd.numpy.randomasnprfromautogradimportvalue_and_gradENV=gym.make('CartPole-v0')defplay_visualise(current_policy):total_reward=0observation=ENV.reset()whileTrue:ENV.render()p_0=sigmoid(np.dot(current_policy,observation))action=0ifp_0>0.5else1observation,reward,done,info=ENV.step(action)total_reward+=rewardifdone:print("Episode finished with reward {}".format(total_reward))breakdefsigmoid(x):return1.0/(1.0+np.exp(-x))defdo_episode(current_policy):rewards=[]pi=[]i=0observation=ENV.reset()foriinxrange(250):p_0=sigmoid(np.dot(current_policy,observation))ifnpr.uniform()<p_0:action=0pi.append(p_0)else:action=1pi.append(1-p_0)observation,reward,done,info=ENV.step(action)rewards.append(reward)ifdone:breakrewards=np.array(rewards)rewards=np.cumsum(rewards[::-1])[::-1]pi=np.array(pi)returnnp.dot(rewards,np.log(pi))deflearn(init_step_size,n_learning_iterations,visualisation_freq):policy=npr.random(4)/100vg=value_and_grad(do_episode)foriinxrange(n_learning_iterations):step_size=init_step_size*1.0/np.sqrt(i+1)value,grad=vg(policy)policy+=step_size*gradif(visualisation_freq>0)and(i%visualisation_freq==0):print'policy',policyplay_visualise(policy)if__name__=='__main__':init_step_size=1e-1n_learning_iterations=10000visualisation_freq=200learn(init_step_size,n_learning_iterations,visualisation_freq)
Comments (0)
HTTPSSSH
You can clone a snippet to your computer for local editing.
Learn more.