生动理解深度学习精度提升利器——测试时增强(TTA)

这篇具有很好参考价值的文章主要介绍了生动理解深度学习精度提升利器——测试时增强(TTA)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

测试时增强(Test-Time Augmentation,TTA)是一种在深度学习模型的测试阶段应用数据增强的技术手段。它是通过对测试样本进行多次随机变换或扰动,产生多个增强的样本,并使用这些样本进行预测的多数投票或平均来得出最终预测结果。

为了直观理解TTA执行的过程,这里我绘制了流程示意图如下所示:

生动理解深度学习精度提升利器——测试时增强(TTA),深度学习,人工智能

TTA的过程如下:

  1. 数据增强:

    • 在测试时,对每个测试样本应用随机的变换或扰动操作,生成多个增强样本。
    • 常用的数据增强操作包括随机翻转、随机旋转、随机裁剪、随机缩放等。这些操作可以增加样本的多样性,模拟真实世界中的不确定性和变化。
  2. 多次预测:

    • 使用训练好的模型对生成的增强样本进行多次预测。
    • 对于每个增强样本,都会得到一个预测结果。
  3. 预测结果集成:

    • 对多次预测的结果进行集成,常用的集成方式有多数投票和平均。
    • 对于分类任务,多数投票即选择预测结果中出现次数最多的类别作为最终的预测类别。对于回归任务,平均即将多次预测结果进行平均。

接下来针对性地对比分析下使用TTA带来的优点和缺点:

优点:

  • 提高鲁棒性:通过应用数据增强,TTA可以增加样本的多样性和泛化能力,提高模型在面对未见过的输入分布和未知变化时的鲁棒性。
  • 提高准确性:通过多次预测和集成,TTA可以减少预测结果的随机性和偶然误差,提高最终预测结果的稳定性和准确性。
  • 模型评估和排名:TTA可以改变模型预测的不确定性,使得模型评估更可靠,能够更好地对不同模型进行性能排名。

缺点:

  • 计算开销:生成和预测多个增强样本会增加计算量。特别是在大型模型和复杂任务中,可能导致推理时间的显著增加,限制了TTA的实际应用。
  • 可能造成过拟合:对于已包含在训练数据中的变换或扰动,如果在测试时反复应用,可能会导致模型对这些特定样本的过拟合,从而影响模型的泛化能力。

TTA是一种常用的技术手段,通过应用数据增强和集成预测结果,可以提高深度学习模型在测试阶段的性能和鲁棒性。然而,TTA的应用需要平衡计算开销和预测准确性,并谨慎处理可能导致模型过拟合的问题。根据具体任务和需求,可以灵活选择合适的增强操作和集成策略来使用TTA。

下面是demo代码实现,如下所示:

import numpy as np
import torch
import torchvision.transforms as transforms

def test_time_augmentation(model, image, n_augmentations):
    # 定义数据增强的变换
    transform = transforms.Compose([
        transforms.ToTensor(),
        # 在此添加你需要的任何其他数据增强操作
    ])

    # 存储多次预测结果的列表
    predictions = []

    # 对图像应用多次增强和预测
    for _ in range(n_augmentations):
        augmented_image = transform(image)
        augmented_image = augmented_image.unsqueeze(0)  # 增加一个维度作为批次
        with torch.no_grad():
            # 切换模型为评估模式,确保不执行梯度计算
            model.eval()
            # 使用增强的图像进行预测
            output = model(augmented_image)
            _, predicted = torch.max(output.data, 1)
            predictions.append(predicted.item())

    # 执行多数投票并返回最终预测结果
    final_prediction = np.bincount(predictions).argmax()

    return final_prediction

在前文鸟类细粒度识别项目实验中测试发现,应用TTA技术后,对应的评估指标上有明显的涨点,但是很明显地可以发现:在整个测试过程中资源消耗增加明显,且耗时显著增长,这也是TTA无法避免的劣势,在对精度要求较高的场景下可以有限考虑引入TTA,但是对于计算时耗要求较高的场景则不推荐使用TTA。

开源社区里面也有一些优秀的实现,这里推荐一个,地址在这里,如下所示:

生动理解深度学习精度提升利器——测试时增强(TTA),深度学习,人工智能

目前有将近1k的star量,还是蛮不错的。

安装方法如下所示:

pip安装:
pip install ttach


源码安装:
pip install git+https://github.com/qubvel/ttach
        Input
             |           # input batch of images 
        / / /|\ \ \      # apply augmentations (flips, rotation, scale, etc.)
       | | | | | | |     # pass augmented batches through model
       | | | | | | |     # reverse transformations for each batch of masks/labels
        \ \ \ / / /      # merge predictions (mean, max, gmean, etc.)
             |           # output batch of masks/labels
           Output

目前支持分割、分类、关键点检测三种任务,实例使用如下所示:

Segmentation model wrapping [docstring]:
import ttach as tta
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')


Classification model wrapping [docstring]:
tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform())


Keypoints model wrapping [docstring]:
tta_model = tta.KeypointsTTAWrapper(model, tta.aliases.flip_transform(), scaled=True)
data transforms 实例实现如下所示:
# defined 2 * 2 * 3 * 3 = 36 augmentations !
transforms = tta.Compose(
    [
        tta.HorizontalFlip(),
        tta.Rotate90(angles=[0, 180]),
        tta.Scale(scales=[1, 2, 4]),
        tta.Multiply(factors=[0.9, 1, 1.1]),        
    ]
)

tta_model = tta.SegmentationTTAWrapper(model, transforms)

Custom model (multi-input / multi-output)实现如下所示:

# Example how to process ONE batch on images with TTA
# Here `image`/`mask` are 4D tensors (B, C, H, W), `label` is 2D tensor (B, N)

for transformer in transforms: # custom transforms or e.g. tta.aliases.d4_transform() 
    
    # augment image
    augmented_image = transformer.augment_image(image)
    
    # pass to model
    model_output = model(augmented_image, another_input_data)
    
    # reverse augmentation for mask and label
    deaug_mask = transformer.deaugment_mask(model_output['mask'])
    deaug_label = transformer.deaugment_label(model_output['label'])
    
    # save results
    labels.append(deaug_mask)
    masks.append(deaug_label)
    
# reduce results as you want, e.g mean/max/min
label = mean(labels)
mask = mean(masks)

Transforms详情如下所示:

Transform Parameters Values
HorizontalFlip - -
VerticalFlip - -
Rotate90 angles List[0, 90, 180, 270]
Scale scales
interpolation
List[float]
"nearest"/"linear"
Resize sizes
original_size
interpolation
List[Tuple[int, int]]
Tuple[int,int]
"nearest"/"linear"
Add values List[float]
Multiply factors List[float]
FiveCrops crop_height
crop_width
int
int

支持的结果融合方法如下:文章来源地址https://www.toymoban.com/news/detail-701975.html

mean
gmean (geometric mean)
sum
max
min
tsharpen (temperature sharpen with t=0.5)

到了这里,关于生动理解深度学习精度提升利器——测试时增强(TTA)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【回归利器】提升至少50%效率的自动化测试工具

    ​开篇立意,先简单介绍一下这个工具是啥,然后说说他的特点,感兴趣的小伙伴可以直接去体验,毕竟体验也不要钱不是。 传送链接:龙测AI-TestOps云平台 首先,这款工具叫龙测 AI-TestOps 云平台,是一个专门对着 UI 自动化测试使劲的测试工具,比较创造性的提出 AI+ 机器人

    2024年02月12日
    浏览(38)
  • 提升图像分割精度:学习UNet++算法

    由于工作需要对 UNet++ 算法进行调参,对规则做较大的修改,初次涉及,有误的地方,请各位大佬指教哈。 1.1 什么是 UNet++ 算法 UNet++ 算法是基于 UNet 算法的改进版本,旨在提高图像分割的性能和效果。它由 Zhou et al. 在论文 “ UNet++: A Nested U-Net Architecture for Medical Image Segment

    2024年02月03日
    浏览(45)
  • 超级利器!Postman自动化接口测试让你提升测试效率,节省宝贵时间!

    Postman自动化接口测试 该篇文章针对已经掌握 Postman 基本用法的读者,即对接口相关概念有一定了解、已经会使用 Postman 进行模拟请求的操作。 当前环境: Window 7 - 64 Postman 版本(免费版):Chrome App v5.5.3 不同版本页面 UI 和部分功能位置会有点不同,不过影响不大。 我们先思

    2024年01月20日
    浏览(64)
  • 模型精度再被提升,统一跨任务小样本学习算法 UPT 给出解法!

    近日,阿里云机器学习平台PAI与华东师范大学高明教授团队、达摩院机器智能技术NLP团队合作在自然语言处理顶级会议EMNLP2022上发表统一多NLP任务的预训练增强小样本学习算法UPT(Unified Prompt Tuning)。这是一种面向多种NLP任务的小样本学习算法,致力于利用多任务学习和预训

    2024年02月12日
    浏览(39)
  • 探索学习和入门使用GitHub Copilot:提升代码开发的新利器

    在最近的开发工作中,发现了一个比较实用的工具,github copilot,这是一款基于人工智能的代码助手工具,旨在提供智能的代码补全和生成功能。在开发过程中能够有效减少我们在繁琐代码上所花费的时间,例如打日志。也可以帮助我们刷题。 本文将介绍如何入门使用GitHub

    2024年02月04日
    浏览(79)
  • 深度学习-双精度

    浮点数据类型主要分为双精度(Fp64)、单精度(Fp32)、半精度(FP16)。 首先来看看为什么需要混合精度。使用FP16训练神经网络,相对比使用FP32带来的优点有: 减少内存占用 :FP16的位宽是FP32的一半,因此权重等参数所占用的内存也是原来的一半,节省下来的内存可以放更

    2024年02月03日
    浏览(52)
  • 深度学习模型优化:提高训练效率和精度的技巧

    🎉欢迎来到AIGC人工智能专栏~探索Java中的静态变量与实例变量 ☆* o(≧▽≦)o *☆嗨~我是IT·陈寒🍹 ✨博客主页:IT·陈寒的博客 🎈该系列文章专栏:AIGC人工智能 📜其他专栏:Java学习路线 Java面试技巧 Java实战项目 AIGC人工智能 数据结构学习 🍹文章作者技术和水平有限,如

    2024年02月11日
    浏览(49)
  • 深度学习进行数据增强(实战篇)

    本文章是我在进行深度学习时做的数据增强,接着我们上期的划分测试集和训练集来做. 文章目录 前言 数据增强有什么好处? 一、构造数据增强函数 二、数据增强 总结 很多人在深度学习的时候在对数据的处理时一般采用先数据增强在进行对训练集和测试集的划分,其实我感觉

    2024年01月23日
    浏览(54)
  • 【深度学习:数据增强】计算机视觉中数据增强的完整指南

    可能面临的一个常见挑战是模型的过拟合。这种情况发生在模型记住了训练样本的特征,但却无法将其预测能力应用到新的、未见过的图像上。过拟合在计算机视觉中尤为重要,在计算机视觉中,我们处理高维图像输入和大型、过度参数化的深度网络。有许多现代建模技术可

    2024年02月03日
    浏览(50)
  • 要利用Java编程提升人们对安全教育的兴趣,可以开发一些互动性强、内容生动有趣的教育软件或游戏

    要利用Java编程提升人们对安全教育的兴趣,可以开发一些互动性强、内容生动有趣的教育软件或游戏。以下是一些建议: 开发安全教育游戏:使用Java编程语言,可以开发一些有关于安全教育的小游戏,如模拟火灾逃生、地震自救等场景,让玩家在游戏中学习到安全知识。

    2024年04月27日
    浏览(62)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包