深度学习笔记--解决GPU显存使用量不断增加的问题

这篇具有很好参考价值的文章主要介绍了深度学习笔记--解决GPU显存使用量不断增加的问题。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

目录

1--问题描述

2--问题解决

3--代码


1--问题描述

        基于 Pytorch 使用 VGG16 预训练模型进行分类预测时,出现 GPU 显存使用量不断增加,最终出现 cuda out of memory 的问题;

        出现上述问题的原因在于:输入数据到网络模型进行推理时,会默认构建计算图,便于后续反向传播进行梯度计算。而构建完整的计算图,会增加计算和累积内存消耗,从而导致 GPU显存使用量不断增加;

        由于博主只使用 VGG16 预训练模型进行分类预测,不需要训练和反向传播更新参数,所以不用构建完整的计算图。

2--问题解决

        在推理代码中增加以下指令,表明当前计算不需要进行反向传播,即强制不进行完整计算图的构建:

with torch.no_grad():
    ...
    ...

3--代码

        问题代码:

def extract_rgb_feature(rgb_data):
    data = rgb_data.to(device_id[0]) # [40, 40, 3]
    data = data.permute(2, 0, 1).unsqueeze(0) # [1, 3, 40, 40]
    data = F.interpolate(data, size = (224, 224), mode='nearest').float() #[1, 3, 224, 224]
    data = model(data) # [1, linear_Class]
    return data

        修正代码:

def extract_rgb_feature(rgb_data):
    with torch.no_grad():
        data = rgb_data.to(device_id[0]) # [40, 40, 3]
        data = data.permute(2, 0, 1).unsqueeze(0) # [1, 3, 40, 40]
        data = F.interpolate(data, size = (224, 224), mode='nearest').float() #[1, 3, 224, 224]
        data = model(data) # [1, linear_Class]
        return data

        完整代码:文章来源地址https://www.toymoban.com/news/detail-582128.html

from torchvision import models
import torch.nn as nn
import torch
import numpy as np
import cv2
import torch.nn.functional as F

class My_Net(nn.Module):
    def __init__(self, linear_Class):
        super(My_Net, self).__init__()
        self.linear_Class = linear_Class
        self.backbone = models.vgg16(pretrained=True) # 以 vgg16 作为 backbone
        self.backbone = self.process_backbone(self.backbone) # 对预训练模型进行处理
 
        self.linear1 = nn.Linear(in_features = 4096, out_features = self.linear_Class)
 
    def process_backbone(self, model):
 
        # 固定预训练模型的参数
        for param in model.parameters():
            param.requires_grad = False
        
        # 删除最后预测层    
        del model.classifier[6]
 
        return model
    
    def forward(self, x):
        x = self.backbone(x)
        x = self.linear1(x)
        return x

linear_Class = 2
device_id = [7]
model = My_Net(linear_Class).to(device_id[0]) # 初始化模型

def extract_rgb_feature(rgb_data):
    with torch.no_grad():
        data = rgb_data.to(device_id[0]) # [40, 40, 3]
        data = data.permute(2, 0, 1).unsqueeze(0) # [1, 3, 40, 40]
        data = F.interpolate(data, size = (224, 224), mode='nearest').float() #[1, 3, 224, 224]
        data = model(data) # [1, linear_Class]
        return data
                    
     
if __name__ == "__main__":

    CSub_train_txt_path = '../statistics/CSub_train.txt'
    CSub_test_txt_path = '../statistics/CSub_test.txt'
    
    CSub_train_data_path = './2J_rgb_patch_npy_file_40x40/CSub/train/'
    CSub_test_data_path = './2J_rgb_patch_npy_file_40x40/CSub/test/'
    
    CSub_train_txt = np.loadtxt(CSub_train_txt_path, dtype = str)
    CSub_test_txt = np.loadtxt(CSub_test_txt_path, dtype = str)
    
    CSub_train_save_path = './pre_vgg_feature/2J/CSub/train.npy'
    CSub_test_save_path = './pre_vgg_feature/2J/CSub/test.npy'
    
    save_data = []
    
    for (idx, name) in enumerate(CSub_test_txt):
        data_path = CSub_test_data_path + name + '.npy' 
        rgb_data = np.load(data_path) # T, M, N, H, W, C
        rgb_data = torch.from_numpy(rgb_data)#.to(device = device_id[0])
        
        T, M, N, H, W, C = rgb_data.shape
        Output = torch.zeros(T, M, N, 1, linear_Class)
        
        for t in range(T):
            for m in range(M):
                for n in range(N):
                    data = extract_rgb_feature(rgb_data[t, m, n])
                    Output[t, m, n] = data.cpu()
                    
        save_data.append(Output) 
        print("Processing " + name + ", Done !")
        
    np.save(CSub_test_save_path, save_data)
        
    print("All done!")

到了这里,关于深度学习笔记--解决GPU显存使用量不断增加的问题的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包