DQN 原理(三):DQN 训练代码实现

火山方舟向量数据库

picture.image

接着上文,我们看下代码实现。首先看 Q-Network 和 Target Q-Network 构建过程:

def build_graph(num_actions):

# 构建图中的 Q Network,是一个 CNN 模型




# build\_network 函数实现见《[利用 TensorFlow + Keras 玩 Atari 游戏](http://mp.weixin.qq.com/s?__biz=MzI2MzYwNzUyNg==&mid=2247483889&idx=1&sn=ec1f89fcb81d26410be29ca19d7d6a6d&chksm=eab80478ddcf8d6e62f9a910810aab77fcb03c5b5bb870839a50c6bd042edb2755ad7ed2266d&scene=21#wechat_redirect)》




s, q\_network = build\_network(    \




        num\_actions=num\_actions,  \




        agent\_history\_length=FLAGS.agent\_history\_length,   \




        resized\_width=FLAGS.resized\_width,   \




        resized\_height=FLAGS.resized\_height)     




network\_params = q\_network.trainable\_weights # Q Network 权值




q\_values = q\_network(s) # 计算 s 状态下所有可能行为的 Q 函数










# 构建图中的 Target Q Network,与 Q-Network 中同样结构的 CNN 模型




st, target\_q\_network = build\_network(     \




        num\_actions=num\_actions,     \




        agent\_history\_length=FLAGS.agent\_history\_length,    \




        resized\_width=FLAGS.resized\_width,         \




        resized\_height=FLAGS.resized\_height)




# Target Q Network 权值




target\_network\_params = target\_q\_network.trainable\_weights 




# 计算 st 状态下所有可能行为的 Target Q 函数




target\_q\_values = target\_q\_network(st) 










# 周期性地用 Q Network 权值更新 Target Q Network 权值




reset\_target\_network\_params = [target\_network\_params[i].assign(network\_params[i]) for i in range(len(target\_network\_params))]









# 定义损失函数




a = tf.placeholder("float", [None, num\_actions]) # 行为序列




y = tf.placeholder("float", [None]) # y 值从外部馈入




action\_q\_values = tf.reduce\_sum(tf.multiply(q\_values, a), reduction\_indices=1) # 计算 Q(s, a) 




cost = tf.reduce\_mean(tf.square(y - action\_q\_values)) # 计算损失函数 L




optimizer = tf.train.AdamOptimizer(FLAGS.learning\_rate) # 使用 Adam 优化器




grad\_update = optimizer.minimize(cost, var\_list=network\_params) # 调整 Q-Network 权值,让损失函数最小










# 导出 graph 给上层 actor-learner 控制器




graph\_ops = {"s" : s,  




             "q\_values" : q\_values,




             "st" : st, 




             "target\_q\_values" : target\_q\_values,




             "reset\_target\_network\_params" : reset\_target\_network\_params,




             "a" : a,




             "y" : y,




             "grad\_update" : grad\_update}










return graph\_ops

接下来的函数是实现 actor-learner 架构的核心,使用了 epsilon-贪婪策略,即选择概率值 0 <= epsilon <= 1,每次

def actor_learner_thread(thread_id, env, session, graph_ops, num_actions, summary_ops, saver):

global TMAX, T




# 获取上面 build\_graph 函数中导出的 graph 节点,略去不表




# ……  










# Atari 模拟器封装类对象




env = AtariEnvironment(gym\_env=env, resized\_width=FLAGS.resized\_width, resized\_height=FLAGS.resized\_height, agent\_history\_length=FLAGS.agent\_history\_length)










s\_batch = []  # 观测序列




a\_batch = []  # 行为序列




y\_batch = []  # y 序列,公式见上










final\_epsilon = sample\_final\_epsilon()




initial\_epsilon = 1.0




epsilon = 1.0










print "Starting thread ", thread\_id, "with final epsilon ", final\_epsilon










time.sleep(3*thread\_id)




t = 0




while T < TMAX:




    # 获取游戏初始观测




    s\_t = env.get\_initial\_state()




    terminal = False










    # 设置计数器




    ep\_reward = 0    # 每局游戏奖励值




    episode\_ave\_max\_q = 0    




    ep\_t = 0










    while True:




        # 向 DQN 送入初始观测 s\_t,执行 forward 计算,得到 Q 函数




        readout\_t = q\_values.eval(session = session, feed\_dict = {s : [s\_t]})




        




        # Choose next action based on e-greedy policy




        # 利用 epsilon-贪婪策略选择下一行为 a\_t  





        # 即以概率 epsilon 选择随机行为(探索过程)  





        # 而以 (1 - epsilon) 概率选择当前 Q 函数中最优行为(经验指导)  





        a\_t = np.zeros([num\_actions])




        action\_index = 0




        if random.random() <= epsilon:




            action\_index = random.randrange(num\_actions)




        else:




            action\_index = np.argmax(readout\_t)




        a\_t[action\_index] = 1  # 相应行为置 1 表示选中的下一步动作










        # 模拟退火,epsilon 在开始训练时较大,随着训练进行不断衰减




        if epsilon > final\_epsilon:




            epsilon -= (initial\_epsilon - final\_epsilon) / FLAGS.anneal\_epsilon\_timesteps









        # 模拟器按照上面得到的行为 

picture.image 进行一步更新,得到 picture.imagepicture.image

        s\_t1, r\_t, terminal, info = env.step(action\_index)










        # 将 

picture.image 送入 Target Q-Network,得到 Target Q 函数值

        readout\_j1 = target\_q\_values.eval(session = session, feed\_dict = {st : [s\_t1]})










        # 将 

picture.image 范围限制在 [-1, 1] 之间

        clipped\_r\_t = np.clip(r\_t, -1, 1)










        # 计算 y 值,即

picture.image

        if terminal:  





            # 如果 Game Over,就不用再算下一步观测的 Q 函数了




            # 直接 y = r\_t  





            y\_batch.append(clipped\_r\_t) 




        else:




            y\_batch.append(clipped\_r\_t + FLAGS.gamma * np.max(readout\_j1))









        a\_batch.append(a\_t)    




        s\_batch.append(s\_t)









        # 更新观测和计数器




        s\_t = s\_t1




        T += 1




        t += 1










        ep\_t += 1




        ep\_reward += r\_t    # 奖励累积




        episode\_ave\_max\_q += np.max(readout\_t)    #  Q 值平滑










        # 周期性同步Q 网络权值到 Target Q 网络权值,频次低




        if T % FLAGS.target\_network\_update\_frequency == 0:




            session.run(reset\_target\_network\_params)









        # 利用 SGD更新 Q 网络权值,频次高




        if t % FLAGS.network\_update\_frequency == 0 or terminal:




            if s\_batch:




                session.run(grad\_update, feed\_dict = {y : y\_batch,




                                                      a : a\_batch,




                                                      s : s\_batch})




            # 更新完毕,清除输入




            s\_batch = []




            a\_batch = []




            y\_batch = []









        # 保存模型




        if t % FLAGS.checkpoint\_interval == 0:




            saver.save(session, FLAGS.checkpoint\_dir+"/"+FLAGS.experiment+".ckpt", global\_step = t)









        # 统计日志信息




        if terminal:




            stats = [ep\_reward, episode\_ave\_max\_q/float(ep\_t), epsilon]




            for i in range(len(stats)):




                session.run(update\_ops[i], feed\_dict={summary\_placeholders[i]:float(stats[i])})




            print "THREAD:", thread\_id, "/ TIME", T, "/ TIMESTEP", t, "/ EPSILON", epsilon, "/ REWARD", ep\_reward, "/ Q\_MAX %.4f" % (episode\_ave\_max\_q/float(ep\_t)), "/ EPSILON PROGRESS", t/float(FLAGS.anneal\_epsilon\_timesteps)




            break

以上内容已经把 DQN 介绍完毕,如果仍有点小困惑,推荐大家阅读参考文献【2】,看累了就玩会小游戏,玩累了再看,有奇效!

Reference

【1】DeepMind, Playing Atari with Deep Reinforcement Learning, https://arxiv.org/abs/1312.5602

【2】Richard S. S. Reinforcement Learning : An Introduction, The MIT Press.

【3】Arun Nair, Massively Parallel Methods for Deep Reinforcement Learning

【4】Tom M. Michell, Machine Learning, Chapter 13, "Reinforcement Learning".

【5】Kevin Chen, Deep Reinforcement Learning for Flappy Bird, http://cs229.stanford.edu/proj2015/362\_report.pdf


如果你觉得本文对你有帮助,请关注公众号,将来会有更多更好的文章推送!

picture.image

0
0
0
0
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论