第八章 模型篇:transfer learning for computer vision

这篇具有很好参考价值的文章主要介绍了第八章 模型篇:transfer learning for computer vision。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

参考教程:
transfer-learning
transfer-learning tutorial

transfer learning

很少会有人从头开始训练一个卷积神经网络,因为并不是所有人都有机会接触到大量的数据。常用的选择是在一个非常大的模型上预训练一个模型,然后用这个模型为基础,或者固定它的参数用作特征提取,来完成特定的任务。

对卷积网络进行finetune

进行transfer-learning的一个方法是在基于大数据训练的模型上进行fine-tune。可以选择对模型的每一个层都进行fine-tune,也可以选择freeze特定的层(一般是比较浅的层)而只对模型的较深的层进行fine-tune。理论支持是,模型的浅层通常是一些通用的特征,比如edge或者colo blob,这些特征可以应用于多种类型的任务,而高层的特征则会更倾向于用于训练的原始数据集中的数据特点,因为不太能泛化到新数据上去。

把卷积网络作为特征提取器

将ConvNet作为一个特征提取器,通常是去掉它最后一个用于分类的全连接层,把剩余的层用来提取新数据的特征。你可以在该特征提取器后加上你自己的head,比如分类head或者回归head,用于完成你自己的任务。

何时、如何进行fine tune

使用哪种方法有多种因素决定,最主要的因素是你的新数据集的大小和它与原始数据集的相似度。

  • 当你的新数据集很小,并和原始数据集比较相似时。
    因为你的数据集很小,所以从过拟合的角度出发,不推荐在卷积网络上进行fine-tune。又因为你的数据和原始数据比较相似,所以卷积网络提取的高层特征和你的数据也是相关的。因此你可以直接卷积网络当作特征提取器,在此基础上训练一个线性分类器。
  • 当你的新数据集很大,并和原始数据集比较相似时。
    新数据集很大时,我们可以对整个网络进行fine-tune,因为我们不太会有过拟合的风险。
  • 当你的新数据集很小,并和原始数据集不太相似时。
    因为你的数据集很小,我们还是推荐只训练一个线性的分类器。但是新数据和原始数据又不相似,所以不建议在网络顶端接上新的分类器,因为网络顶端包含很多的dataset-specific的特征,所以更推荐的是从浅层网络的一个位置出发接上一个分类器。
  • 当你的新数据集很大,并和原始数据集不太相似时。
    因为你的数据集很大,我们仍然选择对整个网络进行fine-tune。因为通常情况下以一个pretrained-model对模型进行初始化的效果比随机初始化要好。

代码示例

我们使用与第四章 模型篇:模型训练与示例一样的流程进行模型训练。

加载数据集

首先是加载数据集,方便起见我们直接使用torchvision中的cifar10数据进行训练。

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

training_data = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=transform
)


test_data = datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

train_dataloader = DataLoader(training_data, batch_size = 64)
test_dataloader = DataLoader(test_data, batch_size = 64)

使用官方提供的代码对我们的dataset进行可视化,注意训练时使用的batchsize为64,这里可视化时为了方便暂时使用了batchsize=4。
第八章 模型篇:transfer learning for computer vision

构建模型

在第四章中我们用了自定义的model。在这里我们使用预训练好的模型,并对模型结构进行修改。

transfer-learning对模型的处理有两种,一种是fine-tune整个模型,一种是将模型作为feature-extractor。第二种和第一种的区别是,模型中的部分层被freeze,不在训练过程中更新。

fine-tune 模型

model_ft = models.resnet18(weights = 'IMAGENET1K_V1')
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 10) # 因为cifar10是十分类,所以输出这里为10

模型作为feature extractor

model_conv = torchvision.models.resnet18(weights='IMAGENET1K_V1')
for param in model_conv.parameters():
    param.requires_grad = False  # requires_grad 设为False,不随训练更新

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 10)

定义train_loop和test_loop

这两个部分直接参考第四章的代码就可以,复制过来直接使用。

# 训练过程的每个epoch的操作,代码来自pytorch_tutorial
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        optimizer.zero_grad() # 重置梯度计算
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward() # 反向传播计算梯度
        optimizer.step() # 调整模型参数
        

        if batch % 10 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

定义超参数,开始训练

全都准备好以后,我们定义一下要使用的优化器和loss,和一些别的超参数,就可以开始训练了。

learning_rate = 1e-3
momentum=0.9
epochs = 20

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate,momentum=momentum)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model_ft, loss_fn, optimizer)
    test_loop(test_dataloader, model_ft, loss_fn)
print("Done!")

因为是在个人pc跑的,所以就随便放一个效果。。。。。
第八章 模型篇:transfer learning for computer vision

结果可视化

第八章 模型篇:transfer learning for computer vision文章来源地址https://www.toymoban.com/news/detail-514893.html

到了这里,关于第八章 模型篇:transfer learning for computer vision的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 第八章:AI大模型的安全与伦理 8.2 模型安全

    随着人工智能技术的发展,AI大模型已经成为了我们生活中不可或缺的一部分。这些模型在处理大规模数据和复杂任务方面表现出色,但同时也带来了一系列安全和伦理问题。在本章中,我们将深入探讨AI大模型的安全和伦理问题,并提出一些解决方案。 AI大模型的安全问题主

    2024年02月01日
    浏览(52)
  • 统计学习导论(ISLR) 第八章树模型课后习题

    🌸个人主页:JOJO数据科学 📝个人介绍: 统计学top3 高校统计学硕士在读 💌如果文章对你有帮助,欢迎✌ 关注 、👍 点赞 、✌ 收藏 、👍 订阅 专栏 ✨本文收录于【R语言数据科学】 本系列主要介绍R语言在数据科学领域的应用包括: R语言编程基础、R语言可视化、R语言进

    2024年02月12日
    浏览(37)
  • 第八章:AI大模型的安全与伦理问题8.3 AI伦理问题

    随着人工智能(AI)技术的发展,人类社会正面临着一系列新的挑战。这些挑战不仅仅是技术上的,更多的是人类价值观、道德和伦理的面临。在这一章节中,我们将深入探讨AI伦理问题,以期帮助读者更好地理解这一领域的关键问题和挑战。 AI技术的发展为人类带来了巨大的

    2024年02月03日
    浏览(47)
  • 论文解读:(UPL)Unsupervised Prompt Learning for Vision-Language Models

    存在的问题 之前的来自目标数据集的标记数据(有监督学习)可能会限制可伸缩性。 动机 通过无监督提示学习(UPL)方法,以避免提示工程,同时提高类clip视觉语言模型的迁移性能。 主张top-k而不是top-p 注:top-k是指挑选概率最大的k个,top-p是指挑选预测概率大于p的那些数据 看

    2024年04月23日
    浏览(59)
  • 第八章:AI大模型的部署与优化8.1 模型压缩与加速8.1.2 量化与剪枝

    作者:禅与计算机程序设计艺术 8.1.1 背景介绍 随着深度学习技术的不断发展,人工智能模型的规模越来越庞大。然而,这也带来了新的问题:大模型需要更多的计算资源和存储空间,同时在移动设备上运行效率较低。因此,模型压缩与加速成为了当前研究的热点。 8.1.2 核心

    2024年03月08日
    浏览(49)
  • (数字图像处理MATLAB+Python)第八章图像复原-第一、二节:图像复原概述和图像退化模型

    图像复原 :在图像生成、记录、传输过程中,由于成像系统、设备或外在的干扰,会导致图像质量下降,称为 图像退化 ,如大气扰动效应、光学系统的像差、物体运动造成的模糊、几何失真等。图像复原是指通过使用图像处理技术来恢复受损图像的原始信息,使其尽可能接

    2024年02月12日
    浏览(72)
  • A Blockchain-Enabled Federated Learning System with Edge Computing for Vehicular Networks边缘计算和区块链

    摘要:在大多数现有的联网和自动驾驶汽车(CAV)中,从多辆车收集的大量驾驶数据被发送到中央服务器进行统一训练。然而,在数据共享过程中,数据隐私和安全没有得到很好的保护。此外,集中式体系结构还存在一些固有问题,如单点故障、过载请求、无法容忍的延迟等

    2024年02月05日
    浏览(40)
  • 【COMP9517】Computer Vision

    COMP9517: Computer Vision Objectives: This lab revisits important concepts covered in the Week 1 and Week 2 lectures and aims to make you familiar with implementing specific algorithms. Preliminaries: As mentioned in the first lecture, we assume you are familiar with programming in Python or are willing to learn it independently. You do not need to be an exp

    2024年02月02日
    浏览(36)
  • 第八章:Linux信号

    linux信号是OS的重要功能。 使用kill -l查看所有信号。使用信号时,可使用信号编号或它的宏。 1、Linux中信号共有61个,没有0、32、33号信号。 2、【1,31】号信号称为普通信号,【34,64】号信号称为实时信号。 每个信号都有一个编号和一个宏定义名称,这些宏定义可以在signal.h中

    2024年02月13日
    浏览(51)
  • 第八章 图像压缩

    数据冗余R为 R = 1 − 1 C R=1-cfrac1C R = 1 − C 1 ​ C为压缩率,定义为 C = b b ′ C=cfrac{b}{b\\\'} C = b ′ b ​ 二维灰度阵列受如下可被识别和利用的三种主要类型的数据冗余的影响: 编码冗余。编码是用于表示信息实体或事件集合的符号系统(字母、数字、比特和类似的符号等)。每个信

    2024年02月10日
    浏览(50)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包