torch.nn.CrossEntropyLoss(weight=None,size_average=None,ignore_index=-100,reduce=None,reduction='mean',label_smoothing=0.0)
计算过程
nn.CrossEntropyLoss()=nn.LogSoftmax()+nn.NLLLoss()文章来源:https://www.toymoban.com/news/detail-813451.html
import torch
import torch.nn as nn
loss_func = nn.CrossEntropyLoss()
pre = torch.tensor([0.8, 0.5, 0.2, 0.5], dtype=torch.float)
tgt = torch.tensor([1, 0, 0, 0], dtype=torch.float)
print("手动计算:")
print("1.softmax")
print(torch.softmax(pre, dim=-1))
print("2.取对数")
print(torch.log(torch.softmax(pre, dim=-1)))
print("3.与真实值相乘")
print(-torch.sum(torch.mul(torch.log(torch.softmax(pre, dim=-1)), tgt), dim=-1))
print()
print("调用损失函数:")
print(loss_func(pre, tgt))
文章来源地址https://www.toymoban.com/news/detail-813451.html
到了这里,关于torch.nn.CrossEntropyLoss()的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!