torch中如何使用预训练权重

这篇具有很好参考价值的文章主要介绍了torch中如何使用预训练权重。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

解释说明:目前很多主流的网络模型主要包含backbone+其他结构(分类,回归),那么如何在训练自己的网络模型时使用别人已经训练好的网络模型权重呢??本文以Resnet50为例,构建一个基于resnet50的网络模型预训练过程。

1. Torchvision中封装的主流网络模型

  • torchvision中封装了Resnet系列、vgg系列、inception系列等网络模型,切内部给出了每个网络模型预训练权重的url路径

  • 如下图所示,为torchvison官方封装的Resnet系列网络

    resnet50预训练权重,目标检测,计算机视觉,深度学习,人工智能

2. 如何使用预训练权重

解释说明:根据自己的理解,使用预训练权重过程主要包含以下几个步骤

  • 创建自己的网络模型:前文说道,网络模型主要包含backbone+其他部分(分类、回归等),因此对于任意一个网络模型而言,只要对backbone做预训练处理就行了(即网络backbone部分载入官方训练好的权重,只训练后续的其他部分)
  • 从torch官方中载入训练权重字典
  • 将torch官方的预训练权重中需要的部分载入进自己的网络模型

模型权重载入完毕后,这是需要根据个人需要,训练时候选择更新网络全部参数还是冻结部分参数值更新后续的其他部分

下面开始撸代码

2.1 创建自己的网络模型

解释说明:这里我创建了一个基于resnet50网络的模型(这个网络是干什么的在此不做解释),网络结构如下

import torch
from torch.nn import Sequential, Conv2d, MaxPool2d, ReLU, BatchNorm2d
from torch import nn
from torch.utils import model_zoo

CLASS_NUM = 20  # 使用其他训练集需要更改
class Bottleneck(nn.Module):  # 定义基本块
    def __init__(self, in_channel, out_channel, stride, downsample):
        super(Bottleneck, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.bottleneck = Sequential(

            Conv2d(in_channel, out_channel, kernel_size=1, stride=stride[0], padding=0, bias=False),
            BatchNorm2d(out_channel),
            ReLU(inplace=True),

            Conv2d(out_channel, out_channel, kernel_size=3, stride=stride[1], padding=1, bias=False),
            BatchNorm2d(out_channel),
            ReLU(inplace=True),

            Conv2d(out_channel, out_channel * 4, kernel_size=1, stride=stride[2], padding=0, bias=False),
            BatchNorm2d(out_channel * 4),
        )
        if self.downsample is False:  # 如果 downsample = True则为Conv_Block 为False为Identity_Block
            self.shortcut = Sequential()
        else:
            self.shortcut = Sequential(
                Conv2d(self.in_channel, self.out_channel * 4, kernel_size=1, stride=stride[0], bias=False),
                BatchNorm2d(self.out_channel * 4)
            )

    def forward(self, x):
        out = self.bottleneck(x)
        out += self.shortcut(x)
        out = self.relu(out)
        return out


class output_net(nn.Module):
    # no expansion
    # dilation = 2
    # type B use 1x1 conv
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, block_type='A'):
        super(output_net, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=2, bias=False, dilation=2)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)
        self.downsample = nn.Sequential()
        self.relu = nn.ReLU(inplace=True)
        if stride != 1 or in_planes != self.expansion * planes or block_type == 'B':
            self.downsample = nn.Sequential(
                nn.Conv2d(
                    in_planes,
                    self.expansion * planes,
                    kernel_size=1,
                    stride=stride,
                    bias=False),
                nn.BatchNorm2d(self.expansion * planes))

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.downsample(x)
        out = self.relu(out)
        return out


class ResNet50(nn.Module):
    def __init__(self, block):
        super(ResNet50, self).__init__()
        self.block = block
        self.layer0 = Sequential(
            Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            BatchNorm2d(64),
            ReLU(inplace=True),
            MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.layer1 = self.make_layer(self.block, channel=[64, 64], stride1=[1, 1, 1], stride2=[1, 1, 1], n_re=3)
        self.layer2 = self.make_layer(self.block, channel=[256, 128], stride1=[2, 1, 1], stride2=[1, 1, 1], n_re=4)
        self.layer3 = self.make_layer(self.block, channel=[512, 256], stride1=[2, 1, 1], stride2=[1, 1, 1], n_re=6)
        self.layer4 = self.make_layer(self.block, channel=[1024, 512], stride1=[2, 1, 1], stride2=[1, 1, 1], n_re=3)
        self.layer5 = self._make_output_layer(in_channels=2048)
        self.avgpool = nn.AvgPool2d(2)  # kernel_size = 2  , stride = 2
        self.conv_end = nn.Conv2d(256, int(CLASS_NUM + 10), kernel_size=3, stride=1, padding=1, bias=False)
        self.bn_end = nn.BatchNorm2d(int(CLASS_NUM + 10))

    def make_layer(self, block, channel, stride1, stride2, n_re):
        layers = []
        for num_layer in range(0, n_re):
            if num_layer == 0:
                layers.append(block(channel[0], channel[1], stride1, downsample=True))
            else:
                layers.append(block(channel[1] * 4, channel[1], stride2, downsample=False))
        return Sequential(*layers)

    def _make_output_layer(self, in_channels):
        layers = []
        layers.append(
            output_net(
                in_planes=in_channels,
                planes=256,
                block_type='B'))
        layers.append(
            output_net(
                in_planes=256,
                planes=256,
                block_type='A'))
        layers.append(
            output_net(
                in_planes=256,
                planes=256,
                block_type='A'))
        return nn.Sequential(*layers)

    def forward(self, x):
        # print(x.shape) # 3*448*448
        out = self.layer0(x)
        # print(out.shape) # 64*112*112
        out = self.layer1(out)
        # print(out.shape)  # 256*112*112
        out = self.layer2(out)
        # print(out.shape) # 512*56*56
        out = self.layer3(out)
        # print(out.shape) # 1024*28*28
        out = self.layer4(out)  # 2048*14*14


        out = self.layer5(out)  # batch_size*256*14*14
        out = self.avgpool(out)  # batch_size*256*7*7
        out = self.conv_end(out)  # batch_size*30*7*7
        out = self.bn_end(out)
        out = torch.sigmoid(out)
        out = out.permute(0, 2, 3, 1)  # bitch_size*7*7*30
        return out


def resnet50():
    model = ResNet50(Bottleneck)
    return model

通过下面代码,分别载入自己的网络模型和torch官方的网络模型,看看模型结构有什么不同

from torchvision import models
import torch
from new_resnet import resnet50

# 获取torch官方restnet50的预训练网络权重参数
# pretrained表示是否在内部直接载入resnet50的权重,在这里我们不载入(下载太慢了,我们先现在到本地然后自己手动载入)
resnet = models.resnet50(pretrained=False)
state_dict = torch.load(r"resnet50-0676ba61.pth")
resnet.load_state_dict(state_dict)
new_state_dict = resnet.state_dict()

# 获取自己创建的resnet50无训练的空权重
net = resnet50()
op = net.state_dict()


print(len(new_state_dict.keys()))  # 输出torch官方网络模型字典长度
print(len(op.keys()))# 输出自己网络模型字典长度

resnet50预训练权重,目标检测,计算机视觉,深度学习,人工智能
从图中可以看出,torch官方网络模型主要有320个key,我们创建的网络模型有384个key

分别输出两种key有什么不同

from torchvision import models
import torch
from new_resnet import resnet50

# 获取torch官方restnet50的预训练网络权重参数
# pretrained表示是否在内部直接载入resnet50的权重,在这里我们不载入(下载太慢了,我们先现在到本地然后自己手动载入)
resnet = models.resnet50(pretrained=False)
state_dict = torch.load(r"resnet50-0676ba61.pth")
resnet.load_state_dict(state_dict)
new_state_dict = resnet.state_dict()

# 获取自己创建的resnet50无训练的空权重
net = resnet50()
op = net.state_dict()

print(len(new_state_dict.keys()))
print(len(op.keys()))

for i in new_state_dict.keys():   # 查看网络结构的名称 并且得出一共有320个key
    print(i)
for j in op.keys():   # 查看网络结构的名称 并且得出一共有384个key
    print(j)

resnet50预训练权重,目标检测,计算机视觉,深度学习,人工智能
从图中可以看出,我们创建的网络模型和torch官方的网络模型在前318层的结构都是一样的(即网络的backbone),官方的网络模型主要使用两层全连接层做分类,因此我们预训练是不需要这两层参数的,我们只要前面的backbone参数。

2.2 权重参数的载入

两种载入方式,通过2.1可以知道,网络的backbone结构是一样的,在318层后是不一样的。通过观察网络的key可以发现,torch官方的resnet网络模型的key名字和我们自己创建的基于resnet50网络模型的key名字不一样,因此参数的载入主要有两种:

  • 当权重字典中的key名字一样时

    from torchvision import models
    import torch
    from new_resnet import resnet50
    
    # 获取torch官方restnet50的预训练网络权重参数
    # pretrained表示是否在内部直接载入resnet50的权重,在这里我们不载入(下载太慢了,我们先现在到本地然后自己手动载入)
    resnet = models.resnet50(pretrained=False)
    state_dict = torch.load(r"resnet50-0676ba61.pth")
    resnet.load_state_dict(state_dict)
    new_state_dict = resnet.state_dict()
    
    # 获取自己创建的resnet50无训练的空权重
    net = resnet50()
    op = net.state_dict()
    
    # 将new_state_dict里不属于op的键剔除掉
    pretrained_dict = {k: v for k, v in new_state_dict.items() if k in op}
    
    # 更新现有的model_dict
    op.update(pretrained_dict)
    # 加载真正需要的state_dict
    net.load_state_dict(op)
    
  • 当权重字典中的key名字不一样时

    from torchvision import models
    import torch
    from new_resnet import resnet50
    
    # 获取torch官方restnet50的预训练网络权重参数
    # pretrained表示是否在内部直接载入resnet50的权重,在这里我们不载入(下载太慢了,我们先现在到本地然后自己手动载入)
    resnet = models.resnet50(pretrained=False)
    state_dict = torch.load(r"resnet50-0676ba61.pth")
    resnet.load_state_dict(state_dict)
    new_state_dict = resnet.state_dict()
    
    # 获取自己创建的resnet50无训练的空权重
    net = resnet50()
    op = net.state_dict()
    
    # 无论名称是否相同都可以使用
    for new_state_dict_num, new_state_dict_value in enumerate(new_state_dict.values()):
        for op_num, op_key in enumerate(op.keys()):
            if op_num == new_state_dict_num and op_num <= 317:  # 320个key中不需要最后的全连接层的两个参数
                op[op_key] = new_state_dict_value
    net.load_state_dict(op)  # 更改了state_dict的值记得把它导入网络中
    
    

从上面两种方式可以看出,第二种方式更适合我们。综上所述,参数的载入构成主要分为

  1. 构建自己的网络模型,并转换成参数字典格式
  2. 创建官方的网络模型,并载入字典格式
  3. 将官方的网络模型字典于自己的网络模型字典做比较,确定需要载入的具体参数数量。
  4. 载入过后一定要导入网络中,即 net.load_state_dict(op)

2.2 训练方式选取(冻结or不冻结训练)

解释说明:预训练参数载入后,我们可以选取在网络模型训练过程过程中,我们是选取让这部分参数参与参数更新,还是不参与参数更新。

  • 如果参与参数更新的话直接进行后续的网络训练就行了,无处理操作

  • 若不参与网络的更新,需要将参与网络更新的bool值设为False. 通过key.requires_grad获取当前字典参数的参与更新状态的bool值。对应的,在训练时候,optimizer里面只能更新requires_grad = True的参数,于是文章来源地址https://www.toymoban.com/news/detail-722941.html

    from torchvision import models
    import torch
    from new_resnet import resnet50
    
    # 获取torch官方restnet50的预训练网络权重参数
    # pretrained表示是否在内部直接载入resnet50的权重,在这里我们不载入(下载太慢了,我们先现在到本地然后自己手动载入)
    resnet = models.resnet50(pretrained=False)
    state_dict = torch.load(r"resnet50-0676ba61.pth")
    resnet.load_state_dict(state_dict)
    new_state_dict = resnet.state_dict()
    
    # 获取自己创建的resnet50无训练的空权重
    net = resnet50()
    op = net.state_dict()
    
    # 无论名称是否相同都可以使用
    for new_state_dict_num, new_state_dict_value in enumerate(new_state_dict.values()):
        for op_num, op_key in enumerate(op.keys()):
            if op_num == new_state_dict_num and op_num <= 317:  # 320个key中不需要最后的全连接层的两个参数
                op[op_key] = new_state_dict_value
    net.load_state_dict(op)  # 更改了state_dict的值记得把它导入网络中
    
    for i, p in enumerate(net.parameters()): # 将前100层参数冻结
        if i < 100:
            p.requires_grad = False
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.001)
    

到了这里,关于torch中如何使用预训练权重的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • ResNet50的猫狗分类训练及预测

    相比于之前写的ResNet18,下面的ResNet50写得更加工程化一点,这还适用与其他分类,就是换一个分类训练只需要修改图片数据的路径即可。 我的代码文件结构   1. 数据处理 首先已经对数据做好了分类       文件夹结构是这样 开始划分数据集 split_data.py 运行完以上代码的到的

    2023年04月12日
    浏览(50)
  • 3D-Resnet-50 医学图像分类(二分类任务)torch代码(精简版)-图像格式为NIFTI

    img_list格式如下 E:...3.nrrd E:...3.nrrd 0 E:...4.nrrd E:...4.nrrd 1 训练代码

    2024年02月12日
    浏览(42)
  • pytorch实现AI小设计-1:Resnet50人脸68关键点检测

            本项目是AI入门的应用项目,后续可以补充内容完善作为满足个人需要。通过构建自己的人脸数据集,此项目训练集为4580张图片,测试集为2308张图片,使用resnet50网络进行训练,最后进行效果展示。本项目也提供了量化内容,便于在硬件上部署。         研究A

    2024年01月18日
    浏览(45)
  • 3D-Resnet-50 医学图像分类(二分类任务,需要mask)训练代码-图像格式为nrrd(附带验证代码)

    img_list格式如下 E:...3.nrrd E:...3.nrrd 0 E:...4.nrrd E:...4.nrrd 1 训练代码 验证代码

    2024年02月09日
    浏览(37)
  • torchvision pytorch预训练模型目标检测使用

    参考: https://pytorch.org/vision/0.13/models.html https://blog.csdn.net/weixin_42357472/article/details/131747022 有分类、检测、分割相关预训练模型 https://pytorch.org/vision/0.13/models.html#object-detection-instance-segmentation-and-person-keypoint-detection https://h-huang.github.io/tutorials/intermediate/torchvision_tutorial.html https

    2024年03月19日
    浏览(42)
  • ylov8的训练和预测使用(目标检测)

    1-配置数据集的yaml文件: 目录在ultralytics/cfg/datasets/下面: 例如我的: (这里面的yaml文件在/ultralytics/cfg/datasets下面有很多,可以找几个参考一下) 2- 配置.config/Ultralytics/settings.yaml 文件(/root/.config/Ultralytics/settings.yaml) 例如我的(更改了datasets_dir、weights_dir、runs_dir的路径):

    2024年01月21日
    浏览(27)
  • 机器学习笔记 - 基于PyTorch + 类似ResNet的单目标检测

            我们将处理年龄相关性黄斑变性 (AMD) 患者的眼部图像。          数据集下载地址,从下面的地址中,找到iChallenge-AMD,然后下载。 Baidu Research Open-Access Dataset - Download Download Baidu Research Open-Access Dataset https://ai.baidu.com/broad/download         这里也提供了百度网盘下

    2024年02月12日
    浏览(43)
  • 【pytorch】目标检测:一文搞懂如何利用kaggle训练yolov5模型

    笔者的运行环境:python3.8+pytorch2.0.1+pycharm+kaggle。 yolov5对python和pytorch版本是有要求的,python=3.8,pytorch=1.6。yolov5共有5种类型nslmx,参数量依次递增,对训练设备的要求也是递增。本文以yolov5_6s为切入点,探究yolov5如何在实战种运用。 roboflow是一个公开数据集网站,里面有很

    2024年02月12日
    浏览(50)
  • 使用YOLOv8训练自己的【目标检测】数据集

    随着深度学习技术在计算机视觉领域的广泛应用,行人检测和车辆检测等任务已成为热门研究领域。然而,实际应用中,可用的预训练模型可能并不适用于所有应用场景。 例如,虽然预先训练的模型可以检测出行人,但它无法区分“好人”和“坏人”,因为它没有接受相关的

    2024年04月10日
    浏览(54)
  • 目标检测笔记(十五): 使用YOLOX完成对图像的目标检测任务(从数据准备到训练测试部署的完整流程)

    目标检测(Object Detection)是计算机视觉领域的一项重要技术,旨在识别图像或视频中的特定目标并确定其位置。通过训练深度学习模型,如卷积神经网络(CNN),可以实现对各种目标的精确检测。常见的目标检测任务包括:人脸检测、行人检测、车辆检测等。目标检测在安防

    2024年02月09日
    浏览(46)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包