Pytorch:手把手教你搭建简单的全连接网络

这篇具有很好参考价值的文章主要介绍了Pytorch:手把手教你搭建简单的全连接网络。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

代码里的注释一定要看!!!里面包括了一些基本知识和原因

可以依次把下面的代码段合在一起运行,也可以通过jupyter notebook分次运行

第一步:一些库的导入

import torch#深度学习的pytoch平台
import torch.nn as nn
import numpy as np
import random
import time#可以用来简单地记录时间
import matplotlib.pyplot as plt#画图
#随机种子
random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
torch.cuda.manual_seed_all(1234)

第二步:构建简单的数据集,这里利用sinx函数作为例子

x = np.linspace(-np.pi,np.pi).astype(np.float32)
y = np.sin(x)
#随机取25个点
x_train = random.sample(x.tolist(),25)    #x_train 就相当于网络的输入
y_train = np.sin(x_train)                 #y_train 就相当于输入对应的标签,每一个输入都会对应一个标签
plt.scatter(x_train,y_train,c="r")
plt.plot(x,y)

Pytorch:手把手教你搭建简单的全连接网络

 红色的点就是我在sinx函数上取的已知点作为网络的训练点。

第三步:用pytorch搭建简单的全连接网络

class DNN(nn.Module):
    def __init__(self):
        super().__init__()
        layers =  [1,20,1]   #网络每一层的神经元个数,[1,10,1]说明只有一个隐含层,输入的变量是一个,也对应一个输出。如果是两个变量对应一个输出,那就是[2,10,1]
        self.layer1 = nn.Linear(layers[0],layers[1])  #用torh.nn.Linear构建线性层,本质上相当于构建了一个维度为[layers[0],layers[1]]的矩阵,这里面所有的元素都是权重
        self.layer2 = nn.Linear(layers[1],layers[2])
        self.elu = nn.ELU()       #非线性的激活函数。如果只有线性层,那么相当于输出只是输入做了了线性变换的结果,对于线性回归没有问题。但是非线性回归我们需要加入激活函数使输出的结果具有非线性的特征
    def forward(self,d):#d就是整个网络的输入
        d1 = self.layer1(d)
        d1 = self.elu(d1)#每一个线性层之后都需要加入一个激活函数使其非线性化。
        d2 = self.layer2(d1)#但是在网络的最后一层可以不用激活函数,因为有些激活函数会使得输出结果限定在一定的值域里。
        return d2

第四步:一些基本参数变量的确定以及数据格式的转换

device = torch.device("cuda") #在跑深度学习的时候最好使用GPU,这样速度会很快。不要的话默认用cpu跑
epochs = 10000                #这是迭代次数,把所有的训练数据输入到网络里去就叫完成了一次epoch。
learningrate = 1e-4           #学习率,相当于优化算法里的步长,学习率越大,网络参数更新地更加激进。学习率越小,网络学习地更加稳定。
net = DNN().to(device=device) #网络的初始化
optimizer = torch.optim.Adam(net.parameters(), lr=learningrate)#优化器,不同的优化器选择的优化方式不同,这里用的是随机梯度下降SGD的一种类型,Adam自适应优化器。需要输入网络的参数以及学习率,当然还可以设置其他的参数
mseloss  = nn.MSELoss()      #损失函数,这里选用的是MSE。损失函数也就是用来计算网络输出的结果与对应的标签之间的差距,差距越大,说明网络训练不够好,还需要继续迭代。
MinTrainLoss = 1e10          
train_loss =[]               #用一个空列表来存储训练时的损失,便于画图
pt_x_train = torch.from_numpy(np.array(x_train)).to(device=device,dtype = torch.float32).reshape(-1,1)  #这里需要把我们的训练数据转换为pytorch tensor的类型,并且把它变成gpu能运算的形式。
pt_y_train = torch.from_numpy(np.array(y_train)).to(device=device,dtype = torch.float32).reshape(-1,1) #reshap的目的是把维度变成(25,1),这样25相当于是batch,我们就可以一次性把所有的点都输入到网络里去,最后网络输出的结果也不是(1,1)而是(25,1),我们就能直接计算所有点的损失
print(pt_x_train.dtype)
print(pt_x_train.shape)

第五步:网络训练过程

start = time.time()
start0=time.time()
for epoch in range(1,epochs+1):
    net.train()    #net.train():在这个模式下,网络的参数会得到更新。对应的还有net.eval(),这就是在验证集上的时候,我们只评价模型,并不对网络参数进行更新。
    pt_y_pred = net(pt_x_train) #将tensor放入网络中得到预测值
    loss = mseloss(pt_y_pred,pt_y_train)  #用mseloss计算预测值和对应标签的差别
    optimizer.zero_grad()      #在每一次迭代梯度反传更新网络参数时,需要把之前的梯度清0,不然上一次的梯度会累积到这一次。
    loss.backward()  # 反向传播
    optimizer.step() #优化器进行下一次迭代
    if epoch % 10 == 0:#每10个epoch保存一次loss
        end = time.time()
        print("epoch:[%5d/%5d] time:%.2fs current_loss:%.5f"
          %(epoch,epochs,(end-start),loss.item()))
        start = time.time()
    train_loss.append(loss.item())
    if train_loss[-1] < MinTrainLoss:
        torch.save(net.state_dict(),"model.pth") #保存每一次loss下降的模型
        MinTrainLoss = train_loss[-1]
end0 = time.time()
print("训练总用时: %.2fmin"%((end0-start0)/60)) 

Pytorch:手把手教你搭建简单的全连接网络

 训练过程如上,时间我这里设置的比较简单,除了分钟,之后的时间没有按照60进制规定。

第六步:查看loss下降情况

plt.plot(range(epochs),train_loss)
plt.xlabel("epoch")
plt.ylabel("loss")

Pytorch:手把手教你搭建简单的全连接网络

可以看到收敛的还是比较好的。

第七步:导入网络模型,输入验证数据,预测结果 

x_test = np.linspace(-np.pi,np.pi).astype(np.float32)
pt_x_test = torch.from_numpy(x_test).to(device=device,dtype=torch.float32).reshape(-1,1)
Dnn = DNN().to(device)
Dnn.load_state_dict(torch.load("model.pth",map_location=device))#pytoch 导入模型
Dnn.eval()#这里指评价模型,不反传,所以用eval模式
pt_y_test = Dnn(pt_x_test) 
y_test = pt_y_test.detach().cpu().numpy()#输出结果torch tensor,需要转化为numpy类型来进行可视化
plt.scatter(x_train,y_train,c="r")
plt.plot(x_test,y_test)

Pytorch:手把手教你搭建简单的全连接网络

这里红色的点为训练用的数据,蓝色为我们的预测曲线,可以看到整体上拟合的是比较好的。

以上就是用pytorch搭建的简单全连接网络的基本步骤,希望可以给到初学者一些帮助!文章来源地址https://www.toymoban.com/news/detail-465859.html

实验室网址:CIG | zhixiang

Github网址:ProgrammerZXG (Zhixiang Guo) · GitHub

到了这里,关于Pytorch:手把手教你搭建简单的全连接网络的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包赞助服务器费用

相关文章

  • 手把手教你搭建自己本地的ChatGLM

    手把手教你搭建自己本地的ChatGLM

    如果能够本地自己搭建一个ChatGPT的话,训练一个属于自己知识库体系的人工智能AI对话系统,那么能够高效的处理应对所属领域的专业知识,甚至加入职业思维的意识,训练出能够结合行业领域知识高效产出的AI。这必定是十分高效的生产力工具,且本地部署能够保护个人数

    2024年02月03日
    浏览(28)
  • 手把手教你5分钟搭建RabbitMq开发环境

    手把手教你5分钟搭建RabbitMq开发环境

    演示环境 1、使用Vagrant 和 VirtualBox创建linux虚拟机 不知道Vagrant怎么使用的可以看这里。 ①在cmd窗口执行命令 vagrant init generic/centos7 ,初始化linux启动环境 ②执行启动命令 vagrant up 启动Linux虚拟机 ③修改当前目录的Vagrantfile文件,为虚拟机配置内网ip,后面登录的时候会用到

    2023年04月12日
    浏览(9)
  • 手把手教你,本地RabbitMQ服务搭建(windows)

    手把手教你,本地RabbitMQ服务搭建(windows)

    前面已经对RabbitMQ介绍了很多内容,今天主要是和大家搭建一个可用的RabbitMQ服务端,方便后续进一步实操与细节分析 跟我们跑java项目,要装jdk类似。rabbitMQ是基于Erlang开发的,因此安装rabbitMQ服务器之前,需要先安装Erlang环境。 【PS: 我已经上传了对应资源,windows可直接下载

    2024年02月14日
    浏览(14)
  • 手把手教你搭建 Webpack 5 + React 项目

    手把手教你搭建 Webpack 5 + React 项目

    在平时工作中,为减少开发成本,一般都会使用脚手架来进行开发,比如 create-react-app 。脚手架都会帮我们配置好了 webpack,但如果想自己搭建 webpack 项目要怎么做呢?这边文章将介绍如何使用 webpack 5 来搭建 react 项目,项目地址在文末。 1.1 Webpack 的好处 试想在不使用任何打

    2024年02月08日
    浏览(15)
  • 手把手教你搭建一个Minecraft 服务器

    手把手教你搭建一个Minecraft 服务器

    这次,我们教大家如何搭建一个我的世界服务器 首先,我们来到这个网站 MCVersions.net - Minecraft Versions Download List MCVersions.net offers an archive of Minecraft Client and Server jars to download, for both current and old releases! https://mcversions.net/   在这里,我们点击对应的版本,从左到右依次是稳定版

    2024年02月09日
    浏览(16)
  • 手把手教你搭建ARM32 QEMU环境

    手把手教你搭建ARM32 QEMU环境

    我们知道嵌入式开发调试就要和各种硬件打交道,所以学习就要专门购买各种开发版,浪费资金,开会演示效果还需要携带一大串的板子和电线,不胜其烦。然而Qemu的使用可以避免频繁在开发板上烧写版本,如果进行的调试工作与外设无关,仅仅是内核方面的调试,Qemu模拟

    2024年02月19日
    浏览(14)
  • 手把手教你搭建内网穿透服务器

    手把手教你搭建内网穿透服务器

    有时候我们需要把外网可以访问自己的内网,比如在微信公众号开发调用接口时为了方便调试就需要配置回调地址或者是想把自己的nas可以在不在家就能访问,这时候就需要内网穿透。使用内网穿透主要有几种方式,1.使用内网穿透服务商提供的服务,但是这种需要付费,免

    2024年04月23日
    浏览(13)
  • 手把手教你在Windows下搭建Vue开发环境

    手把手教你在Windows下搭建Vue开发环境

    最近有小伙伴不会Vue环境的部署,小孟亲自测试了下,大家有需要的可以按照下面的学习。 如果想看视频的,也可以看视频的教程: https://www.bilibili.com/video/BV1if4y1X7BS/?spm_id_from=333.788.recommend_more_video.-1vd_source=e64f225fc5daf048d2687502cb23bb3b 在Windows下搭建Vue开发环境: 官网https://n

    2024年02月08日
    浏览(14)
  • 1. [手把手教你搭建] 之 在linux上搭建java环境

    1. [手把手教你搭建] 之 在linux上搭建java环境

    当我们要在服务器上部署自己的java服务时,首先我们需要安装和配置好java环境,那么我们现需要在服务器上下载java1.8版本的安装包,之后再完成环境配置,服务部署这一套流程,本文会讲解java安装包的下载及环境配置,这里使用的是压缩包的安装方式: 首先创建package目录

    2023年04月11日
    浏览(5)
  • 【Docker】手把手教你搭建好玩的docker项目合集

    【Docker】手把手教你搭建好玩的docker项目合集

    这是我在使用docker后,慢慢一个个累计起来的项目,觉得还挺有意思的。 之后我会持续慢慢的更新新的项目,大伙如何有好玩的docker项目,欢迎来找我讨论哇,我每天都会看私信的 docker搭建数据库 使用docker安装数据库是非常省事的,而且想安什么类型的数据就安什么类型的

    2024年02月07日
    浏览(7)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包