手把手带你实现DQN(TensorFlow2)

这篇具有很好参考价值的文章主要介绍了手把手带你实现DQN(TensorFlow2)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

        大家好,今天给大家带来DQN的思路及实现方法。

        关于DQN,就不用我多做介绍了,我会以最简短明白的阐述讲解DQN,尽量让你在10分钟内理清思路。

        非常重要的一点!!!

        非常重要的一点!!!我在GitHub上下载了DQN代码,跑完后,我重写一次,删改了其中一些东西。比如epsilon,至于原因还有带来的一些结果,我放到后面来说,总之,我们先进入代码环节。先给出参考网址:

https://github.com/marload/DeepRL-TensorFlow2/blob/master/DQN/DQN_Discrete.py

下面是在 gym.make('CartPole-v1) 中实现的效果:

手把手带你实现DQN(TensorFlow2)

下面我们按照DQN的算法来实现DQN

DQN的算法如下图:

手把手带你实现DQN(TensorFlow2)

 

第一步:初始化一个容量为N的存取器D

        我们希望D能够存放(def        insert)、拿取(def        get_sample)数据。

下面我们用collections.deque来实现:

from collections import deque
import numpy as np
import random

class buffer():

    def __init__(self):
        self.buffer = deque(maxlen=1000)

    def insert(self,state,next_state,action,reward,is_over):
        self.buffer.append([state,next_state,action,reward,is_over])

    def get_sample(self):
        sample = random.sample(self.buffer,32)
        state,next_state,action,reward,is_over = map(np.asarray,zip(*sample))
        state = np.array(state).reshape(32,-1)
        next_state = np.array(next_state).reshape(32,-1)
        return state,next_state,action,reward,is_over

    def length(self):
        return len(self.buffer)

好了,我们已经有了一个buffer可以存放数据和拿取数据了。

第二步:初始化动作奖励函数Q

        对于Q函数,我们输入一个状态(state)返回action-value。在 gym.make('CartPole-v1) 中,状态的shape为(4,),动作数量为2。你可以用下列代码查看:

import gym
env = gym.make('CartPole-v1',render_mode='human')
print(env.observation_space.shape)# 查看观测空间
print(env.action_space.n)# 查看动作数

       因为奖励要对应其中的一个动作, 所以,我们的Q函数输入大小为(None,4),输出为(None,2)就能够确定下来了。下面我们用全连接层来实现这个网络:

import tensorflow as tf
from tensorflow.keras import layers,Input,Model,optimizers

def model():
    size = (4,)
    input = Input(size)
    x = layers.Dense(32,activation='relu')(input)
    x = layers.Dense(16,activation='relu')(x)
    out = layers.Dense(2)(x)
    model = Model(inputs=input,outputs=out)
    model.compile(loss='mse',
                  optimizer=optimizers.Adam(0.005))
    return model

这里action-value函数也实现了。

第三步:得到一个初始化的游戏状态

       首先我们需要建立一个游戏,不然怎么得到游戏状态,对吗?关于游戏这块,我们不需要深究,我们只要把注意力放在算法上就行了,它如何实现的,不关我们的事。代码如下:

import gym

env = gym.make('CartPole-v1',render_mode='human')
env.reset() # 初始化游戏
state = env.state # 得到初始状态

第四步:理清并实现训练流程

       我们再理清一下思路:

       1、从上面我们已经得到一个state状态了,接下来我们将state状态reshape后放入Q函数内就能得到action-value (None,2)。

        2、我们要从这两个里面选取奖励最大的value值的index(数值为0或者1),这个index就对应着游戏里的操作action了。

        3、我们将得到的操作action传入游戏,就会得到一个新的state、奖励reward、是否结束is_over等等参数。

        4、我们将这些参数保存到最开始定义的buffer中,在从buffer中随机选取参数出来训练就可以了。

代码如下:

env = gym.make('CartPole-v1',render_mode='human')
buf = buffer()# 初始化buffer
network = model()# 初始化训练网络
lable_network = model()# 初始化标签网络
#这里我感觉定义lable_network网络有些多余
#因为network的参数给了lable_network
#我没有深究,DQN权当了解,因为还有DQN变种的模型
for i in range(1000):
    env.reset() # 初始化游戏
    state = env.state # 得到初始状态
    is_over = False #
    weights = network.get_weights()
    lable_network.set_weights(weights)
    while not is_over:
        state = np.reshape(state, [1, 4]) 
        action = np.argmax(network.predict(state)[0])
        next_state,reward,is_over,_,_ = env.step(action)#传入操作获取新的数据
        buf.insert(state,next_state,action,reward,is_over)#参数保存进buffer
        state = next_state 
    if buf.length()>=32:
        train()

好了,到这里DQN的流程算是走完了,但是这里还有一个函数train()没有定义。

第五步:定义train()函数

       我们要通过train()函数来更新network的权重weights,这是必需的。

       buffer里面的get_sample我们是不是还没用过呢?这就来了。

       按照DQN算法:当is_over为True的时候,奖励为reward本身

                                 当is_over为False的时候,奖励为reward+ k*next_reward

实现如下:

def train():
    for i in range(10):
        state,next_state,action,reward,is_over = buf.get_sample()
        value = lable_network.predict(state)
        next_max_value = np.max(lable_network.predict(next_state),axis=1)
        value[range(32),action] = reward+0.95*next_max_value*(1-is_over)
        network.train_on_batch(state,value)

终于,我们完成了DQN网络。

最后一些想说的话(选择阅读)

       如果你已经按照上面的代码完整打了下来,我想说的是,很遗憾,你不一定能训练处上面的效果。我假设你看过上面GitHub中的代码(那里没有去掉参数)。现在,我们回到最开始的那里,我说过,我去除掉了epsilon参数。因为我在训练出一个了一个稳定的模型后,大概ai玩了10分钟的时候,输掉了游戏。这是令我震惊的,因为前面10分钟,ai分明陷入了循环(周期假设为60s),按理来说会一直保持下去。我认为,输掉的原因就是epsilon。因为这导致了一种随机性,虽然很小,但玩的时间久了,总会出现随件的action,也就是很可能稳定的模型里,本来应该是0,但突然给你换成了1,打破了原有的模型。所以我去除了,之后,这个ai虽然确实没有输掉,但也出现了新的问题。

        这个问题就是,训练的效果不稳定了。最开始那个稳定的效果是我在晚上花30分钟跑出来的。出于严谨,我在第二天重新跑了一次。而这次,一个多小时也没有跑到稳定的效果。我认为,因为get_sample函数里是随机选取数据的,所以可能选不到合适的数据去train,不能让模型很好的稳定下来。我尝试过,不去随机选取,将每一个数据拿去train,但效果也不尽如人意。最后按我的经历来看,如果你想要得到一个稳定的模型,我建议先保留epsilon,在代码中添加一个保留weights的代码,等ai玩的稳定的时候,停住程序,删除epsilon,导入保存的weights,再运行程序,应该能得到想要的效果。

希望这篇文章能对您有帮助,感谢您的观看!文章来源地址https://www.toymoban.com/news/detail-401397.html

到了这里,关于手把手带你实现DQN(TensorFlow2)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【是C++,不是C艹】 手把手带你实现Date类(附源码)

    💞💞 欢迎来到 Claffic 的博客 💞💞  👉  专栏: 《是C++,不是C艹》👈 前言: 恍惚间,已经两个月没更新了 (;´д`)ゞ 我忏悔...  但C++的学习不能停止!这期带大家实践一波,手把手教大家实现一个Date类,以感受C++类的魅力 注: 你最好是学完了C语言,并学过一些初

    2024年02月10日
    浏览(16)
  • 【数据结构】—手把手带你用C语言实现栈和队列(超详细!)

                                       食用指南:本文在有C基础的情况下食用更佳                                     🔥 这就不得不推荐此专栏了:C语言                                   ♈️ 今日夜电波:Tell me —milet                    

    2024年02月14日
    浏览(22)
  • 手把手带你实现ChatGLM2-6B的P-Tuning微调

    参考文献:chatglm2ptuning 注意问题1:AttributeError: ‘Seq2SeqTrainer’ object has no attribute \\\'is_deepspeed_enabl torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 可能是版本太高,可以参考chatglm2的环境 1. ChatGLM2-6B的P-Tuning微调 ChatGLM2-6B :https://github.com/THUDM/ChatGLM2-6B 模型地址 :https://hug

    2024年02月17日
    浏览(22)
  • 从0到1,手把手带你开发截图工具ScreenCap------001实现基本的截图功能

    从0到1,手把手带你开发windows端的截屏软件ScreenCap 当前版本:ScreenCap---001 支持全屏截图 支持鼠标拖动截图区域 支持拖拽截图 支持保存全屏截图 支持另存截图到其他位置 注:博主所有资源永久免费,若有帮助,请点赞转发是对我莫大的帮助 注:博主本人学习过程的分享,

    2024年02月05日
    浏览(25)
  • 从0到1,手把手带你开发截图工具ScreenCap------002实现设置默认保存的图片位置

    在ScreenCap实现截图功能后增加设置图片默认保存位置的功能 实现选择文件夹作为截图的默认保存位置 注:博主所有资源永久免费,若有帮助,请点赞转发是对我莫大的帮助 注:博主本人学习过程的分享,引用他人的文章皆会标注原作者 注:本人文章非盈利性质,若有侵权请

    2024年02月05日
    浏览(52)
  • 从0到1,手把手带你开发截图工具ScreenCap------003实现最小化程序到托盘运行

    为了方便截图干净,实现最小化程序到托盘运行,简洁,勿扰 实现最小化程序到托盘运行 实现托盘菜单功能 实现回显主窗体 实现托盘开始截屏 实现气泡信息提示 实现托盘程序提示 实现托盘退出程序 封装完好,可复用 注:博主所有资源永久免费,若有帮助,请点赞转发是

    2024年02月05日
    浏览(25)
  • 手把手带你使用ESP8266 与 STM32F103C8实现网络服务器

    随着现在物联网设备的而越来越多,现在市场上出现越来越多的物联网设备,其中 ESP8266 是最受欢迎、价格便宜且易于使用的模块,它可以将您的硬件连接到互联网。 今天我们就以ESP8266和STM32来实现一台网络服务器,我们使用 ESP8266 将 STM32F103C8 连接到互联网。 ESP8266 Wi-Fi 模

    2024年01月23日
    浏览(18)
  • 手把手带你搞懂AMS启动原理

    彻底搞懂AMS即ActivityManagerService,看这一篇就够了 最近那么多教学视频(特别是搞车载的)都在讲AMS,可能这也跟要快速启动一个app(甚至是提高安卓系统启动速度有关),毕竟作为安卓系统的核心系统服务之一,AMS以及PMS都是很重要的,而我之前在 应用的开端–PackageManag

    2024年02月12日
    浏览(24)
  • 手把手解决module ‘tensorflow‘ has no attribute ‘placeholder

    1、问题背景 :构建神经网络在加入卷积层时出现报错 face_recigntion_model.add(Conv2D(32,3,3,input_shape=(IMAGE_SIZE,IMAGE_SIZE,3),activation=\\\'relu\\\')) AttributeError: module \\\'tensorflow\\\' has no attribute \\\'placeholder\\\' 2、报错原因 :可能是由于tf.placeholder的版本问题,tf.placeholder是tensorflow1.x版本的东西,tensorflow

    2024年01月21日
    浏览(19)
  • 【Python】OpenAI:基于 Gym-CarRacing 的自动驾驶项目(1) | 前置知识介绍 | 项目环境准备 | 手把手带你一步步实现

     猛戳!跟哥们一起玩蛇啊  👉 《一起玩蛇》🐍  💭 写在前面:  本篇是关于自动驾驶专业项目 Gym-CarRacing 的博客。GYM-Box2D CarRacing 是一种在 OpenAI Gym 平台上开发和比较强化学习算法的模拟环境。 本专栏提供完整可运行代码,包括环境安装的详细讲解,将通过 \\\"理论+实践

    2024年02月02日
    浏览(30)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包