INFOBATCH: LOSSLESS TRAINING SPEED UP BY UNBIASED DYNAMIC DATA PRUNING
即插即用的动态数据裁剪,加速网络训练.
ICLR 2024 Oral | InfoBatch,三行代码,无损加速,即插即用!
论文题目:
InfoBatch: Lossless Training Speed Up by Unbiased Dynamic Data Pruning
论文地址:https://arxiv.org/abs/2303.04947
代码地址:https://github.com/henryqin1997/InfoBatch
1.概述
加速训练一个比较直接的方法是降低数据集规模。如何降低数据集规避,应该剔除哪些数据,一般认为剔除那些 loss(或者其他一些指标score)较小的,因为这样的样本数据对模型训练梯度下降影响较小。
一种方式是 static prune,就是训练一些epoch后,根据损失值或者其他指标 设定 阈值,裁剪 那些小于 阈值的样本。
一种方式是 dynamic prune, 就是每隔一些epoch 根据指标排序,然后进行裁剪,整个训练过程中多次裁剪。
Meanwhile, directly pruning data may lead to a biased gradient estimation as illustrated in Fig. 1a, which affects the convergence result. This is a crucial factor that limits their performance, especially under a high pruning ratio
就是无论是静态和动态prune 数据都有一个问题,就是确实可以加速训练,但是直接剪枝数据可能导致梯度估计偏倚,如下图a所示,影响收敛结果。这是限制其性能的关键因素,特别是在高剪枝比下。
因此作者提出infobatch方法,如图1b所示。
2.原理
主要包括 soft pruning , experctation rescaling 两个步骤。
详细步骤:
每个epoch或者若干个epoch训练后需要重新裁剪数据的时候
0. 第一次裁剪 计算整个数据集的 平均损失 loss_mean,按照直接 裁剪掉 loss较小的一部分进行裁剪。(硬裁剪)
-
之后的裁剪, 每次也要计算整个数据集的 平均损失 loss_mean(被裁剪的样本(未被训练的样本)用之前的loss,未被裁剪的sample用训练更新后的loss)。这一步讲了如何更新 每个样本的损失并计算loss_mean, loss_mean其实就是自适应阈值。
-
然后小于 loss_mean的数据 按照一定的概率 r 进行prune
-
样本减少,整个数据集的 梯度会发生变化,造成与原数据集 梯度期望不一致。解决这个问题,就是 将 小于loss_mean的数据样本 梯度进行rescale, 1/(1-r) times.
-
在最后15%的epoch采用full dataset进行训练
3.实验结果
对于infobatch方法,prune比例约等于节省的时间比例,因为求loss_mean的时间开销很小:
4.三行代码
https://github.com/NUS-HPC-AI-Lab/InfoBatch
Masked Image Training for Generalizable Deep Image Denoising
1.概述
这篇文章很有意思,假如你开发一个denoise model, 但是你的数据集只是一些特定场景的特定的noise type, noise level, 你是希望model能够处理更多的场景还是能够兼顾更多的noise type(noise lvel)。
兼顾更多的场景意味着,即使你的训练集和实际使用的数据 场景差别很大,model也能有效。
兼顾更多的noise type或noise level意味着,即使你的训练集和实际使用的数据 noise type差别很大,model也能有效。
作者认为现在的denoise model是对 特定的noise type, noise level有效,即使换了场景,noise type, noise level只要不变,model仍然有效。
作者提出的denoise model是对 希望对特定的场景有效,无论什么样的noise type, noise level, 只要训练的场景和 实际使用的场景一致,那么model仍然有效。
下图可以很好的说明:
换个噪声type, swir就失效了
下图说明,换个场景(训练场景和实际使用场景有差别),本文提出的方法会失效:
作者提出的方法是利用mask, 为了更好的训练model,使model能够学习到图像语义内容,而不是noise.
2.原理
作者根据基于swir transformer结构 引入 input mask 和 attention mask
文章来源:https://www.toymoban.com/news/detail-830152.html
最后训练的model可以兼容更多的 noise type, noise level.
本人不了解swintransformer的结构,关于网络结构这里请参考:
https://github.com/haoyuc/MaskedDenoising
https://zhuanlan.zhihu.com/p/658523907文章来源地址https://www.toymoban.com/news/detail-830152.html
到了这里,关于INFOBATCH: LOSSLESS TRAINING SPEED UP BY UNBIASED DYNAMIC DATA PRUNING 和Masked Image denoised的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!