【计算机视觉 | Pytorch】timm 包的具体介绍和图像分类案例(含源代码)

这篇具有很好参考价值的文章主要介绍了【计算机视觉 | Pytorch】timm 包的具体介绍和图像分类案例(含源代码)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

一、具体介绍

timm 是一个 PyTorch 原生实现的计算机视觉模型库。它提供了预训练模型和各种网络组件,可以用于各种计算机视觉任务,例如图像分类、物体检测、语义分割等等。

timm 的特点如下:

  1. PyTorch 原生实现:timm 的实现方式与 PyTorch 高度契合,开发者可以方便地使用 PyTorchAPI 进行模型训练和部署。
  2. 轻量级的设计:timm 的设计以轻量化为基础,根据不同的计算机视觉任务,提供了多种轻量级的网络结构。
  3. 大量的预训练模型:timm 提供了大量的预训练模型,可以直接用于各种计算机视觉任务。
  4. 多种模型组件:timm 提供了各种模型组件,如注意力模块、正则化模块、激活函数等等,这些模块都可以方便地插入到自己的模型中。
  5. 高效的代码实现:timm 的代码实现高效并且易于使用。

需要注意的是,timm 是一个社区驱动的项目,它由计算机视觉领域的专家共同开发和维护。在使用时需要遵循相关的使用协议。

二、图像分类案例

下面以使用 timm 实现图像分类任务为例,进行简单的介绍。

2.1 安装 timm 包

!pip install timm

2.2 导入相关模块,读取数据集

import torch
import torch.nn as nn
import timm
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10

# 数据增强
train_transforms = transforms.Compose([
    transforms.RandomCrop(size=32, padding=4),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# 数据集
train_dataset = CIFAR10(root='data', train=True, download=True, transform=train_transforms)
test_dataset = CIFAR10(root='data', train=False, download=True, transform=test_transforms)

# DataLoader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

导入相关模块,其中 timmtorchvision.datasets.CIFAR10 需要分别安装 timmtorchvision 包。

定义数据增强的方式,其中训练集和测试集分别使用不同的增强方式,并且对图像进行了归一化处理。transforms.Compose() 可以将各种操作打包成一个 transform 操作流,transforms.ToTensor() 将图像转化为 tensor 格式,transforms.Normalize() 将图像进行标准化处理。

使用自带的 CIFAR10 数据集,设置 train=True 定义训练集,设置 train=False 定义测试集。数据集会自动下载到指定的 root 路径下,并进行数据增强操作。

使用 torch.utils.data.DataLoader 定义数据加载器,将数据集包装成一个高效的可迭代对象,其中 batch_size 定义批次大小,shuffle 定义是否对数据进行随机洗牌,num_workers 定义使用多少个 worker 来加载数据。

python timm,计算机视觉,深度学习笔记,Python3常用到的函数总结,计算机视觉,pytorch,分类,timm

2.3 定义模型

# 加载预训练模型
model = timm.create_model('resnet18', pretrained=True)

# 修改分类器
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(train_dataset.classes))

这里使用 timm.create_model() 函数来创建一个预训练模型,其中参数 resnet18 定义了使用的模型架构,参数 pretrained = True 表示要使用预训练权重。

这里修改了模型的分类器,首先使用 model.fc.in_features 获取模型 fc 层的输入特征数,然后使用 nn.Linear() 重新定义了一个 nn.Linear 层,输入为上一层的输出特征数,输出为类别数(即 len(train_dataset.classes))。这里直接使用了数据集类别数来定义输出层,以适配不同分类任务的需求。

python timm,计算机视觉,深度学习笔记,Python3常用到的函数总结,计算机视觉,pytorch,分类,timm
在这里,我们使用了 timm 中的 ResNet18 模型,并将其修改为我们需要的分类器,同时在创建模型时,设置参数 pretrained=True 来加载预训练权重。

2.4 定义损失函数和优化器

# 损失函数
criterion = nn.CrossEntropyLoss()

# 优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

在深度学习中,损失函数是评估模型预测结果与真实标签之间差异的一种指标,常用于模型训练过程中。nn.CrossEntropyLoss() 是一个常用的损失函数,适用于多分类问题。

优化器用于更新模型参数以使损失函数最小化。在这里,我们使用了随机梯度下降法(SGD)优化器,以控制模型权重的变化。通过 model.parameters() 指定需要优化的参数,lr 定义了学习率,表示每次迭代时参数必须更新的量的大小,momentum 则是添加上次迭代更新值的一部分到这一次的更新值中,以减小参数更新的方差,稳定训练过程。

2.5 训练模型

num_epochs = 10

for epoch in range(num_epochs):
    # 训练
    model.train()
    for images, labels in train_loader:
        # 前向传播
        outputs = model(images)
        # 计算损失
        loss = criterion(outputs, labels)
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # 测试
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print('Epoch {} Accuracy: {:.2f}%'.format(epoch+1, 100*correct/total))

这段代码是模型训练和测试的循环。num_epochs 定义了循环的次数,每次循环表示一个训练周期。

在训练阶段,首先将模型切换到训练模式,然后使用 train_loader 迭代地读取训练集数据,进行前向传播、计算损失、反向传播和优化器更新等操作。

在测试阶段,模型切换到评估模式,然后使用 test_loader 读取测试集数据,进行前向传播和计算模型预测结果,使用预测结果和真实标签进行准确率计算,并输出每个训练周期的准确率。

其中,torch.max() 函数用于返回每行中最大值及其索引,total 记录了总的测试样本数,correct 记录了正确分类的样本数,最后计算准确率并输出。

输出结果为:

python timm,计算机视觉,深度学习笔记,Python3常用到的函数总结,计算机视觉,pytorch,分类,timm文章来源地址https://www.toymoban.com/news/detail-611716.html

到了这里,关于【计算机视觉 | Pytorch】timm 包的具体介绍和图像分类案例(含源代码)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 计算机视觉技能干货分享——Pytorch图像分类系列教程

    作者:禅与计算机程序设计艺术 计算机视觉(Computer Vision,CV)是指研究如何使电脑从各种输入(如图像、视频)中捕获、分析和处理信息,并在人类可理解的形式上展示出来。它包括目标检测、图像分割、图像跟踪、图像风格化、人脸识别等多个子领域。它的发展始于20世纪

    2024年02月06日
    浏览(28)
  • 计算机视觉介绍

    计算机视觉是指使用计算机技术对数字图像或视频进行处理与分析,以实现对图像内容的理解、认知、识别、分类、跟踪、测量、重构等功能。以下是计算机视觉基础入门教程的内容: 1.数字图像处理 数字图像处理是计算机视觉的基础,因此需要了解数字图像的基本概念和操

    2024年02月01日
    浏览(28)
  • 计算机视觉常用数据集介绍

    MINIST 数据集应该算是CV里面最早流行的数据了,相当于CV领域的Hello World。该数据包含70000张手写数字图像,其中60000张用于train, 10000张用于test, 并且都有相应的label。图像的尺寸比较小, 为28x28。 数据说明及下载地址: http://yann.lecun.com/exdb/mnist/ 这个数据是由 Yann LeCun 创建

    2024年02月14日
    浏览(32)
  • 【计算机视觉|人脸识别】 facenet-pytorch 项目中文说明文档

    下文搬运自GitHub,很多超链接都是相对路径所以点不了,属正常现象。 点击查看原文档 。转载请注明出处。 Click here to return to the English document 译者注: 本项目 facenet-pytorch 是一个十分方便的人脸识别库,可以通过 pip 直接安装。 库中包含了两个重要功能 人脸检测:使用MT

    2024年02月04日
    浏览(32)
  • 计算机视觉一 —— 介绍与环境安装

    傲不可长 欲不可纵 乐不可极 志不可满 研究理论和应用 - 研究如何使机器“看”的科学 - 让计算机具有人类视觉的所有功能 - 让计算机从图像中,提取有用的信息,并解释 - 重构人眼;重构视觉皮层;重构大脑剩余部分 计算机视觉学习图 学习重点 1. 各种深度神经网络模型(

    2024年02月13日
    浏览(31)
  • 【计算机视觉】干货分享:Segmentation model PyTorch(快速搭建图像分割网络)

    如何快速搭建图像分割网络? 要手写把backbone ,手写decoder 吗? 介绍一个分割神器,分分钟搭建一个分割网络。 仓库的地址: 该库的主要特点是: 高级 API(只需两行即可创建神经网络) 用于二元和多类分割的 9 种模型架构(包括传奇的 Unet) 124 个可用编码器(以及 timm

    2024年02月14日
    浏览(37)
  • 【Pytorch】计算机视觉项目——卷积神经网络CNN模型识别图像分类

    在上一篇笔记《【Pytorch】整体工作流程代码详解(新手入门)》中介绍了Pytorch的整体工作流程,本文继续说明如何使用Pytorch搭建卷积神经网络(CNN模型)来给图像分类。 其他相关文章: 深度学习入门笔记:总结了一些神经网络的基础概念。 TensorFlow专栏:《计算机视觉入门

    2024年02月05日
    浏览(39)
  • OpenCV第 1 课 计算机视觉和 OpenCV 介绍

      我们人类可以通过眼睛看到五颜六色的世界,是因为人眼的视觉细胞中存在分别对红、绿、蓝敏感的 3 种细胞。其中的光感色素根据光线的不同进行不同比例的分解,从而让我们识别到各种颜色。   对人工智能而言,学会“ 看 ”也是非常关键的一步。那么机器人是如

    2024年01月24日
    浏览(36)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包