手把手教你用pytorch实现k折交叉验证,解决类别不平衡

这篇具有很好参考价值的文章主要介绍了手把手教你用pytorch实现k折交叉验证,解决类别不平衡。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

在用深度学习做分类的时候,常常需要进行交叉验证,目前pytorch没有通用的一套代码来实现这个功能。可以借助 sklearn中的 StratifiedKFold,KFold来实现,其中StratifiedKFold可以根据类别的样本量,进行数据划分。以5折为例,它可以实现每个类别的样本都是4:1划分。

代码简单的示例如下:

from sklearn.model_selection import  StratifiedKFold
skf = StratifiedKFold(n_splits=5)
for i, (train_idx, val_idx) in enumerate(skf.split(imgs, labels)):
    trainset, valset = np.array(imgs)[[train_idx]],np.array(imgs)[[val_idx]]
    traintag, valtag = np.array(labels)[[train_idx]],np.array(labels)[[val_idx]]

以上示例是将所有imgs列表与对应的labels列表进行split,得到train_idx代表训练集的下标,val_idx代表验证集的下标。后续代码只需要将split完成的trainset与valset输入dataset即可。

接下来用我自己数据集的实例来完整地实现整个过程,即从读取数据,到开始训练。如果你的数据集存储方式和我不同,改一下数据读取代码即可。关键是如何获取到imgs和对应的labels。

我的数据存储方式是这样的(类别为文件夹名,属于该类别的图像在该文件夹下):

"""A generic data loader where the images are arranged in this way: ::

    root/dog/xxx.png
    root/dog/xxy.png
    root/dog/xxz.png

    root/cat/123.png
    root/cat/nsdf3.png
    root/cat/asd932_.png

 以下代码是获取imgs与labels的过程:

import os
import numpy as np

IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png')

def is_image_file(filename):
    return filename.lower().endswith(IMG_EXTENSIONS)

def find_classes(dir):
    classes = [d.name for d in os.scandir(dir) if d.is_dir()]
    classes.sort()
    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

if __name__ == "__main__":
    dir = 'your root path'
    classes, class_to_idx = find_classes(dir)
    imgs = []
    labels = []
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        target_dir = os.path.join(dir, target_class)
        if not os.path.isdir(target_dir):
            continue
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                if is_image_file(path):
                    imgs.append(path)
                    labels.append(class_index)

上述代码只需要把dir改为自己的root路径即可。接下来对所有数据进行5折split。其中我自己写了MyDataset类,可以直接照搬用。

from sklearn.model_selection import  StratifiedKFold
    skf = StratifiedKFold(n_splits=5) #5折
    for i, (train_idx, val_idx) in enumerate(skf.split(imgs, labels)):
        trainset, valset = np.array(imgs)[[train_idx]],np.array(imgs)[[val_idx]]
        traintag, valtag = np.array(labels)[[train_idx]],np.array(labels)[[val_idx]]
        train_dataset = MyDataset(trainset, traintag, data_transforms['train'] )
        val_dataset = MyDataset(valset, valtag, data_transforms['val'])
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader


class MyDataset(Dataset):

    def __init__(self, imgs, labels, transform=None,target_transform=None):

        self.imgs = imgs
        self.labels = labels
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        path = self.imgs[idx]
        target = self.labels[idx]

        with open(path, 'rb') as f:
            img = Image.open(f)
            img = img.convert('RGB')

        if self.transform:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

有了数据集之后,就可以创建dataloader了,后面就是正常的训练代码:

from sklearn.model_selection import  StratifiedKFold
    skf = StratifiedKFold(n_splits=5) #5折
    for i, (train_idx, val_idx) in enumerate(skf.split(imgs, labels)):
        trainset, valset = np.array(imgs)[[train_idx]],np.array(imgs)[[val_idx]]
        traintag, valtag = np.array(labels)[[train_idx]],np.array(labels)[[val_idx]]
        train_dataset = MyDataset(trainset, traintag, data_transforms['train'] )
        val_dataset = MyDataset(valset, valtag, data_transforms['val'])
        train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
                                                  shuffle=True, num_workers=args.workers)
        test_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
                                                  shuffle=True, num_workers=args.workers)

        # define model
        model = resnet18().cuda()
        # define criterion
        criterion = torch.nn.CrossEntropyLoss()
        # Observe that all parameters are being optimized.
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
        for epoch in range(args.epoch):
            train_acc, train_loss = train(train_dataloader, model, criterion, args)
            test_acc, tect_acc_top5, test_loss = validate(test_dataloader, model, criterion, args)

为了保证每次跑的时候分的数据都是一致的,注意shuffle=False(默认)

StratifiedKFold(n_splits=5,shuffle=False)

以上就是实现的基本代码,之所以在代码层面实现k折而不是在数据层面做,比如预先把数据等分为5份。是因为这个代码可以支持数据样本的随意增减,不需要人为地再去分数据,十分方便。 文章来源地址https://www.toymoban.com/news/detail-447854.html

到了这里,关于手把手教你用pytorch实现k折交叉验证,解决类别不平衡的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 手把手教你用AirtestIDE无线连接手机

    一直以来,我们发现同学们都挺喜欢用无线的方式连接手机,正好安卓11出了个无线连接的新姿势,我们今天就一起来看看,如何用AirtestIDE无线连接你的Android设备~ 当 手机与电脑处在同一个wifi 下,即可尝试无线连接手机了,但是这种方式受限于网络连接的稳定性,可能会出

    2023年04月18日
    浏览(53)
  • 手把手教你用Python编写邮箱脚本引擎

    版权声明:原创不易,本文禁止抄袭、转载需附上链接,侵权必究! 邮箱是传输信息方式之一,个人,企业等都在使用,朋友之间发消息,注册/登录信息验证,订阅邮箱,企业招聘,向客户发送消息等都是邮箱的使用场景;邮箱有两个较重要的协议:SMTP和POP3,均位于OSI7层

    2024年02月06日
    浏览(51)
  • 手把手教你用jmeter做压力测试(详图)

    压力测试是每一个Web应用程序上线之前都需要做的一个测试,他可以帮助我们发现系统中的瓶颈问题,减少发布到生产环境后出问题的几率;预估系统的承载能力,使我们能根据其做出一些应对措施。所以压力测试是一个非常重要的步骤,下面我带大家来使用一款压力测试工

    2024年02月02日
    浏览(46)
  • 手把手教你用 Jenkins 自动部署 SpringBoot

    CI/CD 是一种通过在应用开发阶段引入自动化来频繁向客户交付应用的方法。 CI/CD 的核心概念可以总结为三点: 持续集成 持续交付 持续部署 CI/CD 主要针对在集成新代码时所引发的问题(俗称\\\"集成地狱\\\")。 为什么会有集成地狱这个“雅称”呢?大家想想我们一个项目部署的

    2024年02月02日
    浏览(48)
  • 手把手教你用UNet做医学图像分割系统

    兄弟们好呀,这里是肆十二,这转眼间寒假就要过完了,相信大家的毕设也要准备动手了吧,作为一名大作业区的UP主,也该蹭波热度了,之前关于图像分类和目标检测我们都出了相应的教程,所以这期内容我们搞波新的,我们用Unet来做医学图像分割。我们将会以皮肤病的数

    2024年02月03日
    浏览(71)
  • 手把手教你用git上传项目到GitHub

    github的官方网址:https://github.com ,如果没有账号,赶紧注册一个。 点击Sign in进入登录界面,输入账号和密码登入github。 创建成功可以看到自己的仓库地址,如此,我的远程免费的仓库就创建了。它还介绍了github仓库的常用指令。这个指令需要在本地安装git客户端。 Git是目

    2024年01月18日
    浏览(49)
  • 手把手教你用Git——详解git merge

    关于本教程的编写环境 本文基于 Windows10系统 , Mac 系统的小伙伴可以尝试 Homebrew 。由于本人手里并没有搭载 MacOS 的电脑,因此 Homebrew 相关的使用请自行尝试。 对于使用 Windows11系统 的小伙伴,本文的教程是通用的,不过一些细节可能略有不同,这点希望小伙伴们注意一下

    2024年02月05日
    浏览(48)
  • 手把手教你用MindSpore训练一个AI模型!

    首先我们要先了解深度学习的概念和AI计算框架的角色( https://zhuanlan.zhihu.com/p/463019160 ),本篇文章将演示怎么利用MindSpore来训练一个AI模型。和上一章的场景一致,我们要训练的模型是用来对手写数字图片进行分类的LeNet5模型 请参考( http://yann.lecun.com/exdb/lenet/ )。 图1 M

    2024年02月04日
    浏览(56)
  • 爬虫实战|手把手教你用Python爬虫(附详细源码)

    实践来源于理论,做爬虫前肯定要先了解相关的规则和原理,要知道互联网可不是法外之地,你一顿爬虫骚操作搞不好哪天就…  首先,咱先看下爬虫的定义:网络爬虫(又称为网页蜘蛛,网络机器人,在FOAF社区中间,更经常的称为网页追逐者),是一种按照一定的规则,自

    2024年02月02日
    浏览(79)
  • 手把手教你用Python编写配置脚本引擎(福利篇)

    版权声明:原创不易,本文禁止抄袭、转载需附上链接,侵权必究! 配置信息初始化 定义配置引擎类和初始化方法,其中有两个属性,配置实例对象及配置文件路径: 将配置信息写入到配置文件中,该方法有三个形参,category(配置信息类别),name(配置字段名称),value(配置字

    2024年02月06日
    浏览(69)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包