接着上文,我们看下代码实现。首先看 Q-Network 和 Target Q-Network 构建过程:
def build_graph(num_actions):
# 构建图中的 Q Network,是一个 CNN 模型
# build\_network 函数实现见《[利用 TensorFlow + Keras 玩 Atari 游戏](http://mp.weixin.qq.com/s?__biz=MzI2MzYwNzUyNg==&mid=2247483889&idx=1&sn=ec1f89fcb81d26410be29ca19d7d6a6d&chksm=eab80478ddcf8d6e62f9a910810aab77fcb03c5b5bb870839a50c6bd042edb2755ad7ed2266d&scene=21#wechat_redirect)》
s, q\_network = build\_network( \
num\_actions=num\_actions, \
agent\_history\_length=FLAGS.agent\_history\_length, \
resized\_width=FLAGS.resized\_width, \
resized\_height=FLAGS.resized\_height)
network\_params = q\_network.trainable\_weights # Q Network 权值
q\_values = q\_network(s) # 计算 s 状态下所有可能行为的 Q 函数
# 构建图中的 Target Q Network,与 Q-Network 中同样结构的 CNN 模型
st, target\_q\_network = build\_network( \
num\_actions=num\_actions, \
agent\_history\_length=FLAGS.agent\_history\_length, \
resized\_width=FLAGS.resized\_width, \
resized\_height=FLAGS.resized\_height)
# Target Q Network 权值
target\_network\_params = target\_q\_network.trainable\_weights
# 计算 st 状态下所有可能行为的 Target Q 函数
target\_q\_values = target\_q\_network(st)
# 周期性地用 Q Network 权值更新 Target Q Network 权值
reset\_target\_network\_params = [target\_network\_params[i].assign(network\_params[i]) for i in range(len(target\_network\_params))]
# 定义损失函数
a = tf.placeholder("float", [None, num\_actions]) # 行为序列
y = tf.placeholder("float", [None]) # y 值从外部馈入
action\_q\_values = tf.reduce\_sum(tf.multiply(q\_values, a), reduction\_indices=1) # 计算 Q(s, a)
cost = tf.reduce\_mean(tf.square(y - action\_q\_values)) # 计算损失函数 L
optimizer = tf.train.AdamOptimizer(FLAGS.learning\_rate) # 使用 Adam 优化器
grad\_update = optimizer.minimize(cost, var\_list=network\_params) # 调整 Q-Network 权值,让损失函数最小
# 导出 graph 给上层 actor-learner 控制器
graph\_ops = {"s" : s,
"q\_values" : q\_values,
"st" : st,
"target\_q\_values" : target\_q\_values,
"reset\_target\_network\_params" : reset\_target\_network\_params,
"a" : a,
"y" : y,
"grad\_update" : grad\_update}
return graph\_ops
接下来的函数是实现 actor-learner 架构的核心,使用了 epsilon-贪婪策略,即选择概率值 0 <= epsilon <= 1,每次
def actor_learner_thread(thread_id, env, session, graph_ops, num_actions, summary_ops, saver):
global TMAX, T
# 获取上面 build\_graph 函数中导出的 graph 节点,略去不表
# ……
# Atari 模拟器封装类对象
env = AtariEnvironment(gym\_env=env, resized\_width=FLAGS.resized\_width, resized\_height=FLAGS.resized\_height, agent\_history\_length=FLAGS.agent\_history\_length)
s\_batch = [] # 观测序列
a\_batch = [] # 行为序列
y\_batch = [] # y 序列,公式见上
final\_epsilon = sample\_final\_epsilon()
initial\_epsilon = 1.0
epsilon = 1.0
print "Starting thread ", thread\_id, "with final epsilon ", final\_epsilon
time.sleep(3*thread\_id)
t = 0
while T < TMAX:
# 获取游戏初始观测
s\_t = env.get\_initial\_state()
terminal = False
# 设置计数器
ep\_reward = 0 # 每局游戏奖励值
episode\_ave\_max\_q = 0
ep\_t = 0
while True:
# 向 DQN 送入初始观测 s\_t,执行 forward 计算,得到 Q 函数
readout\_t = q\_values.eval(session = session, feed\_dict = {s : [s\_t]})
# Choose next action based on e-greedy policy
# 利用 epsilon-贪婪策略选择下一行为 a\_t
# 即以概率 epsilon 选择随机行为(探索过程)
# 而以 (1 - epsilon) 概率选择当前 Q 函数中最优行为(经验指导)
a\_t = np.zeros([num\_actions])
action\_index = 0
if random.random() <= epsilon:
action\_index = random.randrange(num\_actions)
else:
action\_index = np.argmax(readout\_t)
a\_t[action\_index] = 1 # 相应行为置 1 表示选中的下一步动作
# 模拟退火,epsilon 在开始训练时较大,随着训练进行不断衰减
if epsilon > final\_epsilon:
epsilon -= (initial\_epsilon - final\_epsilon) / FLAGS.anneal\_epsilon\_timesteps
# 模拟器按照上面得到的行为
进行一步更新,得到 和
s\_t1, r\_t, terminal, info = env.step(action\_index)
# 将
送入 Target Q-Network,得到 Target Q 函数值
readout\_j1 = target\_q\_values.eval(session = session, feed\_dict = {st : [s\_t1]})
# 将
范围限制在 [-1, 1] 之间
clipped\_r\_t = np.clip(r\_t, -1, 1)
# 计算 y 值,即
if terminal:
# 如果 Game Over,就不用再算下一步观测的 Q 函数了
# 直接 y = r\_t
y\_batch.append(clipped\_r\_t)
else:
y\_batch.append(clipped\_r\_t + FLAGS.gamma * np.max(readout\_j1))
a\_batch.append(a\_t)
s\_batch.append(s\_t)
# 更新观测和计数器
s\_t = s\_t1
T += 1
t += 1
ep\_t += 1
ep\_reward += r\_t # 奖励累积
episode\_ave\_max\_q += np.max(readout\_t) # Q 值平滑
# 周期性同步Q 网络权值到 Target Q 网络权值,频次低
if T % FLAGS.target\_network\_update\_frequency == 0:
session.run(reset\_target\_network\_params)
# 利用 SGD更新 Q 网络权值,频次高
if t % FLAGS.network\_update\_frequency == 0 or terminal:
if s\_batch:
session.run(grad\_update, feed\_dict = {y : y\_batch,
a : a\_batch,
s : s\_batch})
# 更新完毕,清除输入
s\_batch = []
a\_batch = []
y\_batch = []
# 保存模型
if t % FLAGS.checkpoint\_interval == 0:
saver.save(session, FLAGS.checkpoint\_dir+"/"+FLAGS.experiment+".ckpt", global\_step = t)
# 统计日志信息
if terminal:
stats = [ep\_reward, episode\_ave\_max\_q/float(ep\_t), epsilon]
for i in range(len(stats)):
session.run(update\_ops[i], feed\_dict={summary\_placeholders[i]:float(stats[i])})
print "THREAD:", thread\_id, "/ TIME", T, "/ TIMESTEP", t, "/ EPSILON", epsilon, "/ REWARD", ep\_reward, "/ Q\_MAX %.4f" % (episode\_ave\_max\_q/float(ep\_t)), "/ EPSILON PROGRESS", t/float(FLAGS.anneal\_epsilon\_timesteps)
break
以上内容已经把 DQN 介绍完毕,如果仍有点小困惑,推荐大家阅读参考文献【2】,看累了就玩会小游戏,玩累了再看,有奇效!
Reference
【1】DeepMind, Playing Atari with Deep Reinforcement Learning, https://arxiv.org/abs/1312.5602
【2】Richard S. S. Reinforcement Learning : An Introduction, The MIT Press.
【3】Arun Nair, Massively Parallel Methods for Deep Reinforcement Learning
【4】Tom M. Michell, Machine Learning, Chapter 13, "Reinforcement Learning".
【5】Kevin Chen, Deep Reinforcement Learning for Flappy Bird, http://cs229.stanford.edu/proj2015/362\_report.pdf
如果你觉得本文对你有帮助,请关注公众号,将来会有更多更好的文章推送!