前言
今天在这里纪录一下如何对torch网络的层进行更改:变更,增加,删除与查找
这里拿VGG16网络举例,先看一下网络结构
import torch
import torch.nn as nn
from torchvision import models
net = models.vgg11(pretrained=True)
一、在网络中添加一层:
net网络是一个树型结构, net下面有三个结点,分别是(features, avgpoll, classifier), 我们先在features结点添加一层’lastlayer’层
net.features.add_module('lastlayer', nn.Conv2d(512,512, kernel_size=3, stride=1, padding=1))
- 在classifier结点添加一个线性层:
net.classifier.add_module('Linear', nn.Linear(1000, 10))
二、修改网络中的某一层
- 以features 结点举例
net.features[8] = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
- 以classifier结点举例
net.classifier[6] = nn.Linear(1000, 5)
注意: 这里我尝试对Linear这一层进行更新, 但是Linear名字是字符串, 提取不出来,所以应该在之前添加网络时候, 名字不要取字符串, 否则会报错 ‘ 'str' object cannot be interpreted as an integer’。
三、网络层的删除
方法一:使用关键字del删除层(推荐)
删除前
model = prepare_vitmodel('mae_visualize_vit_large_ganloss.pth', 'vit_large_patch16')
del model.head # 删除层
model
删除后
方法二:将层设置为空层
以features举例 classifier结点的操作相同,这里直接使用nn.Sequential()对改层设置为空即可
net.features[13] = nn.Sequential()
文章来源:https://www.toymoban.com/news/detail-649834.html
四、网络层的切片
net.features = nn.Sequential(*list(net.features.children())[:-4])
可以看到后面4层被去除了, 就是说可以使用列表切片的方法来删除网络层
net.classifier 对应 net.classifier.children()
net.features 对应 net.features.children()
文章来源地址https://www.toymoban.com/news/detail-649834.html
五、网络层的冻结
#冻结指定层的预训练参数:
net.feature[26].weight.requires_grad = False
到了这里,关于pytorch对网络层的增加,删除,变更和切片的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!