权重衰减weight_decay参数从入门到精通

这篇具有很好参考价值的文章主要介绍了权重衰减weight_decay参数从入门到精通。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

本文内容

Weight Decay是一个正则化技术,作用是抑制模型的过拟合,以此来提高模型的泛化性。

目前网上对于Weight Decay的讲解都比较泛,都是短短的几句话,但对于其原理、实现方式大多就没有讲解清楚,本文将会逐步解释weight decay机制。

1. 什么是权重衰减(Weight Decay)

Weight Decay是一个正则化技术,作用是抑制模型的过拟合,以此来提高模型的泛化性。

它是通过给损失函数增加模型权重L2范数的惩罚(penalty)来让模型权重不要太大,以此来减小模型的复杂度,从而抑制模型的过拟合。

看完上面那句话,可能很多人已经蒙圈了,这是在说啥。后面我会逐步进行解释,将会逐步回答以下问题:

  1. 什么是正则化?
  2. Weight Decay的减小模型参数的思想
  3. L1范数惩罚项和L2范数惩罚项是什么?
  4. 为什么Weight Decay参数是在优化器上,而不是在Loss上。
  5. weight decay的调参技巧

2. 什么是正则化?

正则化的目标是减小方差或是说减小数据扰动所造成的影响。 我们来看下图来理解一下这句话:

权重衰减weight_decay参数从入门到精通

这幅图是随着训练次数,训练Loss和验证Loss的变化曲线。上面那条线是验证集的。很明显,这个模型出现了过拟合,因为随着训练次数的增加,训练Loss在下降,但是验证Loss却在上升。这里我们会引出三个概念:

  1. 方差(Variance):刻画数据扰动所造成的影响。
  2. 偏差(Bias):刻画学习算法本身的拟合能力。
  3. 噪声(Noise):当前任务任何学习算法能达到的期望泛化误差的下界。也就是数据的噪声导致一定会出现的那部分误差。

通常不考虑噪声,所以偏差和噪声合并称为偏差。

2.1 什么数据扰动

上面说方差是“刻画数据扰动所造成的影响”,我们可以通过下面例子来理解这句话。

假设我们要预测一个 y = x y=x y=x 的模型:

权重衰减weight_decay参数从入门到精通

绿色的线是真正的模型 y = x y=x y=x,蓝色的点是训练数据,红色的线是预测出的模型。这个训练数据点距离真实模型的偏离程度就是数据扰动

如果我们使用数据扰动较小的数据,那么预测模型结果就会和真正模型的差距较小,例如:

权重衰减weight_decay参数从入门到精通

当我们数据扰动越大,预测模型距离实际模型的差距就会越大。因此,我们减小过拟合就是让预测模型和真实模型尽可能的一致。通常有两种做法:

  1. 增加数据量和使用更好的数据。这也是最推荐的做法
  2. 然而,通常我们很难收集到更多的数据,所以此时就需要一些正则化技术来减小“数据扰动”对模型预测带来的影响

3. 减小模型权重

权重衰减(Weight Decay)就是减小模型的权重大小,而减小模型的权重大小就可以降低模型的复杂度,使模型变得平滑,进而减小过拟合。

假设我们的模型为: y = w 0 + w 1 x + w 2 x 2 + w 2 x 2 + ⋯ + w n x n y = w_0 + w_1 x + w_2x^2 + w_2x^2 + \cdots +w_nx^n y=w0+w1x+w2x2+w2x2++wnxn,模型的参数为 W = ( w 0 , w 1 , w 2 , ⋯   , w n ) W=(w_0, w_1, w_2, \cdots, w_n) W=(w0,w1,w2,,wn)

我们使用该模型根据一些训练数据点可能会学到如下的两种曲线:

权重衰减weight_decay参数从入门到精通

很明显,蓝色的曲线显然过拟合了。如果我们观察 W W W 的话会发现,蓝色曲线的参数通常都比较大,而绿色曲线的参数通常都比较小。

上面只是直观的说一下。结论就是:模型权重数值越小,模型的复杂度越低

该结论可以通过实验观察出来,也可以通过数学证明。(李沐说可以证明,感兴趣的同学可以搜一下)

4. 为Loss增加惩罚项

上面说了Weight Decay目的是要让模型权重小一点(控制在某一个范围内),以此来减小模型的复杂性,从而抑制过拟合。

而Weight Decay的具体做法就是在Loss后面增加一个权重的L2范数惩罚项。

4.1 通过公式理解Weight Decay

Weight Decay的具体公式就是:

L = L 0 + λ 2 ∣ ∣ W ∣ ∣ 2 L = L_0 + \frac{\lambda}{2}||W||^2 L=L0+2λ∣∣W2

其中 L 0 L_0 L0 是原本的Loss, λ \lambda λ 是一个超参,负责控制权重衰减的强弱。 ∣ ∣ W ∣ ∣ 2 ||W||^2 ∣∣W2 为模型参数的2范数的平方。

具体的,假设我们的模型有 n n n 个参数,即 W = [ w 1 , w 2 , ⋯   , w n ] W=[w_1, w_2, \cdots, w_n] W=[w1,w2,,wn],则 L L L 为:

L = L 0 + λ 2 ( w 1 2 + w 2 2 + ⋯ + w n 2 ) 2 = L 0 + λ 2 ( w 1 2 + w 2 2 + ⋯ + w n 2 ) \begin{aligned} L &= L_0 + \frac{\lambda}{2}\left( \sqrt{w_1^2+w_2^2+\cdots+w_n^2} \right) ^2 \\\\ &= L_0 + \frac{\lambda}{2}(w_1^2+w_2^2+\cdots+w_n^2) \end{aligned} L=L0+2λ(w12+w22++wn2 )2=L0+2λ(w12+w22++wn2)

从上面的公式,我们可以很明显的得到如下结论:

  1. 模型的权重越大,Loss就会越大。
  2. λ \lambda λ 越大,权重衰减的就越厉害
  3. λ \lambda λ 过大,那么原本Loss的占比就会较低,最后模型就光顾着让模型权重变小了,最终模型效果就会变差。

4.2 通过图像理解Weight Decay

接下来我们用图像来感受一下Weight Decay。假设我们的模型只有两个参数W1和W2,W1和W2与Loss=2有如下关系:

权重衰减weight_decay参数从入门到精通

这个绿色的椭圆表示,当W1和W2取绿色椭圆上的点时,Loss都是2。所以,当我们没有惩罚项时,对于Loss=2,取椭圆上的这些点都可以。若取到右上角的点,那么 W1和W2 的值就会比较大,所以我们希望W1和W2尽量往左下靠。

因为我们的惩罚项是 w 1 2 + w 2 2 w_1^2 + w_2^2 w12+w22,我们将其图像画出来( w 1 2 + w 2 2 = X w_1^2 + w_2^2= X w12+w22=X)。

权重衰减weight_decay参数从入门到精通

上图我绘制了三条橘色图像,分别为

w 1 2 + w 2 2 = X 1 w_1^2 + w_2^2= X_1 w12+w22=X1,与椭圆无焦点。
w 1 2 + w 2 2 = X 2 w_1^2 + w_2^2= X_2 w12+w22=X2,与椭圆交于A点
w 1 2 + w 2 2 = X 3 w_1^2 + w_2^2= X_3 w12+w22=X3,与椭圆交于B,C两点

从上图可以看到,在不改变原Loss的情况下,(W1, W2)落在A点时,惩罚项最小,即 w 1 2 + w 2 2 w_1^2 + w_2^2 w12+w22 最小。

所以,我们增加2范数的惩罚,会让模型参数变小。

为什么1范数不好

可能有些同学比较好奇,为什么不取1范数,我们同样用图可以表示出来。我们将上述的2范数图像变成1范数图像(即 ∣ w 1 ∣ + ∣ w 2 ∣ = X |w_1| +|w_2|=X w1+w2=X):

权重衰减weight_decay参数从入门到精通
上图我绘制了三条橘色图像,分别为

∣ w 1 ∣ + ∣ w 2 ∣ = X 1 |w_1| + |w_2|= X_1 w1+w2=X1,与椭圆无焦点。
∣ w 1 ∣ + ∣ w 2 ∣ = X 2 |w_1| + |w_2|= X_2 w1+w2=X2,与椭圆交于A点
∣ w 1 ∣ + ∣ w 2 ∣ = X 3 |w_1| + |w_2|= X_3 w1+w2=X3,与椭圆交于B,C两点

与2范数同理,在不改变原Loss的情况下,(W1, W2)落在A点时,惩罚项最小,即 ∣ w 1 ∣ + ∣ w 2 ∣ |w_1| + |w_2| w1+w2 最小。

但这里有个问题,我们发现此时 w 1 w_1 w1 变成 0 了。这就是为什么我们通常不用1范数,因为1范数会倾向于让一部分权重变成0。

更高的范数同理,可以参考“什么是范数(Norm),其具有哪些性质”这篇博客来感受一下每个范数不同的图像,然后将其套到上面的图中,感受一下其他范数。

5. Weight Decay的实现

通常我们在使用Weight Decay是在优化器(Optimizer)上,这就很奇怪了,上面明明都是在说Loss,为什么weight decay参数是在优化器上呢?

这是因为它们是等价的。这个很容易推导,我们用SGD来举例,SGD的更新参数的过程为:

w i ← w i − γ ∂ L ∂ w i w_i \gets w_i - \gamma \frac{\partial L}{\partial w_i} wiwiγwiL

其中 γ \gamma γ 是学习率。

我们将 L = L 0 + λ 2 ∣ ∣ W ∣ ∣ 2 L = L_0 + \frac{\lambda}{2}||W||^2 L=L0+2λ∣∣W2 带进来求一下可得:

w i − γ ∂ L ∂ w i = w i − γ ( ∂ L 0 ∂ w i + λ w i ) \begin{aligned} & w_i - \gamma \frac{\partial L}{\partial w_i} \\\\ = & w_i - \gamma (\frac{\partial L_0}{\partial w_i} + \lambda w_i) \end{aligned} =wiγwiLwiγ(wiL0+λwi)

其中 ∂ L 0 ∂ w i \frac{\partial L_0}{\partial w_i} wiL0 就是原本的梯度,所以我们为Loss增加L2正则项只需要在更新参数时,给模型的梯度加一个 λ w i \lambda w_i λwi 即可。

对应Pytorch的实现如下图:

权重衰减weight_decay参数从入门到精通

6. weight_decay的一些trick

  1. weight_decay并没有你想想中的那么好,它的效果可能只有一点点,不要太指望它。尤其是当你的模型很复杂时,权重衰退的效果可能会更小了。
  2. 通常取1e-3,如果要尝试的话,一般也就是1e-2, 1e-3, 1e-4 这些选项。
  3. 权重衰退通常不对bias做。但通常bias做不做权重衰退其实效果差不多,不过最好不要做。
  4. weight_decay取值越大,对抑制模型的强度越大。但这并不说明越大越好,太大的话,可能会导致模型欠拟合。

针对第三点:对于一个二维曲线,bias只是让曲线整体上下移动,并不能减小模型的复杂度,所以通常不需要对bias做正则化。

参考资料

正则化之weight_decay(深度之眼): https://www.bilibili.com/video/BV1HB4y1i7Fn

权重衰退(李沐): https://www.bilibili.com/video/BV1UK4y1o7dy

从拉格朗日乘数法角度理解L1L2正则: https://www.bilibili.com/video/BV1Z44y147xA文章来源地址https://www.toymoban.com/news/detail-433618.html

到了这里,关于权重衰减weight_decay参数从入门到精通的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【Python】解决CNN中训练权重参数不匹配size mismatch for fc.weight,size mismatch for fc.bias

    目录 1.问题描述 2.问题原因 3.问题解决 3.1思路1——忽视最后一层权重 额外说明:假如载入权重不写strict=False, 直接是model.load_state_dict(pre_weights, strict=False),会报错找不到key? 解决办法是:加上strict=False,这个语句就是指忽略掉模型和参数文件中不匹配的参数 3.2思路2——更

    2023年04月14日
    浏览(25)
  • 【PyTorch】权重衰减

    通过对模型过拟合的思考,人们希望能通过某种工具 调整模型复杂度 ,使其达到一个合适的平衡位置。 权重衰减(又称 L 2 L_2 L 2 ​ 正则化)通过为损失函数 添加惩罚项 ,用来惩罚权重的 L 2 L_2 L 2 ​ 范数,从而限制模型参数值,促使模型参数更加稀疏或更加集中,进而调

    2024年02月04日
    浏览(34)
  • 深度学习学习笔记——解决过拟合问题的方法:权重衰减和暂退法,与正则化之间的关系

    解决过拟合问题是机器学习和深度学习中关键的任务之一,因为它会导致模型在训练数据上表现良好,但在未见数据上表现不佳。以下是一些解决过拟合问题的常见方法: 增加训练数据 : 增加更多的训练数据可以帮助模型更好地捕捉数据的真实分布,减少过拟合的可能性。

    2024年02月09日
    浏览(42)
  • Verilog权重轮询仲裁器设计——Weighted Round Robin Arbiter

    前两篇讲了固定优先级仲裁器的设计、轮询仲裁器的设计 Verilog固定优先级仲裁器——Fixed Priority Arbiter_weixin_42330305的博客-CSDN博客 Verilog轮询仲裁器设计——Round Robin Arbiter_weixin_42330305的博客-CSDN博客 权重轮询仲裁器就是在轮询仲裁器的基础上,当grant次数等于weight时,再切换

    2024年02月14日
    浏览(30)
  • nn.BCEWithLogitsLoss中weight参数和pos_weight参数的作用及用法

    上式是nn.BCEWithLogitsLoss损失函数的计算公式,其中w_n对应weight参数。 如果我们在做多分类任务,有些类比较重要,有些类不太重要,想要模型更加关注重要的类别,那么只需将比较重要的类所对应的w权重设置大一点,不太重要的类所对应的w权重设置小一点。 下面是一个代码

    2024年01月23日
    浏览(23)
  • C++从入门到精通——缺省参数

    缺省参数是在函数定义时指定的默认值,当调用函数时未提供该参数的值时,将使用缺省值。使用缺省参数可以简化函数调用,提高代码可读性。但需注意,过多使用缺省参数可能导致代码难以理解和维护。 缺省参数是声明或定义函数时为函数的参数指定一个缺省值。在调用

    2024年04月10日
    浏览(44)
  • Ceph入门到精通-sysctl参数优化

    sysctl.conf  是一个文件,通常用于在 Linux 操作系统中配置内核参数。这些参数可以控制网络、文件系统、内存管理等各方面的行为。 99-xx.yml  可能是一个文件名,其中  99-  是一个特定的命名约定。在  sysctl.conf  文件中,通常会有一个特定的顺序来加载配置项。通常,以 

    2024年02月10日
    浏览(27)
  • binary_cross_entropy_with_logits中的weight参数与pos_weight参数

    根据官方给出的binary_cross_entropy_with_logits函数的二分类交叉熵损失计算公式: 其中, N代表batch大小。 可以看到,weight参数代表每个样本的权重。 根据官方对pos_weight参数的解释:a weight of positive examples to be broadcasted with target. Must be a tensor with equal size along the class dimension to the

    2024年04月09日
    浏览(21)
  • Ceph入门到精通-Nginx超时参数分析设置

    nginx中有些超时设置,本文汇总了nginx中几个超时设置 Nginx 中的超时设置包括: “client_body_timeout”:设置客户端向服务器发送请求体的超时时间,单位为秒。 “client_header_timeout”:设置客户端向服务器发送请求头的超时时间,单位为秒。 “send_timeout”:设置服务器向客户端

    2024年02月07日
    浏览(32)
  • 『Linux从入门到精通』第 ⑮ 期 - main函数的三个参数你见过吗?

    🌸作者简介: 花想云 ,在读本科生一枚,C/C++领域新星创作者,新星计划导师,阿里云专家博主,CSDN内容合伙人…致力于 C/C++、Linux 学习。 🌸 专栏简介:本文收录于 Linux从入门到精通 ,本专栏主要内容为本专栏主要内容为Linux的系统性学习,专为小白打造的文章专栏。

    2024年02月16日
    浏览(28)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包