torch.argmax(input, dim=None, keepdim=False)
argmax函数:返回指定维度最大值的索引,dim指定某一维度,那么这一维度就会消失,返回的所有维度会少这个dim指定的维度,根据这个返回的维度,确定对哪个维度采取argmax操作
例如输入是token_output的维度是(62,320,523):target_len:62【序列最大长度】, 320【batch-size】, 523【词表大小】
output_all_token_id = torch.argmax(token_output, -1).tolist()
这段话的意思就是在让最后一维消失(取每个批次生成概率最大的token),那么就变成(62,320)维度了,意思就是320条生成的文本
简单例子:
假如是二维矩阵:
dim=0意思就是“行”这一维度消失,只剩下列,也就是求每一列中最大值的索引
dim=1意思就是“列”这一维度消失,只剩下行,也就是求每一行中最大值的索引
import torch
a = torch.randn(2, 3)
print(a)
tensor([[-0.3018, 0.3350, 0.8318],
[ 0.2485, 0.5349, -1.2342]])
# 求所有值中最大值的索引
print(torch.argmax(a))
# dim=0意思就是“行”这一维度消失,只剩下列,也就是求每一列中最大值的索引
print(torch.argmax(a, dim=0))
# dim=1意思就是“列”这一维度消失,只剩下行,也就是求每一行中最大值的索引
print(torch.argmax(a, dim=1))
tensor(2)
tensor([2, 1])
tensor([1, 1, 0])文章来源:https://www.toymoban.com/news/detail-548980.html
torch.argmax函数说明_Egozjuer的博客-CSDN博客文章来源地址https://www.toymoban.com/news/detail-548980.html
到了这里,关于torch.argmax()函数【求最大值的索引,并让指定维度消失】的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!