关于nn.CrossEntropyLoss交叉熵损失中weight和ignore_index参数

这篇具有很好参考价值的文章主要介绍了关于nn.CrossEntropyLoss交叉熵损失中weight和ignore_index参数。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

目录

1. 交叉熵损失 CrossEntropyLoss

2. ignore_index 参数

3. weight 参数

4. 例子


1. 交叉熵损失 CrossEntropyLoss

CrossEntropyLoss 交叉熵损失可函数以用于分类或者分割任务中,这里主要介绍分割任务

建立如下的数据,pred是预测样本,label是真实标签

分割中,使用交叉熵损失的话,需要保证label的维度比pred维度少1,也就是没有channel维度。并且,label的类型是int

关于nn.CrossEntropyLoss交叉熵损失中weight和ignore_index参数

正常计算损失结果为:

关于nn.CrossEntropyLoss交叉熵损失中weight和ignore_index参数

手动计算一下,pred的softmax为

关于nn.CrossEntropyLoss交叉熵损失中weight和ignore_index参数

所以,loss = -(ln0.69+ln0.3543+ln0.5987)/3 = -(ln0.1464) / 3 = 0.6406 

后面的是计算产生的误差,这里用数学方法简化计算了

one-hot 编码,只计算label的 ln 预测值

2. ignore_index 参数

在分割任务中,经常有像素点是认为不感兴趣的,所以这里ignore_index可以将那些不感兴趣的像素点排除

import torch
import torch.nn as nn
import torch.nn.functional as F


pred = torch.Tensor([[0.9, 0.1],[0.8, 0.2],[0.7, 0.3]])     # 预测值 size = 3*2, dtype = torch.float32
label = torch.LongTensor([0, 1, 0])                         # 真实值 size = 3 , dtype = torch.int64
loss = nn.CrossEntropyLoss(ignore_index=1)
out = loss(pred,label)
print(out)      # tensor(0.4421)

这里将label = 1的像素点排除,手动计算一下

loss = (-ln0.69-ln0.5987) / 2 = 0.4421 

这里将label = 1的忽略了,下面是pred的softmax值

关于nn.CrossEntropyLoss交叉熵损失中weight和ignore_index参数

3. weight 参数

当涉及到样本的个数不平衡的时候,可以将样本少的label,w加大点

import torch
import torch.nn as nn
import torch.nn.functional as F


pred = torch.Tensor([[0.9, 0.1],[0.8, 0.2],[0.7, 0.3]])     # 预测值 size = 3*2, dtype = torch.float32
label = torch.LongTensor([0, 1, 0])                         # 真实值 size = 3 , dtype = torch.int64
w = torch.FloatTensor([1,2])
loss = nn.CrossEntropyLoss(weight=w)
out = loss(pred,label)
print(out)      # tensor(0.7398)

计算方法是:

loss =- ( 1*ln0.69 + 2*ln0.3543+1*ln0.5987) / 4 = (0.3711 + 2.0741+ 0.5130) / 4= 0.7396

可以发现答案是类似的,这里保留了四位小数进行计算,所以有误差

关于nn.CrossEntropyLoss交叉熵损失中weight和ignore_index参数

因为,label = 1有一个,label = 0 有两个,所以1的样本较少,这里就对label = 1设置权重大点。可以发现,计算出来的loss确实比不加loss的大,下图为不加w的

关于nn.CrossEntropyLoss交叉熵损失中weight和ignore_index参数

如果将w改成[2,1]的话,loss会更低,不利于loss的下降

关于nn.CrossEntropyLoss交叉熵损失中weight和ignore_index参数

 文章来源地址https://www.toymoban.com/news/detail-416033.html

所以,在样本不均衡的情况下,加label少的样本,w加大,可以将loss变大,从而梯度下降的时候可以更好的弥补样本不平衡的问题

注意:w的类型是float

4. 例子

测试代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F


pred = torch.Tensor([[0.9, 0.1,0.2],[0.8, 0.2,0.1],[0.7, 0.3,0.5],[0.1,0.5,0.6]])
label = torch.LongTensor([2, 1, 0,1])

s = F.softmax(pred,dim=1)
print(s)

w = torch.FloatTensor([2,1,2])
loss = nn.CrossEntropyLoss(weight=w,ignore_index=2)
out = loss(pred,label)
print(out)      # tensor(1.0401)

其中,pred的softmax如下:

label 为:2 1 0 1

关于nn.CrossEntropyLoss交叉熵损失中weight和ignore_index参数

可以发现,label 是 0 1 2 三类,这里将label = 2的忽略,并且对0 1 2施加的权重为 2 1 2

所以手动计算的公式为,这里精确到六位小数

label = 0 的损失 = - ln0.4018 = 0.911801

label = 1 的损失 = (- ln0.2683 - ln0.3603 ) / 2 = (1.315650 + 1.020818)/2 = 1.168234

label = 2 的损失 = - ln0.2552 = 1.365708

这里忽略了label = 2,所以还剩:

label = 0 的损失 = - ln0.4018 = 0.911801

label = 1 的损失 = (- ln0.2683 - ln0.3603 ) / 2 = (1.315650 + 1.020818)/2 = 1.168234

并且对0 1 进行加权2 1

所以总的loss = (0.911801 *2 + 1.315650*1+1.020818*1) /(2+1+1) = 4.16007/4=1.0400175

可以发现结果是一样的,这里最后是精度问题

关于nn.CrossEntropyLoss交叉熵损失中weight和ignore_index参数

 

到了这里,关于关于nn.CrossEntropyLoss交叉熵损失中weight和ignore_index参数的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • loss = nn.CrossEntropyLoss(reduction=‘none‘)

    nn.CrossEntropyLoss() 函数是 PyTorch 中用于计算交叉熵损失的函数。 其中 reduction 参数用于 控制输出损失的形式 。 当 reduction=\\\'none\\\' 时,函数会输出一个形状为 (batch_size, num_classes) 的矩阵,表示 每个样本的每个类别的损失 。 当 reduction=\\\'sum\\\' 时,函数会对 矩阵求和 ,输出一个标量

    2024年02月14日
    浏览(29)
  • 深度学习之PyTorch实战(5)——对CrossEntropyLoss损失函数的理解与学习

      其实这个笔记起源于一个报错,报错内容也很简单,希望传入一个三维的tensor,但是得到了一个四维。 查看代码报错点,是出现在pytorch计算交叉熵损失的代码。其实在自己手写写语义分割的代码之前,我一直以为自己是对交叉熵损失完全了解的。但是实际上还是有一些些

    2023年04月09日
    浏览(30)
  • nn.BCEWithLogitsLoss中weight参数和pos_weight参数的作用及用法

    上式是nn.BCEWithLogitsLoss损失函数的计算公式,其中w_n对应weight参数。 如果我们在做多分类任务,有些类比较重要,有些类不太重要,想要模型更加关注重要的类别,那么只需将比较重要的类所对应的w权重设置大一点,不太重要的类所对应的w权重设置小一点。 下面是一个代码

    2024年01月23日
    浏览(24)
  • 【人工智能与深度学习】均方损失,交叉墒损失,vgg损失,三元组损失

    均方损失,交叉墒损失,vgg损失,三元组损失的应用场景有哪些 均方损失(Mean Squared Error, MSE),交叉熵损失(Cross-Entropy Loss),和三元组损失(Triplet Loss)是机器学习和深度学习中常用的损失函数,每个都适用于不同的应用场景: 1. 均方损失(MSE) 应用场景 :主要用于回

    2024年01月22日
    浏览(87)
  • 损失函数——交叉熵损失(Cross-entropy loss)

    交叉熵损失(Cross-entropy loss) 是深度学习中常用的一种损失函数,通常用于分类问题。它衡量了模型预测结果与实际结果之间的差距,是优化模型参数的关键指标之一。以下是交叉熵损失的详细介绍。 假设我们有一个分类问题,需要将输入数据x分为C个不同的类别。对于每个

    2024年02月02日
    浏览(37)
  • 交叉熵--损失函数

    目录 交叉熵(Cross Entropy) 【预备知识】 【信息量】 【信息熵】 【相对熵】 【交叉熵】 是Shannon信息论中一个重要概念, 主要用于度量两个概率分布间的差异性信息。 语言模型的性能通常用交叉熵和复杂度(perplexity)来衡量。交叉熵的意义是用该模型对文本识别的难度,

    2024年02月12日
    浏览(24)
  • DDPM交叉熵损失函数推导

    K L rm KL K L 散度 由于以下推导需要用到 K L rm KL K L 散度,这里先简单介绍一下。 K L rm KL K L 散度一般用于度量两个概率分布函数之间的“距离”,其定义如下: K L [ P ( X ) ∣ ∣ Q ( X ) ] = ∑ x ∈ X [ P ( x ) log ⁡ P ( x ) Q ( x ) ] = E x ∼ P ( x ) [ log ⁡ P ( x ) Q ( x ) ] KLbig[P(X)||Q(X)

    2024年02月10日
    浏览(29)
  • 深度学习——常见损失函数Loss:L1 ,L2 ,MSE ,Binary Cross ,Categorical Cross ,Charbonnier ,Weighted TV ,PSNR

    在深度学习中,损失函数是一个核心组件,它度量模型的预测结果与真实值之间的差异。通过最小化损失函数的值,模型能够在训练过程中逐渐改善其性能。损失函数为神经网络提供了一个明确的优化目标,是连接数据和模型性能的重要桥梁。 选择合适的损失函数是非常重要

    2024年01月24日
    浏览(45)
  • 交叉熵(Cross Entropy)损失函数

    交叉熵(Cross Entropy)损失函数是一种常用的损失函数,广泛应用于分类问题中,尤其是二分类问题和多分类问题。 假设有 N N N 个样本,每个样本有 C C C 个类别, y i ∈ { 0 , 1 } C y_i in {0,1}^C y i ​ ∈ { 0 , 1 } C 表示第 i i i 个样本的真实标签(one-hot编码), y i ^ ∈ [ 0 , 1 ]

    2024年02月09日
    浏览(31)
  • pytorch——损失函数之nn.L1Loss()和nn.SmoothL1Loss()

    今天讨论下:对称损失函数:symmetric regression function such as L1 or L2 norm,注意说说L1 1.1 数学定义 平均绝对误差(MAE)是一种用于回归模型的损失函数。MAE 是目标变量和预测变量之间绝对差值之和,因此它衡量的是一组预测值中的平均误差大小,而不考虑它们的方向,范围为

    2024年02月04日
    浏览(34)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包