3D-Resnet-50 医学图像分类(二分类任务)torch代码(精简版)-图像格式为NIFTI

这篇具有很好参考价值的文章主要介绍了3D-Resnet-50 医学图像分类(二分类任务)torch代码(精简版)-图像格式为NIFTI。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

1. 需要有GPU(推荐8G以上),已设置好CUDA:基于win10深度学习环境配置(conda,python,cuda11.7,torch1.13.0)_dr_yingli的博客-CSDN博客2. 文件格式为常见的nii

img_list格式如下

E:\...\3.nrrd E:\...\3.nrrd 0
E:\...\4.nrrd E:\...\4.nrrd 1

训练代码文章来源地址https://www.toymoban.com/news/detail-528056.html

import torch
from torch import nn
import os
import numpy as np
from torch.utils.data import Dataset
import nibabel
from scipy import ndimage
from torch import optim
from torch.utils.data import DataLoader
import time
import logging
root_dir = './data'  # type=str, help='Root directory path of data'
img_list = './data/train.txt'  # type=str, help='Path for image list file'
pretrain_path = 'pretrain/resnet_50.pth'  # type=str, help='Path for pretrained model.'
save_folder = "./trails/models/Resnet50"
total_epochs = 20  # type=int, help='Number of total epochs to run'
save_intervals = 10  # type=int, help='Interation for saving model'
learning_rate = 0.001  # set to 0.001 when finetune, type=float, help= 'Initial learning rate (divided by 10 while training by lr scheduler)'
new_layer_names = ['conv_cls'] # type=list, help='New layer except for backbone'
batch_size = 1  # type=int, help='Batch Size'
input_D = 56  # type=int, help='Input size of depth'
input_H = 448  # type=int, help='Input size of height'
input_W = 448  # type=int, help='Input size of width'
torch.manual_seed(1)
class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)
        self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False)
        self.bn2 = nn.BatchNorm3d(planes)
        self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm3d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out
class ResNet(nn.Module):
    def __init__(self, block, layers, input_D, input_H, input_W):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv3d(1, 64, kernel_size=7, stride=(2, 2, 2), padding=(3, 3, 3), bias=False)
        self.bn1 = nn.BatchNorm3d(64) # conv1的输出维度
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) # H/2,W/2。C不变
        self.layer1 = self._make_layer(block, 64, layers[0]) # H,W不变。downsample控制的shortcut,out_channel=64x4=256
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2) # H/2, W/2。downsample控制的shortcut,out_channel=128x4=512
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) # H/2, W/2。downsample控制的shortcut,out_channel=256x4=1024
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) # H/2, W/2。downsample控制的shortcut,out_channel=512x4=2048
        self.conv_cls = nn.Sequential(
            nn.AdaptiveMaxPool3d(output_size=(1, 1, 1)),
            nn.Flatten(start_dim=1),
            nn.Dropout(0.1),
            nn.Linear(512 * block.expansion, 1)
        )
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
           downsample = nn.Sequential(
            nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm3d(planes * block.expansion))
        layers = []
        layers.append(block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample))
        self.inplanes = planes * block.expansion # 在下一次调用_make_layer函数的时候,self.in_channel已经x4
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation))
        return nn.Sequential(*layers) # '*'的作用是将list转换为非关键字参数传入
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.conv_cls(x)
        x = torch.sigmoid_(x)
        return x
def generate_model(input_D, input_H, input_W, pretrain_path):
    model = ResNet(Bottleneck, [3, 4, 6, 3], input_W=input_W, input_H=input_H, input_D=input_D)
    model = model.cuda()
    net_dict = model.state_dict()
    print('loading pretrained model {}'.format(pretrain_path))
    pretrain = torch.load(pretrain_path)
    pretrain_dict = {k.replace("module.", ""): v for k, v in pretrain['state_dict'].items() if k.replace("module.", "") in net_dict.keys()}
    net_dict.update(pretrain_dict) # 字典 dict2 的键/值对更新到 dict 里。
    model.load_state_dict(net_dict) # model.load_state_dict()函数把加载的权重复制到模型的权重中去
    new_parameters = []
    for pname, p in model.named_parameters():
        for layer_name in new_layer_names:
            if pname.find(layer_name) >= 0:
                new_parameters.append(p)
                break
    new_parameters_id = list(map(id, new_parameters))
    base_parameters = list(filter(lambda p: id(p) not in new_parameters_id, model.parameters()))
    parameters = {'base_parameters': base_parameters, 'new_parameters': new_parameters}
    return model, parameters
model, parameters = generate_model(input_D, input_H, input_W, pretrain_path)
params = [{'params': parameters['base_parameters'], 'lr': learning_rate },{'params': parameters['new_parameters'], 'lr': learning_rate*100}]
optimizer = torch.optim.SGD(params, momentum=0.9, weight_decay=1e-3)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
class Dataset(Dataset):
    def __init__(self, root_dir, img_list, input_D, input_H, input_W):
        with open(img_list, 'r') as f:
            self.img_list = [line.strip() for line in f]
        print("Processing {} datas".format(len(self.img_list)))
        self.root_dir = root_dir
        self.input_D = input_D
        self.input_H = input_H
        self.input_W = input_W
    def __nii2tensorarray__(self, data):
        [z, y, x] = data.shape
        new_data = np.reshape(data, [1, z, y, x])
        new_data = new_data.astype("float32")
        return new_data
    def __len__(self):
        return len(self.img_list)
    def __getitem__(self, idx):
        # read image and labels
        ith_info = self.img_list[idx].split(" ")
        img_name = os.path.join(self.root_dir, ith_info[0])
        label_name = os.path.join(self.root_dir, ith_info[1])
        class_array = np.zeros(1)
        class_array[0] = ith_info[2]
        class_array = torch.tensor(class_array, dtype=torch.float32)  ######
        assert os.path.isfile(img_name)
        assert os.path.isfile(label_name)
        img = nibabel.load(img_name)  # We have transposed the data from WHD format to DHW
        assert img is not None
        mask = nibabel.load(label_name)
        assert mask is not None
        # data processing
        img_array, mask_array = self.__training_data_process__(img, mask)
        # 2 tensor array
        img_array = self.__nii2tensorarray__(img_array)
        mask_array = self.__nii2tensorarray__(mask_array)
        assert img_array.shape == mask_array.shape, "img shape:{} is not equal to mask shape:{}".format(img_array.shape, mask_array.shape)
        return img_array, mask_array, class_array  #####
    def __drop_invalid_range__(self, volume, label=None):
        """
        Cut off the invalid area
        """
        zero_value = volume[0, 0, 0]
        non_zeros_idx = np.where(volume != zero_value)
        [max_z, max_h, max_w] = np.max(np.array(non_zeros_idx), axis=1)
        [min_z, min_h, min_w] = np.min(np.array(non_zeros_idx), axis=1)
        if label is not None:
            return volume[min_z:max_z, min_h:max_h, min_w:max_w], label[min_z:max_z, min_h:max_h, min_w:max_w]
        else:
            return volume[min_z:max_z, min_h:max_h, min_w:max_w]
    def __random_center_crop__(self, data, label):
        from random import random
        """
        Random crop
        """
        target_indexs = np.where(label > 0)
        [img_d, img_h, img_w] = data.shape
        [max_D, max_H, max_W] = np.max(np.array(target_indexs), axis=1)
        [min_D, min_H, min_W] = np.min(np.array(target_indexs), axis=1)
        [target_depth, target_height, target_width] = np.array([max_D, max_H, max_W]) - np.array([min_D, min_H, min_W])
        Z_min = round((min_D - target_depth * 1.0 / 2) * random())
        Y_min = round((min_H - target_height * 1.0 / 2) * random())
        X_min = round((min_W - target_width * 1.0 / 2) * random())
        Z_max = round(img_d - ((img_d - (max_D + target_depth * 1.0 / 2)) * random()))
        Y_max = round(img_h - ((img_h - (max_H + target_height * 1.0 / 2)) * random()))
        X_max = round(img_w - ((img_w - (max_W + target_width * 1.0 / 2)) * random()))
        Z_min = np.max([0, Z_min])
        Y_min = np.max([0, Y_min])
        X_min = np.max([0, X_min])
        Z_max = np.min([img_d, Z_max])
        Y_max = np.min([img_h, Y_max])
        X_max = np.min([img_w, X_max])
        Z_min = round(Z_min)
        Y_min = round(Y_min)
        X_min = round(X_min)
        Z_max = round(Z_max)
        Y_max = round(Y_max)
        X_max = round(X_max)
        return data[Z_min: Z_max, Y_min: Y_max, X_min: X_max], label[Z_min: Z_max, Y_min: Y_max, X_min: X_max]
    def __itensity_normalize_one_volume__(self, volume):
        """
        normalize the itensity of an nd volume based on the mean and std of nonzeor region
        inputs:
            volume: the input nd volume
        outputs:
            out: the normalized nd volume
        """
        pixels = volume[volume > 0]
        mean = pixels.mean()
        std = pixels.std()
        out = (volume - mean) / std
        out_random = np.random.normal(0, 1, size=volume.shape)
        out[volume == 0] = out_random[volume == 0]
        return out
    def __resize_data__(self, data):
        """
        Resize the data to the input size
        """
        [depth, height, width] = data.shape
        scale = [self.input_D * 1.0 / depth, self.input_H * 1.0 / height, self.input_W * 1.0 / width]
        data = ndimage.zoom(data, scale, order=0)
        return data
    def __crop_data__(self, data, label):
        """
        Random crop with different methods:
        """
        # random center crop
        data, label = self.__random_center_crop__(data, label)
        return data, label
    def __training_data_process__(self, data, label):
        # crop data according net input size
        data = data.get_fdata()
        label = label.get_fdata()
        # drop out the invalid range
        data, label = self.__drop_invalid_range__(data, label)
        # crop data
        data, label = self.__crop_data__(data, label)
        # resize data
        data = self.__resize_data__(data)
        label = self.__resize_data__(label)
        # normalization datas
        data = self.__itensity_normalize_one_volume__(data)
        return data, label
training_dataset = Dataset(root_dir=root_dir, img_list=img_list, input_D=input_D, input_H=input_H, input_W=input_W)
data_loader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
logging.basicConfig(format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG)
log = logging.getLogger()
def train(data_loader, model, optimizer, scheduler, total_epochs, save_interval, save_folder):
    batches_per_epoch = len(data_loader)
    log.info('{} epochs in total, {} batches per epoch'.format(total_epochs, batches_per_epoch))
    loss_seg = nn.BCELoss()# nn.CrossEntropyLoss(ignore_index=-1)   #
    loss_seg = loss_seg.cuda()
    model.train()
    train_time_sp = time.time()
    for epoch in range(total_epochs):
        log.info('Start epoch {}'.format(epoch))
        log.info('lr = {}'.format(scheduler.get_lr()))
        for batch_id, batch_data in enumerate(data_loader):
            # getting data batch
            batch_id_sp = epoch * batches_per_epoch
            volumes, label_masks, class_array = batch_data  #####
            volumes = volumes.cuda()
            class_array = class_array.cuda()  #####
            optimizer.zero_grad()
            out_masks = model(volumes)
            print(volumes.shape)
            # calculating loss
            loss_value_seg = loss_seg(out_masks, class_array)  #####
            loss = loss_value_seg
            loss.requires_grad_(True)  #####
            loss.backward()
            optimizer.step()
            scheduler.step()
            avg_batch_time = (time.time() - train_time_sp) / (1 + batch_id_sp)
            log.info('Batch: {}-{} ({}), loss = {:.3f}, loss_seg = {:.3f}, avg_batch_time = {:.3f}' \
                     .format(epoch, batch_id, batch_id_sp, loss.item(), loss_value_seg.item(), avg_batch_time))
            # save model
            if batch_id == 0 and batch_id_sp != 0 and batch_id_sp % save_interval == 0:
                # if batch_id_sp != 0 and batch_id_sp % save_interval == 0:
                model_save_path = '{}_epoch_{}_batch_{}.pth.tar'.format(save_folder, epoch, batch_id)
                model_save_dir = os.path.dirname(model_save_path)
                if not os.path.exists(model_save_dir):
                    os.makedirs(model_save_dir)
                log.info('Save checkpoints: epoch = {}, batch_id = {}'.format(epoch, batch_id))
                torch.save({'ecpoch': epoch,'batch_id': batch_id,'state_dict': model.state_dict(),'optimizer': optimizer.state_dict()},model_save_path)
    print('Finished training')
train(data_loader=data_loader, model=model, optimizer=optimizer, scheduler=scheduler, total_epochs=total_epochs, save_interval=save_intervals, save_folder=save_folder)

到了这里,关于3D-Resnet-50 医学图像分类(二分类任务)torch代码(精简版)-图像格式为NIFTI的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • Pytorch迁移学习使用Resnet50进行模型训练预测猫狗二分类

    目录   1.ResNet残差网络 1.1 ResNet定义  1.2 ResNet 几种网络配置  1.3 ResNet50网络结构 1.3.1 前几层卷积和池化 1.3.2 残差块:构建深度残差网络 1.3.3 ResNet主体:堆叠多个残差块 1.4 迁移学习猫狗二分类实战 1.4.1 迁移学习 1.4.2 模型训练 1.4.3 模型预测   深度学习在图像分类、目标检

    2024年02月16日
    浏览(67)
  • ResNet50的猫狗分类训练及预测

    相比于之前写的ResNet18,下面的ResNet50写得更加工程化一点,这还适用与其他分类,就是换一个分类训练只需要修改图片数据的路径即可。 我的代码文件结构   1. 数据处理 首先已经对数据做好了分类       文件夹结构是这样 开始划分数据集 split_data.py 运行完以上代码的到的

    2023年04月12日
    浏览(44)
  • 文献速递:生成对抗网络医学影像中的应用—— CG-3DSRGAN:用于从低剂量PET图像恢复图像质量的分类指导的3D生成对抗网络

    本周给大家分享文献的主题是生成对抗网络(Generative adversarial networks, GANs)在医学影像中的应用。文献的研究内容包括同模态影像生成、跨模态影像生成、GAN在分类和分割方面的应用等。生成对抗网络与其他方法相比展示出了优越的数据生成能力,使它们在医学图像应用中广

    2024年02月04日
    浏览(44)
  • 计算机视觉的应用8-基于ResNet50对童年数码宝贝的识别与分类

    大家好,我是微学AI,今天给大家介绍一下计算机视觉的应用8-基于ResNet50对童年数码宝贝的识别与分类,想必做完90后的大家都看过数码宝贝吧,里面有好多类型的数码宝贝,今天就给大家简单实现一下,他们的分类任务。 引言 ResNet50模型简介 ResNet50模型原理 ResNet50模型的应

    2024年04月28日
    浏览(37)
  • 图像分类:Pytorch图像分类之--ResNet模型

    前言  ResNet 网络是在 2015年 由微软实验室提出,斩获当年ImageNet竞赛中分类任务第一名,目标检测第一名。获得COCO数据集中目标检测第一名,图像分割第一名。 原论文地址:Deep Residual Learning for Image Recognition(作者是CV大佬何凯明团队) ResNet创新点介绍 在ResNet网络中创新点

    2023年04月11日
    浏览(33)
  • 机器学习笔记 - 使用 ResNet-50 和余弦相似度的基于图像的推荐系统

    一、简述         这里的代码主要是基于图像的推荐系统,该系统利用 ResNet-50 深度学习模型作为特征提取器,并采用余弦相似度来查找给定输入图像的最相似嵌入。         该系统旨在根据所提供图像的视觉内容为用户提供个性化推荐。 二、所需环境 Python 3.x tenso

    2024年02月12日
    浏览(36)
  • 计算机视觉的应用4-目标检测任务:利用Faster R-cnn+Resnet50+FPN模型对目标进行预测

    大家好,我是微学AI,今天给大家介绍一下计算机视觉的应用4-目标检测任务,利用Faster Rcnn+Resnet50+FPN模型对目标进行预测,目标检测是计算机视觉三大任务中应用较为广泛的,Faster R-CNN 是一个著名的目标检测网络,其主要分为两个模块:Region Proposal Network (RPN) 和 Fast R-CNN。我

    2024年02月05日
    浏览(50)
  • 基于ResNet-18实现Cifar-10图像分类

    安耀辉,男,西安工程大学电子信息学院,22级研究生 研究方向:小样本图像分类算法 电子邮箱:1349975181@qq.com 张思怡,女,西安工程大学电子信息学院,2022级研究生,张宏伟人工智能课题组 研究方向:机器视觉与人工智能 电子邮件:981664791@qq.com CIFAR-10 数据集由 60000张图

    2024年02月06日
    浏览(44)
  • 清华青年AI自强作业hw6:基于ResNet实现IMAGENET分类任务

    一起学AI系列博客:目录索引 hw6作业为基于ResNet模型,并利用VGG标准模块和GoogleNet中的inception模块对IMAGENET数据集进行20类分类。模型输入图像尺寸为 299*299 ,输出为softmax后的20分类。 观察参考代码发现需要使用IMAGENET处理好后的数据ILSVRC2012_20_tfrecord,由于缺乏实验数据,本

    2024年02月12日
    浏览(29)
  • Resnet实现CIFAR-10图像分类 —— Mindspore实践

            计算机视觉是当前深度学习研究最广泛、落地最成熟的技术领域,在手机拍照、智能安防、自动驾驶等场景有广泛应用。从2012年AlexNet在ImageNet比赛夺冠以来,深度学习深刻推动了计算机视觉领域的发展,当前最先进的计算机视觉算法几乎都是深度学习相关的。深

    2024年02月07日
    浏览(38)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包