1. bool矩阵当做索引(类型是:BoolTensor)
结果为一维向量(因为bool矩阵二维的,根据bool矩阵中True对应位置,把tensor数据中相应位置中的值取出来,组成一个新的一维tensor向量)
#布尔索引 用布尔索引总是会返回一份新创建的数据,原本的数据不会被改变。
a2 = np.arange(15).reshape(3,5)
print('a2===',a2)
mask = a2<5
b2 = a2[mask]
print('b2===',b2)
b2[0] = 17
print('a2===',a2) #修改b2中的数据,会发现原数据a2中的值没有发生改变。
输出结果:
a2=== tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
b2=== tensor([0, 1, 2, 3, 4])
a2=== tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
2. 一维bool向量当做索引(类型是:BoolTensor)
第一个例子:结果为一维向量(target是一维的)
target = torch.Tensor([1,0,0,2,0,0,3])
mask = (target > 0)
masked_target = target[mask]
print(target)
print(mask)
print(masked_target)
输出结果如下:
tensor([1., 0., 0., 2., 0., 0., 3.])
tensor([ True, False, False, True, False, False, True])
tensor([1., 2., 3.])
第二个例子:结果为一个二维的向量(x是二维的)
x = torch.randn(4,3)
y = torch.tensor([True,False,False,True])
print('x===',x)
print('y===',y)
print('x[y]===',x[y])
print('y.sum()===',y.sum())
输出结果:
x===tensor([[ 0.4563, 0.3963, 0.4101],
[-0.4360, -0.2968, 1.0010],
[ 0.2851, 0.0890, 0.5452],
[ 0.8384, -1.1912, 0.2131]])
y===tensor([ True, False, False, True])
x[y]===tensor([[ 0.4563, 0.3963, 0.4101],
[ 0.8384, -1.1912, 0.2131]])
y.sum()===tensor(2)
3. torch.masked_select() 根据bool矩阵取出tensor中对应位置元素(类型是:BoolTensor)。
torch.masked_select(input, mask, *, out=None) → Tensor
此方法中mask是一个bool矩阵,在input中取出mask中True对应的值。
首先介绍参数:
- input(tensor):需要进行处理的tensor。
- mask(BoolTensor):包含了二进制掩码,要进行索引的tensor。
- out:输出的结果tensor(结果为一维tensor)。
使用方法如下:
第一个例子:
>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.3552, -2.3825, -0.8297, 0.3477],
[-1.2035, 1.2252, 0.5002, 0.6248],
[ 0.1307, -2.0608, 0.1244, 2.0139]])
>>> mask = x.ge(0.5)
>>> mask
tensor([[False, False, False, False],
[False, True, True, True],
[False, False, False, True]])
>>> torch.masked_select(x, mask)
tensor([ 1.2252, 0.5002, 0.6248, 2.0139])
第二个例子:
target = torch.Tensor([1,0,0,2,0,0,3])
mask = target.ge(0)
masked_target = torch.masked_select(target, mask)
print(target)
print(mask)
print(masked_target)
输出结果为:
tensor([1., 0., 0., 2., 0., 0., 3.])
tensor([ True, False, False, True, False, False, True])
tensor([1., 2., 3.])
使用方法其实挺明显的,就是把input与mask相对应起来,取出mask中True所对应位置的数据,组成一维的tensor。文章来源:https://www.toymoban.com/news/detail-409593.html
注意:mask和input的形状可以不相同,但是它们必须是可以广播的。并且返回tensor和原tensor使用不同的内存,相互独立。文章来源地址https://www.toymoban.com/news/detail-409593.html
4. 参考链接
- Python基础 切片索引、布尔索引、花式索引
- Pytorch中Tensor的索引,切片以及花式索引(fancy indexing)
- pytorch中bool类型的张量作为索引
- pytorch中几种tensor掩码的获取方法(含代码)
- pytorch每日一学35(torch.masked_select())根据bool矩阵取出tensor中对应位置元素
到了这里,关于pytorch 根据bool矩阵取出tensor中对应位置元素的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!