资源简介
压缩文件中有两个.py文件,分别为深度强化学习的交叉熵优化方法和策略优化方法的完整代码,readme文件中提供的资料中有具体的操作细节以及算法解释
代码片段和文件信息
#!/usr/bin/python
# -*- coding: utf-8 -*-
import time
import numpy as np
import gym
from gym.spaces import Discrete Box
# ================================================================
# Policies
# ================================================================
class DeterministicDiscreteActionLinearPolicy(object):
def __init__(self theta ob_space ac_space):
“““
dim_ob: dimension of observations
n_actions: number of actions
theta: flat vector of parameters
“““
dim_ob = ob_space.shape[0]
n_actions = ac_space.n
assert len(theta) == (dim_ob + 1) * n_actions
self.W = theta[0 : dim_ob * n_actions].reshape(dim_ob n_actions)
self.b = theta[dim_ob * n_actions : None].reshape(1 n_actions)
def act(self ob):
“““
“““
y = ob.dot(self.W) + self.b
a = y.argmax()
return a
class DeterministicContinuousActionLinearPolicy(object):
def __init__(self theta ob_space ac_space):
“““
dim_ob: dimension of observations
dim_ac: dimension of action vector
theta: flat vector of parameters
“““
self.ac_space = ac_space
dim_ob = ob_space.shape[0]
dim_ac = ac_space.shape[0]
assert len(theta) == (dim_ob + 1) * dim_ac
self.W = theta[0 : dim_ob * dim_ac].reshape(dim_ob dim_ac)
self.b = theta[dim_ob * dim_ac : None]
def act(self ob):
a = np.clip(ob.dot(self.W) + self.b self.ac_space.low self.ac_space.high)
return a
def do_episode(policy env num_steps render=False):
total_rew = 0
ob = env.reset()
for t in range(num_steps):
a = policy.act(ob)
(ob reward done _info) = env.step(a)
total_rew += reward
if render and t%3==0: env.render()
if done: break
return total_rew
env = None
def noisy_evaluation(theta):
policy = make_policy(theta)
rew = do_episode(policy env num_steps)
return rew
def make_policy(theta):
if isinstance(env.action_space Discrete):
return DeterministicDiscreteActionLinearPolicy(theta
env.observation_space env.action_space)
elif isinstance(env.action_space Box):
return DeterministicContinuousActionLinearPolicy(theta
env.observation_space env.action_space)
else:
raise NotImplementedError
# Task settings:
# env = gym.make(‘CartPole-v0‘)
env = gym.make(‘MountainCar-v0‘)
# Change as needed
num_steps = 500 # maximum length of episode
# Alg settings:
n_iter = 100 # number of iterations of CEM
batch_size = 25 # number of samples per batch
elite_frac = 0.2 # fraction of samples used as elite set
if isinstance(env.action_space Discrete):
dim_theta = (env.observation_space.shape[0]+1) * env.action_space.n
elif isinstance(env.action_space Box):
dim_theta = (env.observation_space.shape[0]+1) * env.action_space.shape[0]
else:
raise NotImplementedError
# Initialize me
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
目录 0 2018-03-08 09:44 CEM_PolicyGradient_gym\
目录 0 2018-03-08 09:42 CEM_PolicyGradient_gym\.idea\
文件 455 2018-03-08 09:42 CEM_PolicyGradient_gym\.idea\CEM_Implementation.iml
文件 212 2018-03-08 09:42 CEM_PolicyGradient_gym\.idea\misc.xm
文件 288 2018-03-08 09:42 CEM_PolicyGradient_gym\.idea\modules.xm
文件 13488 2018-03-08 09:42 CEM_PolicyGradient_gym\.idea\workspace.xm
文件 4393 2018-03-08 09:42 CEM_PolicyGradient_gym\CEM.py
文件 6727 2018-03-08 09:42 CEM_PolicyGradient_gym\Policy_Gradient.py
文件 88 2018-03-08 09:45 CEM_PolicyGradient_gym\readme.txt
- 上一篇:微信小游戏分包加载demo
- 下一篇:编译原理课程设计.docx
评论
共有 条评论