Datawhale-AI夏令营:脑PET图像分析和疾病预测挑战赛baseline解读

这篇具有很好参考价值的文章主要介绍了Datawhale-AI夏令营:脑PET图像分析和疾病预测挑战赛baseline解读。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

这段代码是一个完整的深度学习模型训练和预测的流程。下面我会逐步解释每个步骤的作用。

首先,这段代码导入了必要的库,包括PyTorch、numpy、pandas等。接着,打印出CUDA版本和是否可用GPU,并将模型部署到GPU上(如果可用)。

接下来是数据预处理的部分。通过glob.glob函数获取训练和测试图像的路径,并对其进行随机化。然后定义了一个自定义的Dataset类XunFeiDataset,用于读取和处理图像数据。在__getitem__方法中,首先检查数据是否已经加载过,如果已经加载过则直接使用,否则通过nibabel库读取图像数据,并进行一些预处理操作(例如随机选择通道、图像增强等)。最后返回处理后的图像数据及其对应的标签。__len__方法返回数据集的大小。

接下来是数据集的处理,将数据集分为训练集、验证集和测试集。分别创建了train_loader、val_loader和test_loader三个DataLoader对象,用于加载训练、验证和测试数据。其中,在train_loader和val_loader中使用了不同的数据增强操作。

然后定义了一个自定义的CNN网络XunFeiNet,它基于预训练的ResNet34模型,并修改了输入和输出层的形状以适应特定的任务。在forward方法中,将输入数据传入ResNet模型并输出结果。

接下来是模型训练与验证的部分。train函数定义了模型的训练过程,包括前向传播、计算损失、反向传播和参数更新,并返回训练数据集上的平均损失。validate函数定义了模型的验证过程,包括前向传播和计算准确率,并返回验证数据集上的平均准确率。然后使用这两个函数分别在训练集和验证集上进行模型的训练和验证,并打印出每次迭代的损失和准确率。

最后是模型预测与提交的部分。定义了predict函数用于在测试集上进行模型预测,并返回预测结果。通过循环调用predict函数多次,然后对预测结果进行求和,得到最终的预测结果。最后将预测结果保存为CSV文件。

整个代码流程包括数据预处理、模型定义、训练与验证、模型预测与提交,实现了一个完整的深度学习模型训练和预测的流程。
baseline解读

step0:引入必要库

import torch
import os, sys, glob, argparse
import pandas as pd
import numpy as np
import albumentations as A
from tqdm import tqdm
import cv2 as cv
from PIL import Image
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset
import nibabel as nib
from nibabel.viewers import OrthoSlicer3D

step1:检验cuda是否安装成功,如果成功则部署到GPU上

print(‘CUDA版本:’, torch.cuda_version)
print(‘torch能否使用GPU:’, torch.cuda.is_available())
device = torch.device(‘cuda:0’ if torch.cuda.is_available() else ‘cpu’)

step2:数据预处理

train_path = glob.glob(‘./脑PET图像分析和疾病预测挑战赛公开数据/Train//’)
test_path = glob.glob(‘./脑PET图像分析和疾病预测挑战赛公开数据/Test/*’)

np.random.shuffle(train_path)
np.random.shuffle(test_path)

DATA_LOADER = {}

class XunFeiDataset(Dataset):
def init(self, img_path, transform=None):
self.img_path = img_path
if transform is not None:
self.transform = transform
else:
self.transform = None

def __getitem__(self, index):
    if self.img_path[index] in DATA_LOADER:
        img = DATA_LOADER[self.img_path[index]]
    else:
        img = nib.load(self.img_path[index])
        img = img.dataobj[:, :, :, 0]
        DATA_LOADER[self.img_path[index]] = img

    # 随机选择一些通道
    idx = np.random.choice(range(img.shape[-1]), 50)
    img = img[:, :, idx]
    img = img.astype(np.float32)

    if self.transform is not None:
        img = self.transform(image=img)['image']

    img = img.transpose([2, 0, 1])
    return img, torch.from_numpy(np.array(int('NC' in self.img_path[index])))

def __len__(self):
    return len(self.img_path)

step3:数据集处理(train、test、val)

train_loader = torch.utils.data.DataLoader(
XunFeiDataset(train_path[:-10],
A.Compose([
A.RandomRotate90(),
A.RandomCrop(120, 120),
A.HorizontalFlip(p=0.5),
A.RandomContrast(p=0.5),
A.RandomBrightnessContrast(p=0.5),
])
), batch_size=2, shuffle=True, num_workers=0, pin_memory=False
)

val_loader = torch.utils.data.DataLoader(
XunFeiDataset(train_path[-10:],
A.Compose([
A.RandomCrop(120, 120),
])
), batch_size=2, shuffle=False, num_workers=0, pin_memory=False
)

test_loader = torch.utils.data.DataLoader(
XunFeiDataset(test_path,
A.Compose([
A.RandomCrop(128, 128),
A.HorizontalFlip(p=0.5),
A.RandomContrast(p=0.5),
])
), batch_size=2, shuffle=False, num_workers=0, pin_memory=False
)

step4:自定义CNN网络

class XunFeiNet(nn.Module):
def init(self):
super(XunFeiNet, self).init()
model = models.resnet34(True)
model.conv1 = torch.nn.Conv2d(50, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.avgpool = nn.AdaptiveAvgPool2d(1)
model.fc = nn.Linear(512, 2)
self.resnet = model

def forward(self, img):
    out = self.resnet(img)
    return out

部署到GPU上

model = XunFeiNet().to(device)
Loss_Function = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

step5:模型训练与验证

def train(train_loader, model, Loss_Function, optimizer):
“”"

:param train_loader: 脑PET图像分析和疾病预测中Train部分数据
:param model: Resnet34
:param Loss_Function:交叉熵损失函数
:param optimizer: SGD优化算法

:return:train_loss
"""
model.train()
train_loss = 0.0
for i, (input, target) in enumerate(train_loader):
    input = input.to(device)
    target = target.to(device)

    output = model(input)
    loss = Loss_Function(output, target.long())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # if i % 20 == 0:
    #     print(loss.item())

    train_loss += loss.item()

return train_loss / len(train_loader)

def validate(val_loader, model, Loss_Function):
“”"

:param val_loader: 脑PET图像分析和疾病预测中Train部分数据
:param model: Resnet34
:param Loss_Function:交叉熵损失函数

:return: val_acc
"""
model.eval()
val_acc = 0.0

with torch.no_grad():
    for i, (input, target) in enumerate(val_loader):
        input = input.to(device)
        target = target.to(device)

        # compute output
        output = model(input)
        loss = Loss_Function(output, target.long())

        val_acc += (output.argmax(1) == target).sum().item()

return val_acc / len(val_loader.dataset)

迭代次数

num = 0
for _ in range(30):
num += 1
train_loss = train(train_loader, model, Loss_Function, optimizer)
val_acc = validate(val_loader, model, Loss_Function)
train_acc = validate(train_loader, model, Loss_Function)
print(f’第{num}次\n训练模型的损失:{train_loss}\t训练正确率:{train_acc}\t验证正确率{val_acc}\n’)

step6:模型预测与提交

def predict(test_loader, model, Loss_Function):
model.eval()
val_acc = 0.0

test_pred = []
with torch.no_grad():
    for i, (input, target) in enumerate(test_loader):
        input = input.to(device)
        target = target.to(device)
        output = model(input)
        test_pred.append(output.data.cpu().numpy())

return np.vstack(test_pred)

pred = None
for i in range(50):
if pred is None:
pred = predict(test_loader, model, Loss_Function)
else:
pred += predict(test_loader, model, Loss_Function)

print(‘预测成功,正在生成.csv文件’)
submit = pd.DataFrame(
{
‘uuid’: [int(x.split(‘\’)[-1][:-4]) for x in test_path],
‘label’: pred.argmax(1)
})
submit[‘label’] = submit[‘label’].map({1: ‘NC’, 0: ‘MCI’})
submit = submit.sort_values(by=‘uuid’)
submit.to_csv(‘submit.csv’, index=None)文章来源地址https://www.toymoban.com/news/detail-605262.html

到了这里,关于Datawhale-AI夏令营:脑PET图像分析和疾病预测挑战赛baseline解读的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • DataWhale 机器学习夏令营第二期——AI量化模型预测挑战赛 学习记录

    DataWhale 机器学习夏令营第二期 ——AI量化模型预测挑战赛 已跑通baseline,线上得分 0.51138 , 跑通修改后进阶代码,线上得分 0.34497 按照鱼佬直播分享按照以下常见思路分析机器学习竞赛: 1.1 赛事数据 数据集情况 给定数据集 : 给定训练集(含验证集), 包括10只(不公开)

    2024年02月11日
    浏览(41)
  • 【Datawhale夏令营】任务二学习笔记

    目录 一:python语法回顾 1.1  print() 1.2  列表与字典 1.3自定义函数与return 1.4火车类(面向对象)  实例化总结: 二:LightGBM 代码精读 2.1导入库 2.2数据准备与参数设置  2.3时间特征函数   2.4优化  2.5训练与预测 三:优化讲解 3.1: 3.2优化建议: 一:python语法回顾 1.1  print

    2024年02月14日
    浏览(50)
  • DataWhale 机器学习夏令营第三期

    DataWhale 机器学习夏令营第三期 ——用户新增预测挑战赛 已跑通baseline,换为lightgbm基线,不加任何特征线上得分 0.52214 ; 添加baseline特征,线上得分 0.78176 ; 暴力衍生特征并微调模型参数,线上得分 0.86068 赛题数据由约62万条训练集、20万条测试集数据组成,共包含13个字段

    2024年02月12日
    浏览(47)
  • DataWhale 机器学习夏令营第三期——任务二:可视化分析

    DataWhale 机器学习夏令营第三期 ——用户新增预测挑战赛 2023.08.17 已跑通baseline,换为lightgbm基线,不加任何特征线上得分 0.52214 ; 添加baseline特征,线上得分 0.78176 ; 暴力衍生特征并微调模型参数,线上得分 0.86068 2023.08.23 数据分析、衍生特征: 0.87488 衍生特征、模型调参:

    2024年02月11日
    浏览(42)
  • AI夏令营笔记——任务2

    任务要求与任务1一样: 从论文标题、摘要作者等信息,判断该论文是否属于医学领域的文献。 可以将任务看作是一个文本二分类任务。机器需要根据对论文摘要等信息的理解,将论文划分为医学领域的文献和非医学领域的文献两个类别之一。 使用预训练的大语言模型进行建

    2024年02月11日
    浏览(36)
  • AI夏令营第三期用户新增挑战赛学习笔记

    通过pd库的df.info()方法查看数据框属性,发现只有udmap字段为类别类型,其余皆为数值类型。 相关性热力图颜色越深代表相关性越强,所以x7和x8变量之间的关系更加密切,还有common_ts与x6也是。即存在很强的多重共线性,进行特征工程时可以考虑剔除二者中的一个变量,以免

    2024年02月11日
    浏览(38)
  • acm夏令营课后题(持续更新)

                                            米有程序题就懒得写哩           acm夏令营贪心算法选题_李卓航哇哇咔~的博客-CSDN博客  上面这个自己写的不知道为什么错哩,在网上找了下面这个。   这道题写的很顺      (这道题答案来源于网上)  网上答案写的很详细

    2024年02月13日
    浏览(43)
  • 考研保研、夏令营推免的简历模板

      本文介绍在保研夏令营、考研复试等场景中, 个人简历 的制作模板与撰写注意事项。   这里就将当初我自己的简历分享一下,供大家参考。其实我的简历是那种比较简单、质朴的,通篇就一个颜色,没有太多花里胡哨的部分。我个人感觉,对于读研、升学而言,其实

    2024年02月04日
    浏览(45)
  • 【NVIDIA CUDA】2023 CUDA夏令营编程模型(二)

    博主未授权任何人或组织机构转载博主任何原创文章,感谢各位对原创的支持! 博主链接 本人就职于国际知名终端厂商,负责modem芯片研发。 在5G早期负责终端数据业务层、核心网相关的开发工作,目前牵头6G算力网络技术标准研究。 博客内容主要围绕:        5G/6G协议

    2024年02月10日
    浏览(37)
  • 北京大学2014计算机学科夏令营上机考试

    暴力必超时  利用栈的思想,利用一个(模仿栈)的数组,遇到男孩则入栈(即加入数组),记录当前位置(更新相对下标、绝对下表); 而遇到女孩,则出栈(男孩相对下标--),输出女孩与男孩的绝对位置。 2014计算机学科夏令营上机考试 B:排队游戏 找规律……#¥%……

    2024年02月12日
    浏览(53)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包