博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
DQN初探之学习Breakout-v0
阅读量:2079 次
发布时间:2019-04-29

本文共 6481 字,大约阅读时间需要 21 分钟。

DQN初探之学习"Breakout-v0"

本文记录了我初次使用DQN训练agent完成Atari游戏之"Breakout-v0"的过程。整个过程仿照DeepMind在nature发表的论文"Human-level control through deep reinforcement learning"。

1.gym环境之"Breakout-v0"

1.1.环境的状态空间和动作空间

首先确定"Breakout-v0"的动作空间和状态空间,其状态是游戏截图。

env = gym.make('Breakout-v0')print(env.observation_space)  # Box(0, 255, (210, 160, 3), uint8)print(env.action_space)  # Discrete(4)env.reset()observation = env.reset()print(observation.shape)  # (210, 160, 3) h, w, c

由上可以知道,"Breakout-v0"的状态空间和动作空间的基本情况:

  • “observation space”:Box(0, 255, (210, 160, 3), uint8),每一个状态都是一张游戏截图,分辨率为 210 × 160 × 3 ( h , w , c ) 210 \times 160 \times 3(h,w,c) 210×160×3(h,w,c),每一个像素取值在0~255;
  • “action space”:Discrete(4),有四个动作 ( 0 , 1 , 2 , 3 ) (0,1,2,3) (0,1,2,3)

上述两个空间都是gym中的抽象类space的实例化。space有两个常用的方法:

  • contains(x):检查"x"是否属于该空间;
  • sample():随机从空间中返回一个采样值。
env = gym.make('Breakout-v0')action_space = env.action_spaceprint(action_space.contains(2))  # Trueprint(action_space.sample())  # 3

gym的第一步通常是使用make得到一个游戏环境,如"Breakout-v0"。这个环境是抽象类env的实例化,通常具有以下属性和方法:

  • observation_space:环境的状态空间;
  • action_space:环境的动作空间;
  • reset():重置环境为初始状态,并且返回该初始状态;
  • step():agent与环境交互的接口。它会接受一个动作,然后返回更新后的环境状态、奖励、结束标识符以及其它相关游戏信息。
env = gym.make('Breakout-v0')env.reset()observation, reward, done, info = env.step(env.action_space.sample())print(observation)  # (210, 160, 3) print(reward)  # scalarprint(done)  # boolprint(info)  # {'ale.lives': 5}

1.2.Wrapper

gym除了上述功能外,还有一个强大的自定义工具WrapperWrapperenv的子类,其构造函数只接受一个参数:env的实例化对象。通过这个抽象类,我们可以对env的功能进行一些扩展。比如,可以自定义step()reset(),改变环境的状态输出;也可以对奖励进行一些处理等等。

此外,为了处理更加特殊的情况,gym还对Wrapper扩展了三个子类,使其可以单独对状态、动作和奖励进行处理。

  • ObservationWrapperobservation(obs)
  • ActionWrapperaction(act)
  • RewardWrapperreward(rew)

在这里插入图片描述

接下来以ActionWrapper为例,说明如何使用该类对动作进行处理。以下代码是ActionWrapper的源码。可以发现,ActionWrapperstep()并没有直接接受输入的action,而是先使用action()进行了某些处理。 但是处理方法并没有定义,这里就需要我们对该方法根据使用需要重新定义该函数。

# gym源码class ActionWrapper(Wrapper):    def reset(self, **kwargs):        return self.env.reset(**kwargs)    def step(self, action):        return self.env.step(self.action(action))    def action(self, action):        raise NotImplementedError    def reverse_action(self, action):        raise NotImplementedError

以下是对action()的重定义,用定义好的ActionWrapper封装env后,程序就会在使用step()时自动调用action()对输入动作进行处理。

# 自定义动作class RandomActionWrapper(gym.ActionWrapper):    def __init__(self, env, epsilon=0.1):        super(RandomActionWrapper, self).__init__(env)        self.epsilon = epsilon    def action(self, action):        if random() > self.epsilon:            print("it is random!")            return self.env.action_space.sample()        else:            return action

举例和解读了常见的Wrapper,这些Wrapper就是DeepMind在论文中提出的一些处理过程。

除此之外,其实通过gym.make()获得的环境本身也是经过封装的。使用env.unwrapper可以获得最原始的环境。比如,使用gym.make('CartPole-v0')获得的环境存在最大迭代次数为200次的设置,但是env.unwrapper就没有这种限制了。

env.unwrapper存在一些方法可以让我们理解该环境的一些内容:

  • “env.unwrapped.get_action_meanings()”:返回字符列表,表示动作空间每一个对应动作的含义;

    env = gym.make("Breakout-v0")env.action_space  # Discrete(4) [0, 1, 2, 3]env.unwrapped.get_action_meanings()  # ['NOOP', 'FIRE', 'RIGHT', 'LEFT']

1.3.Monitor

如果你想记录agent的训练过程,那么Monitor是一个很好的工具。只需要对环境做一个简单的封装,Monitor就可以实现很多有用的功能。但需要注意的是,使用该工具必须要确保系统具有FFmpeg功能,否则程序会报错。最简单的做法就是安装FFmpeg。在Windows环境下,可以使用命令conda install -c conda-forge ffmpeg进行安装。

env = gym.make("Breakout-v0")# argument2: save path# argument3(force): if force=False,then the save folder shouldn't exist.env = gym.wrappers.Monitor(env, "recording", force=True)

2.创建经验回放类

DQN中的一个重要机制就是经验回放机制,它是为了打破训练数据之间的相关性而专门设置的。在深度学习中,通常会假设训练数据是相互独立的。但是在强化学习中,这种假设并不成立。因为,它的学习资料是各组相互关联的MDP序列。

为了打破训练数据之间的相关性,DQN会专门设置一个经验池,用于储存经历过的状态序列 ( x i , a i , r i + 1 , x i + 1 ) (x_i,a_i,r_{i+1},x_{i+1}) (xi,ai,ri+1,xi+1)。然后,DQN会从经验池中随机抽取"batch_size"组序列用于神经网络的训练。在提出DQN的论文中,把经验池的大小设置为1,000,000,即可以存放这个数量的序列。

First, we use a technique known as experience replay23 in which we store the agent’s experiences at each time-step, e t = ( s t , a t , r t , s t + 1 ) e_t=(s_t,a_t,r_t,s_{t+1}) et=(st,at,rt,st+1), in a data set D t = e 1 , … , e t D_t={e_1,…,e_t} Dt=e1,,et, pooled over many episodes (where the end of an episode occurs when a terminal state is reached) into a replay memory.

经验回放池需要实现以下功能:

  • 创建一个储存memory的变量作为经验池,用于储存所有的memory;
  • 实现经验池的随机采样功能,即可以从经验池中随机选择任意数量的样本作为训练资料;
  • 对memory进行重组,返回由最近的k个memory组成的obs ( k × c , h , w ) (k \times c,h,w) (k×c,h,w),训练时的一个batch的资料大小应为 ( b , k × c , h , w ) (b,k\times c,h,w) (b,k×c,h,w)。以"Breakout-v0"为例,每次会以当前的最近的k帧图片作为输入,使用网络找到价值最高的动作。那么同样地,下一个状态的维度也应该是 ( k × c , h , w ) (k \times c,h,w) (k×c,h,w)

关于第三个功能,即把k帧图像重构为一组状态输入,存在以下三种情况特殊处理:

  1. 若索引号小于k,且经验池未填满,该如何处理?
  2. 若索引号小于k,但经验池已填满,该如何处理?
  3. 若k帧图像中,其中的某帧图像是终止图像(即游戏结束),该如何处理之后的图像?

由于经验池的容量是固定的,当经验池填满后,算法会根据先进先出的原则移除最早的数据。因此,对于第2种情况,我们需要进行经验的首尾拼接。

关于剩余2种情况,一种方法是用0填充不足的图像;当然,也可以不处理这种情况。

下面详细分析以上几种特殊情况的处理方法(注:从后往前检测):

  • situation 1 (no situation 3)

    假设需要4帧图像,但是目前经验池中只有2帧,则取该2帧图像,剩余2帧图像用0补足。

  • situation 1 (with situation 3)

    假设需要4帧图像,但是目前经验池中只有3帧。取停止标志后的那些帧,剩余的用0补足。

  • situation 2 (no situation 3)

    假设需要4帧图像,但是前2帧是经验池的最后两帧,后2帧是经验池的开始2帧。将这4帧拼接。

  • situation 2 (no situation 3)

    假设需要4帧图像,但是前2帧是经验池的最后两帧,后2帧是经验池的开始2帧。

    • 停止标志出现在前2帧

      取后两帧和停止标志后的那些帧,剩余的用0补足。

    • 停止标志出现在后2帧

      取停止标志后的那些帧,剩余的用0补足。

  • situation 3

    取停止标志后的那些帧,剩余的用0补足。

3.target net

DQN与其它使用神经网络训练强化学习的方法不同,它把target网络和预测网络分开了。两个网络采样相同的网络结构,但是target网络不参与模型的训练,而是每隔一定时间加载预测网络的参数。

The second modification to online Q-learning aimed at further improving the stability of our method with neural networks is to use a separate network for generating the targets yj in the Q-learning update. More precisely, every C updates we clone the network Q to obtain a target network Q ^ \hat{Q} Q^ and use Q ^ \hat{Q} Q^ for generating the Q-learning targets y j y_j yj for the following C updates to Q.

这样处理可以让算法更加稳定,减少策略更新时的振荡。

4.其它技巧

DeepMind在论文中还提出了很多可以提高算法性能的trick。这些数据处理的过程通常都可以使用Wrapper对环境进行改写,上文的重组k帧图像的过程也可以放在Wrapper中处理。

下面简要介绍DeepMind的一些小技巧:

  • 对Atari游戏截图 ( 210 × 160 × 3 ) (210\times 160 \times 3) (210×160×3)做预处理,灰度化和尺寸裁剪后,新图像的维度为 ( 84 × 84 × 1 ) (84\times 84 \times 1) (84×84×1)
  • 取两帧相邻的图像,取两者之间的最大值作为输出图像,这样可以消除那些闪烁的像素(仅出现在奇数帧或偶数帧);
  • 把环境输出的奖励裁剪到 ( − 1 , 1 ) (-1,1) (1,1)
  • 使用frame-skipping技术,对于k帧相邻图像只计算一个动作。比如,取一帧图像使用神经网络得到一个动作,之后的 ( k − 1 ) (k-1) (k1)帧仍然使用该动作。这样可以减少计算代价,提高训练效率;
  • 把网络的训练误差 ∣ r + γ max ⁡ a ′ Q ( s ′ , a ′ ; θ i − ) − Q ( s , a ; θ i ) ∣ |r+\gamma \max_{a'}Q(s',a';\theta^-_i)-Q(s,a;\theta_i)| r+γmaxaQ(s,a;θi)Q(s,a;θi)裁剪到 ( − 1 , 1 ) (-1,1) (1,1)

其它内容可以查看DeepMind发表的论文"Human-level control through deep reinforcement learning"。

关于DQN的代码实现,我放在了。我是基于Pytorch实现的,torch版本为1.8.1。由于本人没有专门跑深度学习的配置,因此只跑了900000次就结束了。从结果看,agent的得到的平均奖励还是有明显提升的。

转载地址:http://csuqf.baihongyu.com/

你可能感兴趣的文章
fastjson使用(三) -- 序列化
查看>>
浅谈使用单元素的枚举类型实现单例模式
查看>>
Java 利用枚举实现单例模式
查看>>
Java 动态代理作用是什么?
查看>>
Java动态代理机制详解(JDK 和CGLIB,Javassist,ASM) (清晰,浅显)
查看>>
三种线程安全的单例模式
查看>>
Spring AOP 和 动态代理技术
查看>>
从 volatile 说起,可见性和有序性是什么
查看>>
如何开始接手一个项目
查看>>
Netty 5用户指南
查看>>
Java实现简单的RPC框架
查看>>
一个用消息队列 的人,不知道为啥用 MQ,这就有点尴尬
查看>>
从零手写RPC
查看>>
高并发和多线程的关系
查看>>
Java并发与多线程
查看>>
对于多线程程序,单核cpu与多核cpu是怎么工作的
查看>>
多线程和CPU的关系
查看>>
认识cpu、核与线程
查看>>
关于Java健壮性的一些思考与实践!
查看>>
如何避免自己写的代码成为别人眼中的一坨屎!
查看>>