PyTorch中的选择函数
gather()函数
import torch
a = torch.arange(1, 16).reshape(5, 3)
"""
result:
a = [[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12],
[13, 14, 15]]
"""
# 定义两个index
b = torch.tensor([[0, 1, 2], [2, 3, 4], [0, 2, 4]])
c = torch.tensor([[1, 2, 0, 2, 1], [1, 2, 1, 0, 0]])
# axis=0
output1 = a.gather(0, b)
"""
result:
[[1, 5, 9],
[7, 11, 15],
[1, 8, 15]]
"""
# axis=1
output2 = a.gather(1, c)
"""
result:
[[2, 3, 1, 3, 2],
[5, 6, 5, 4, 4]]
"""
文章来源:https://www.toymoban.com/news/detail-758464.html
参考文献
1、理解pytorch几个高级选择函数(如gather)
2、图解PyTorch的torch.gather函数文章来源地址https://www.toymoban.com/news/detail-758464.html
到了这里,关于『PyTorch』张量和函数之gather()函数的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!