前言
这个本来是没打算写的,因为看了官方的解释以及在网上看了好几个教程都没理解什么意思,所以把自己理解的东西整理分享一下。
官方的解释
官网链接:torch.gather()
给个截图如下
常用的参数有3个,第一个input
表示要从中选取元素,第二个dim
表示操作的维度,第三个index
表示选取元素的索引。
按照官方的解释我是没看懂的,后面去找教程也一知半解,所以自己琢磨了一下,终于悟了。
使用详解
结合着例子,直接看代码把:
import torch
a = torch.arange(3, 12).view(3, 3)
print(a)
# tensor([[ 3, 4, 5],
# [ 6, 7, 8],
# [ 9, 10, 11]])
index = torch.tensor([[2, 1, 0]])
b = torch.gather(a, dim=0, index=index)
print(b)
# tensor([[9, 7, 5]])
# 1、将index中的各个元素的索引明确,获得具体坐标:
# index = torch.tensor([[2, 1, 0]])中,
# 2的索引(坐标)为(0,0),1的索引(坐标)为(0,1),0的索引(坐标)为(0,2)
# 2、将具体坐标中对应的维度替换成index中的值:
# 2的索引(坐标)为(0,0),将第0个维度的索引替换后的新坐标为(2, 0),用2替换掉0
# 1的索引(坐标)为(0,1),将第0个维度的索引替换后的新坐标为(1, 1),用1替换掉0
# 0的索引(坐标)为(0,2),将第0个维度的索引替换后的新坐标为(0, 2),用0替换掉0
# 3、按照新的坐标取输入中的值:
# tensor([[ 3, 4, 5],
# [ 6, 7, 8],
# [ 9, 10, 11]]),坐标(2,0)值为9,坐标(1,1)值为7,坐标(0,2)值为5,得到最后的结果[9,7,5].
index = torch.tensor([[2, 1, 0]])
c = torch.gather(a, dim=1, index=index)
print(c) # tensor([[5, 4, 3]])
# 1、获取具体坐标:(0,0),(0,1),(0,2)
# 2、第1维度替换坐标:(0,2),(0,1),(0,0)
# 3、找元素:[5,4,3]
# 二维的情况也一样
index = torch.tensor([[0, 2],
[1, 2]])
d = torch.gather(a, dim=1, index=index)
print(d)
# tensor([[3, 5],
# [7, 8]])
# 1、获取具体坐标:(0,0),(0,1),(1,0),(1,1)
# 2、第1维度替换坐标:(0,0),(0,2),(1,1),(1,2)
# 3、找元素:[[3, 5],[7, 8]]
怕在代码里面太暗了看不清楚,在这里再贴一次:
以第一个为例:文章来源:https://www.toymoban.com/news/detail-483805.html
创建张量
a = torch.arange文章来源地址https://www.toymoban.com/news/detail-483805.html
到了这里,关于torch.gather()使用解析的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!