torch.roll(input, shifts, dims=None)
这个函数是用来移位的,是顺移。input是咱们要移动的tensor向量,shifts是要移动到的位置,要移动去哪儿,dims是值在什么方向上(维度)去移动。比如2维的数据,那就两个方向,横着或者竖着。最关键的一句话,所有操作针对的是第一行或者第一列,下面举例子给大家做解释,自己慢慢体会
import torch
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]).view(3, 3)
print("")
print(x)
y = torch.roll(x, 1, 0)
print("")
print(y)
输出:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
tensor([[7, 8, 9],
[1, 2, 3],
[4, 5, 6]])
torch.roll(x, 1, 0) 这行代码的意思就是把x的第一行(0维度)移到1这个位置上,其他位置的数据顺移。
x——咱们要移动的向量
1——第一行向量要移动到的最终位置
0——从行的角度去移动
import torch
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]).view(3, 3)
print("")
print(x)
y = torch.roll(x, -1, 1)
print("")
print(y)
输出:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
tensor([[2, 3, 1],
[5, 6, 4],
[8, 9, 7]])
torch.roll(x, -1, 1) 意思就是把x的第一列(1维度)移到-1这个位置(最后一个位置)上,其他位置的数据顺移。
shifts和dims可以是元组,其实就是分步骤去移动
import torch
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]).view(3, 3)
print("")
print(x)
y = torch.roll(x, (0,1), (1,1))
print("")
print(y)
输出:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
tensor([[3, 1, 2],
[6, 4, 5],
[9, 7, 8]])
torch.roll(x, (0,1), (1,1)) :
首先,针对元组第一个元素,把x的第一列(1维度)移到0这个位置(已经在0这个位置,因此原地不动)上,其他位置的数据顺移。(所有数据原地不动)文章来源:https://www.toymoban.com/news/detail-856848.html
然后,针对元组第二个元素,把a的第一列(1维度)移到1这个位置上,其他位置的数据顺移。文章来源地址https://www.toymoban.com/news/detail-856848.html
到了这里,关于pytorch中torch.roll用法说明的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!