【超详细小白必懂】Pytorch 直接加载ResNet50模型和参数实现迁移学习

这篇具有很好参考价值的文章主要介绍了【超详细小白必懂】Pytorch 直接加载ResNet50模型和参数实现迁移学习。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

Torchvision.models包里面包含了常见的各种基础模型架构,主要包括以下几种:(我们以ResNet50模型作为此次演示的例子)

AlexNet
VGG
ResNet
SqueezeNet
DenseNet
Inception v3
GoogLeNet
ShuffleNet v2
MobileNet v2
ResNeXt
Wide ResNet
MNASNet

首先加载ResNet50模型,如果如果需要加载模型本身的参数,需要使用pretrained=True,代码如下

import torchvision
from torchvision import models
resnet50 = models.resnet50(pretrained=True) #pretrained=True 加载模型以及训练过的参数
print(resnet50)  # 打印输出观察一下resnet50到底是怎么样的结构

打印输出后ResNet50部分结构如下图,其中红框的全连接层是需要关注的点。全连接层中,“resnet50” 的out_features=1000,这也就是说可以进行class=1000的分类。

models.resnet50,深度学习,机器学习,人工智能,pytorch,迁移学习

 由于我们正常所使用的分类场景大概率与resnet50的分类数不一样,所以在调用时,要使用out_features=分类数进行调整。假设我们采用CIFAR10数据集(10 class)进行测试,那么我们就需要修改全连接层,out_features=10。具体代码如下:

resnet50 = models.resnet50(pretrained=True)
num_ftrs = resnet50.fc.in_features 
for param in resnet50.parameters():
    param.requires_grad = False #False:冻结模型的参数,也就是采用该模型已经训练好的原始参数。只需要训练我们自己定义的Linear层

#保持in_features不变,修改out_features=10
resnet50.fc = nn.Sequential(nn.Linear(num_ftrs,10),
                            nn.LogSoftmax(dim=1))

一个简单完整的 CIFAR10+ResNet50 训练代码如下:

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import models

#下载CIFAR10数据集
train_data = torchvision.datasets.CIFAR10(root="../data",train=True,transform=torchvision.transforms.ToTensor(),
                                          download=False)
test_data = torchvision.datasets.CIFAR10(root="../data",train=False,transform=torchvision.transforms.ToTensor(),
                                         download=False)
train_data_size = len(train_data)
test_data_size = len(test_data)
print("The size of Train_data is {}".format(train_data_size))
print("The size of Test_data is {}".format(test_data_size))

#dataloder进行数据集的加载
train_dataloader = DataLoader(train_data,batch_size=128)
test_dataloader = DataLoader(test_data,batch_size=128)

resnet50 = models.resnet50(pretrained=True)
num_ftrs = resnet50.fc.in_features
for param in resnet50.parameters():
    param.requires_grad = False #False:冻结模型的参数,
                                # 也就是采用该模型已经训练好的原始参数。
                                #只需要训练我们自己定义的Linear层
resnet50.fc = nn.Sequential(nn.Linear(num_ftrs,10),
                            nn.LogSoftmax(dim=1))

# 网络模型cuda
if torch.cuda.is_available():
    resnet50 = resnet50.cuda()

#loss
loss_fn = nn.CrossEntropyLoss()
if torch.cuda.is_available():
    loss_fn = loss_fn.cuda()
#optimizer
learning_rate = 0.01
optimizer = torch.optim.SGD(resnet50.parameters(),lr=learning_rate,)

#设置网络训练的一些参数
#记录训练的次数
total_train_step = 0
#记录测试的次数
total_test_step = 0
#训练的轮数
epoch = 10

for i in range(epoch):
    print("-------第{}轮训练开始-------".format(i+1))
    resnet50.train()
    #训练步骤开始
    for data in train_dataloader:
        imgs, targets = data
        if torch.cuda.is_available():
            # 图像cuda;标签cuda
            # 训练集和测试集都要有
            imgs = imgs.cuda()
            targets = targets.cuda()
        outputs = resnet50(imgs)
        loss = loss_fn(outputs, targets)

        # 优化器优化模型
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step = total_train_step + 1
        if total_train_step % 100 == 0:
            print("训练次数:{}, Loss: {}".format(total_train_step, loss.item()))
            #writer.add_scalar("train_loss", loss.item(), total_train_step)

    #测试集
    total_test_loss = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            if torch.cuda.is_available():
                # 图像cuda;标签cuda
                # 训练集和测试集都要有
                imgs = imgs.cuda()
                targets = targets.cuda()
            outputs = resnet50(imgs)
            loss = loss_fn(outputs,targets)
            total_test_loss += loss.item()
            total_test_step += 1
            if total_test_step % 100 ==0:
                print("测试次数:{},Loss:{}".format(total_test_step,total_test_loss))

完美!!!!!

剩下的大家可以举一反三,继续探索。。。。文章来源地址https://www.toymoban.com/news/detail-563058.html

到了这里,关于【超详细小白必懂】Pytorch 直接加载ResNet50模型和参数实现迁移学习的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 卷积神经网络学习—Resnet50(论文精读+pytorch代码复现)

    如果说在CNN领域一定要学习一个卷积神经网络,那一定非Resnet莫属了。 接下来我将按照:Resnet论文解读、Pytorch实现ResNet50模型两部分,进行讲解,博主也是初学者,不足之处欢迎大家批评指正。 预备知识 :卷积网络的深度越深,提取的特征越高级,性能越好,但传统的卷积

    2024年01月19日
    浏览(45)
  • ResNet18、50模型结构

    论文地址: https://arxiv.org/pdf/1512.03385.pdf pytorch官方 预训练模型 地址: pytorch官方 resnet网络 代码(包括resnet18、34、50、101、152,resnext50_32x4d、resnext101_32x8d、wide_resnet50_2、wide_resnet101_2): torchvision.models.resnet — Torchvision 0.11.0 documentation https://pytorch.org/vision/stable/_modules/torchvis

    2024年02月06日
    浏览(41)
  • 什么是Resnet50模型?

     随着CNN的不断发展,为了获取深层次的特征,卷积的层数也越来越多。一开始的 LeNet 网络只有 5 层,接着 AlexNet 为 8 层,后来 VggNet 网络包含了 19 层,GoogleNet 已经有了 22 层。但仅仅通过增加网络层数的方法,来增强网络的学习能力的方法并不总是可行的,因为网络层数到

    2023年04月13日
    浏览(43)
  • pytorch实现AI小设计-1:Resnet50人脸68关键点检测

            本项目是AI入门的应用项目,后续可以补充内容完善作为满足个人需要。通过构建自己的人脸数据集,此项目训练集为4580张图片,测试集为2308张图片,使用resnet50网络进行训练,最后进行效果展示。本项目也提供了量化内容,便于在硬件上部署。         研究A

    2024年01月18日
    浏览(44)
  • FPGA上利用Vitis AI部署resnet50 TensorFlow神经网络模型

    参考Xilinx官方教程快速入门 • Vitis AI 用户指南 (UG1414) 克隆 Vitis AI 存储库以获取示例、参考代码和脚本(连接github失败可能需要科学上网)。 安装Docker如何在 Ubuntu 20.04 上安装和使用 Docker 安装完docker后,下载最新Vitis AI Docker, 将官方的指令 docker pull xilinx/vitis-ai-pytorch/tensorfl

    2024年02月04日
    浏览(46)
  • ResNet代码复现+超详细注释(PyTorch)

    关于ResNet的原理和具体细节,可参见上篇解读:经典神经网络论文超详细解读(五)——ResNet(残差网络)学习笔记(翻译+精读+代码复现) 接下来我们就来复现一下代码。 源代码比较复杂,感兴趣的同学可以上官网学习:  https://github.com/pytorch/vision/tree/master/torchvision 本

    2024年02月11日
    浏览(42)
  • 图像分类:Pytorch图像分类之--ResNet模型

    前言  ResNet 网络是在 2015年 由微软实验室提出,斩获当年ImageNet竞赛中分类任务第一名,目标检测第一名。获得COCO数据集中目标检测第一名,图像分割第一名。 原论文地址:Deep Residual Learning for Image Recognition(作者是CV大佬何凯明团队) ResNet创新点介绍 在ResNet网络中创新点

    2023年04月11日
    浏览(36)
  • 人工智能(Pytorch)搭建模型6-使用Pytorch搭建卷积神经网络ResNet模型

    大家好,我是微学AI,今天给大家介绍一下人工智能(Pytorch)搭建模型6-使用Pytorch搭建卷积神经网络ResNet模型,在本文中,我们将学习如何使用PyTorch搭建卷积神经网络ResNet模型,并在生成的假数据上进行训练和测试。本文将涵盖这些内容:ResNet模型简介、ResNet模型结构、生成假

    2024年02月06日
    浏览(78)
  • PyTorch示例——ResNet34模型和Fruits图像数据

    ResNet34模型,做图像分类 数据使用水果图片数据集,下载见Kaggle Fruits Dataset (Images) Kaggle的Notebook示例见 PyTorch——ResNet34模型和Fruits数据 下面见代码 查看图像 展示多张图片 苹果 樱桃 直接使用ImageFolder加载数据,按目录解析水果类别 输出如下 ResidualBlock ResNet34 准备代码 开始

    2024年02月12日
    浏览(48)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包