PyTorch中的符号索引和函数索引用法

这篇具有很好参考价值的文章主要介绍了PyTorch中的符号索引和函数索引用法。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

Pytorch中很多函数都采用的是函数式索引的思路,而且使用函数式索引对代码可读性会有很大提升。

张量的符号索引

张量也是有序序列,我们可以根据每个元素在系统内的顺序位置,来找出特定的元素,也就是索引。

一维张量的索引

一维张量索引与Python中的索引一样是是从左到右,从0开始的,遵循格式为[start: end: step]

>>> data = torch.arange(1, 11)
>>> print(data)
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10])
>>> data[0]
tensor(1)

张量索引出的结果是零维张量,而不是单独的数。要转化成数,需要使用item()方法:

>>> data[0].item()
1

批注:构成一维张量的是零维张量,而不是单独的数。

>>> data[3:9:2] # 隔2个数取一个,左闭右开
tensor([4, 6, 8])

在Python中,step可以为负数,例如:

>>> num_list = [1, 2, 3]
>>> num_list[::-1]
[3, 2, 1]

但在张量中,step必须大于0,否则就会报错。

>>> num_tensor = torch.arange(1, 11)
>>> num_tensor[::-1]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: step must be greater than zero

二维张量的索引

二维张量的索引逻辑和一维张量的索引逻辑相同,二维张量可以视为两个一维张量组合而成。

>>> data2 = torch.arange(1, 21).reshape(4, 5)
>>> data2
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10],
        [11, 12, 13, 14, 15],
        [16, 17, 18, 19, 20]])
>>> data2[0, 1], data2[0][1]  # 这两种索引方式都可以
(tensor(2), tensor(2))

但是data2[::2, ::2]data2[::2][ ::2]的索引结果就不同:

>>> data2[::2, ::2]
tensor([[ 1,  3,  5],
        [11, 13, 15]])
>>> data2[::2][::2]
tensor([[1, 2, 3, 4, 5]])

解释:

  • t2[::2, ::2]二维索引使用逗号隔开时,可以理解为全局索引,取第一行和第三行的第一列和第三列的元素。
  • t2[::2][::2]二维索引在两个中括号中时,可以理解为先取了第一行和第三行,构成一个新的二维张量,然后在此基础上又间隔2并对所有张量进行索引。
>>> d = data2[::2]
>>> d
tensor([[ 1,  2,  3,  4,  5],
        [11, 12, 13, 14, 15]])
>>> d[::2]
tensor([[1, 2, 3, 4, 5]])

三维张量的索引

设三维张量的shapexyz,则可理解为它是由x个二维张量构成,每个二维张量由y个一维张量构成,每个一维张量由z个元素构成。

>>> data3 = torch.arange(1, 25).reshape(2, 3, 4)
>>> data3
tensor([[[ 1,  2,  3,  4],
         [ 5,  6,  7,  8],
         [ 9, 10, 11, 12]],

        [[13, 14, 15, 16],
         [17, 18, 19, 20],
         [21, 22, 23, 24]]])
>>> data3[1, 1, 1]
tensor(18)
>>> data3[1, ::2, ::2]
tensor([[13, 15],
        [21, 23]])

高维张量的思路与低维一样,就是围绕张量的“形状”进行索引。

张量的函数索引

除了常⽤的索引选择数据之外,PyTorch还提供了⼀些⾼级的选择函数:

  • index_select(input, dim, index):在指定维度dim上选取,⽐如选取某些⾏、某些列
  • masked_select(input, mask):例⼦如上,a[a>0],使⽤ByteTensor进⾏选取
  • non_zero(input):⾮0元素的下标
  • gather(input, dim, index):根据index,在dim维度上选取数据,输出的sizeindex⼀样

index_select()

index_select(dim, index)表示在张量的哪个维度进行索引,索引的位置是多少。

torch.index_select()函数返回的是沿着输入张量的指定维度指定索引号进行索引的张量子集

torch.index_select(input, dim, index, out=None)

其函数参数有:

  • input(Tensor) - 需要进行索引操作的输入张量;
  • dim(int) - 需要对输入张量进行索引的维度;
  • index(LongTensor) - 包含索引号的 1D 张量;

index_select函数指定index来对张量进行索引,index的类型必须为Tensor

由于 index_select 函数只能针对输入张量的其中一个维度的一个或者多个索引号进行索引,因此可以通过 PyTorch 中的高级索引来实现。下面列举三个例子来说明这个函数的用法:

  1. 获取1D张量的第1个维度且索引号为2和3的张量子集
>>> data = torch.arange(9)
>>> sub_data1 = torch.index_select(data, dim=0, index=torch.tensor([2, 3]))
>>> sub_data2 = data[[2, 3]]
>>> data
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])
>>> sub_data1
tensor([2, 3])
>>> sub_data2
tensor([2, 3])
  1. 获取2D张量的第2个维度且索引号为0和1的张量子集(第一列和第二列)
>>> data2 = torch.arange(9).view(3, 3)
>>> sub_data1 = torch.index_select(data2, dim=1, index=torch.tensor([0, 1]))
>>> sub_data2 = data2[:, [0, 1]]
>>> data2
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
>>> sub_data1
tensor([[0, 1],
        [3, 4],
        [6, 7]])
>>> sub_data2
tensor([[0, 1],
        [3, 4],
        [6, 7]])
  1. 获取3D张量的第1个维度且索引号为0的张量子集
>>> data3 = torch.arange(18).view(2, 3, 3)
>>> sub_data1 = torch.index_select(data3, dim=0, index=torch.tensor([0]))
>>> sub_data2 = data3[0]
>>> print(data3)
tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8]],

        [[ 9, 10, 11],
         [12, 13, 14],
         [15, 16, 17]]])
>>> print(sub_data1)
tensor([[[0, 1, 2],
         [3, 4, 5],
         [6, 7, 8]]])
>>> print(sub_data2)
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

index_select 函数虽然简单,但是有几点需要注意:

  • index 参数必须是1D长整型张量 (1D-LongTensor);
  • 使用 index_select 函数输出的张量维度和原始的输入张量维度相同。(即是说,原来是三维的张量,就会输出三维的张量)
>>> data_rand3 = torch.rand(3, 4)
>>> sub_data1 = torch.index_select(data_rand3, dim=0, index=torch.tensor([0]))
>>> sub_data2 = data_rand3[[0]]
>>> sub_data3 = data_rand3[0]
>>> print(sub_data1)
tensor([[0.1926, 0.6743, 0.9063, 0.0857]])
>>> print(sub_data2)
tensor([[0.1926, 0.6743, 0.9063, 0.0857]])
>>> print(sub_data3)
tensor([0.1926, 0.6743, 0.9063, 0.0857])
>>> print(sub_data1.size(), sub_data2.size(), sub_data3.size())
torch.Size([1, 4]) torch.Size([1, 4]) torch.Size([4])

上面的代码示例可以说明,三种方式索引出来的张量子集中的元素是一样的,不同的是索引出来的张量子集的形状。所以,前文才说 index_select 函数对输入张量进行索引可以使用高级索引实现。

masked_select()

masked_select()函数返回一个根据布尔掩码 (boolean mask) 索引输入张量的 1D 张量。其用法如下:

torch.masked_select(input, mask, out=None) → Tensor

具体地:

  • input (Tensor) : 输入张量
  • mask (ByteTensor) : 掩码张量,包含了二元索引值
  • out (Tensor, optional) : 目标张量
>>> data = torch.randn(5, 4)
>>> mask_index = data.ge(0)  # 筛选大于0的结果
>>> res = torch.masked_select(data, mask_index)
>>> print(data)
tensor([[ 2.5874, -1.5814, -0.6473, -0.1795],
        [-0.4612,  0.2462,  0.5025,  0.9862],
        [ 1.2485,  0.6655,  1.5536,  0.7446],
        [-1.2433,  1.8842, -0.6330, -0.8245],
        [-0.5634, -1.1724,  1.3369,  0.5930]])
>>> print(res)
tensor([2.5874, 0.2462, 0.5025, 0.9862, 1.2485, 0.6655, 1.5536, 0.7446, 1.8842,
        1.3369, 0.5930])

masked_select 函数最关键的参数就是布尔掩码 mask参数。其通过布尔张量maskTrue或者False来决定输入张量对应位置的元素是否保留,最后返回一维张量。

很明显,这种操作是一一对应的关系(True就保留,False就舍去),这就需要maskinput的形状相同。

  • 两者的形状可以完全相同,也即是input.shape = mask.shape
  • 广播机制,两者的形状可以不完全相同,但是必须要能够通过 PyTorch 中的广播机制广播成相同形状的张量。

广播机制 (Broadcast) 是在科学运算中经常使用的小技巧,它是一种轻量级的张量复制手段,只在逻辑层面扩展和复制张量,并不进行实际的存储复制操作,从而大大的减少了计算代价。但并不是所有形状不一致的张量都能进行广播,需要满足一定的规则。比如对于两个张量来说:

  • 如果两个张量的维度不同,则将维度小的张量进行扩展,直到两个张量的维度一样;
  • 如果两个张量在对应维度上的长度相同或者其中一个张量的长度为 1,那么就说这两个张量在该维度上是相容的;
  • 如果两个张量在所有维度上都是相容的,表示这两个张量能够进行广播,否则会出错;
  • 在任何一个维度上,如果一个张量的长度为 1,另一个张量的长度大于 1,那么在该维度上,就好像是对第一个张量进行了复制;

masked_select 函数中的广播机制比较简单,只需要保证输入张量不变,对布尔张量进行广播,而广播后的形状和输入张量的形状一致就可以了。

>>> data = torch.arange(8).view(2, 4)
>>> mask2 = torch.tensor([True, True, False, True])  # 能广播
>>> print(data)
tensor([[0, 1, 2, 3],
        [4, 5, 6, 7]])
>>> print(torch.masked_select(data, mask2))
tensor([0, 1, 3, 4, 5, 7])

需要注意两点:

  • 使用 masked_select 函数返回的结果都是 1D 张量,张量中的元素就是被筛选出来的元素值;
  • 传入 input 参数中的输入张量和传入 mask 参数中的布尔张量形状可以不一致,但是布尔张量必须要能够通过广播机制扩展成和输入张量相同的形状;

问题来了,注意看下面的代码示例:

>>> data = torch.randn(3, 4)
>>> mask = data.ge(0)
>>> print(data[mask])
>>> print(torch.masked_select(data, mask))
tensor([[False,  True,  True, False],
        [ True, False,  True,  True],
        [ True, False,  True,  True]])
tensor([0.5393, 0.2735, 0.9606, 0.0107, 0.0654, 0.8304, 0.8467, 0.0034])
tensor([0.5393, 0.2735, 0.9606, 0.0107, 0.0654, 0.8304, 0.8467, 0.0034])

可以发现,masked_select函数其实没太大必要,直接通过data[mask]就可以达到效果了。那这个函数存在的意义在哪呢?这个问题留待后续…

non_zero()

non_zero()函数用于输出数组的非零值的索引,即用来定位数组中非零的元素。其用法如下:

torch.nonzero(input, *, out=None, as_tuple=False) → LongTensor or tuple of LongTensors

参数为:

  • input:输入的数组
  • as_tuple:函数返回方式,默认为False

如果设为False,则返回一个二维张量,其中每一行都是非零值的索引,如果输入的数组有 n n n维,则输出的张量维度大小为 z × n z\times n z×n,其中 z z zinput非零元素的总数。

>>> a = torch.randn(3, 5)
>>> a = torch.where(x < 0, x, 0)  # 将非负元素置0
>>> print(a)
tensor([[ 0.0000, -1.0246, -0.2621,  0.0000],
        [ 0.0000, -0.7053, -0.8949, -0.3949],
        [ 0.0000, -0.1732, -0.4669,  0.0000],
        [ 0.0000, -1.0170, -1.1945,  0.0000],
        [ 0.0000,  0.0000,  0.0000, -0.4522]])
>>> print(torch.nonzero(a))
tensor([[0, 1],
        [0, 2],
        [1, 1],
        [1, 2],
        [1, 3],
        [2, 1],
        [2, 2],
        [3, 1],
        [3, 2],
        [4, 3]])

这里nonzero函数输出的结果,就是非零元素的索引。比如第一行[0, 1]代表这源Tensor里面第0行第1列的元素非零。

如果as_tuple设为True,则返回一个由一维张量组成的元组。看下面的输出:

>>> a = torch.randn(3, 5)
>>> a = torch.where(x < 0, x, 0)  # 将非负元素置0
>>> print(a)
tensor([[ 0.0000, -1.0246, -0.2621,  0.0000],
        [ 0.0000, -0.7053, -0.8949, -0.3949],
        [ 0.0000, -0.1732, -0.4669,  0.0000],
        [ 0.0000, -1.0170, -1.1945,  0.0000],
        [ 0.0000,  0.0000,  0.0000, -0.4522]])
>>> print(torch.nonzero(a, as_tuple=True))
(tensor([0, 0, 1, 1, 1, 2, 2, 3, 3, 4]), tensor([1, 2, 1, 2, 3, 1, 2, 1, 2, 3]))

如果输入数组为 n n n维,则有 n n n个一维张量,每个一维张量对应非零元素特定维度的索引(第一个张量数组储存的是所有非零元素第一维度的索引),并且每个张量里面有 z z z个数,其中 z z z为输入数组非零元素的个数。

这个函数除了找出非零元素外,还可用于特定元素定位,比如:

>>> a = torch.arange(12).view(3, 4)
>>> print(a)
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
>>> torch.nonzero(a == 6)  # 输入元素为6的位置
tensor([[1, 2]])

gather()

gather()函数作用:沿给定轴dim,将输入索引张量index指定位置的值进行聚合。(沿着给定的维度dim收集值)

其用法为:

torch.gather(input, dim, index, out=None)

参数解释为:

  • input(Tensor):源张量
  • dim(int):索引的轴
  • index(LongTensor):聚合元素的下标
  • out:目标张量

注意:index的维度要和inputdim所指的维度相同

torch.gather()常用索引多分类中标签所对应的概率。

例子1:按照dim = 0, 取一个二维张量对角线上的数值

>>> a = torch.tensor([[2, 3, 5], [4, 9, 7]])
>>> index = torch.LongTensor([[0, 1, 0]])
>>> b = torch.gather(a, dim=0, index=index)
>>> print(a)
tensor([[2, 3, 5],
        [4, 9, 7]])
>>> print(b)
tensor([[2, 9, 5]])

可以看到dim=0,即行方向的维度和index的维度是匹配的,就是说aindex由行方向从左往右看,有2列,即有2个样本,行方向是匹配的。另外,函数输出的tensorindex大小相同。

上面代码的操作逻辑可以用下图来表示。
PyTorch中的符号索引和函数索引用法
具体地:在a中,由行从左往右看,有两个样本,索引分别为0和1;每个样本有两个特征,每个特征中从上往下索引分别为0和1;依据index中的索引值,取第0样本的第0个特征2,再取第1个样本的第1个特征7。

例子2:按照dim = 1, 取一个二维张量的对角线上的数值

>>> a = torch.tensor([[2, 3], [4, 9], [6, 10]])
>>> index = torch.LongTensor([[0], [1], [0]])
>>> b = torch.gather(a, dim=1, index=index)
>>> print(a)
tensor([[ 2,  3],
        [ 4,  9],
        [ 6, 10]])
>>> print(b)
tensor([[2],
        [9],
        [6]])

可以看到dim=1,即列方向的维度和index的维度是匹配的,就是说aindex由列方向从上往下看,有3行,即有3个样本,列方向是匹配的。另外,函数输出的tensorindex大小相同。

上面代码的操作逻辑可以用下图来表示。
PyTorch中的符号索引和函数索引用法

具体地:在a中,由列从上往下看,有三个样本,索引分别为0、1和0;每个样本有两个特征,每个特征中从左往右索引分别为0和1;依据index中的索引值,取第0样本的第0个特征2,取第1个样本的第1个特征7,再取第2个样本的第1个特征6。

例子3:模拟多分类问题中的概率取值(跟上面的两个例子一样)

>>> a = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
>>> index = torch.LongTensor([[0], [2]])
>>> a
tensor([[0.1000, 0.3000, 0.6000],
        [0.3000, 0.2000, 0.5000]])
>>> a.gather(dim=1, index=index)
tensor([[0.1000],
        [0.5000]])

总结:根据维度按行或者列根据索引取值文章来源地址https://www.toymoban.com/news/detail-427406.html

  • dim=0:在列上按索引取值
  • dim=1:在行上按索引取值

到了这里,关于PyTorch中的符号索引和函数索引用法的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • PyTorch 各种池化层函数全览与用法演示

    目录 torch.nn.functional子模块Pooling层详解 avg_pool1d 用法与用途 参数 注意事项 示例代码 avg_pool2d 用法与用途 参数 注意事项 示例代码 avg_pool3d 用法与用途 参数 注意事项 示例代码 max_pool1d 用法与用途 参数 注意事项 示例代码 max_pool2d 用法与用途 参数 注意事项 示例代码 max_pool

    2024年02月01日
    浏览(70)
  • 【Pytorch学习】pytorch中的isinstance() 函数

    描述 isinstance() 函数来判断一个对象是否是一个已知的类型,类似 type()。 isinstance() 与 type() 区别: type() 不会认为子类是一种父类类型,不考虑继承关系。 isinstance() 会认为子类是一种父类类型,考虑继承关系。 如果要判断两个类型是否相同推荐使用 isinstance()。 语法 以下是

    2024年02月15日
    浏览(33)
  • Pytorch中的repeat以及repeat_interleave用法

    repeat和repeat_interleave都是pytorch中用来复制的两种方法,但是二者略有不同,如下所示。 torch.tensor().repeat()里面假设里面有3个参数,即3,2,1,如下所示: 用repeat时,应当从后往前看,即先复制最后一维,依次向前。 ①最后一个数字为1,复制一次,还是[1,2,3]. ②倒数第二个数

    2023年04月12日
    浏览(39)
  • 详解Pytorch中的view函数

    一、函数简介 Pytorch中的view函数主要用于 Tensor维度的重构 ,即返回一个 有相同数据但不同维度的Tensor 。 根据上面的描述可知,view函数的操作对象应该是Tensor类型。如果不是Tensor类型,可以通过tensor = torch.tensor(data)来转换。 二、实例讲解 ▶view(参数a,参数b,…),其中,总的

    2024年02月16日
    浏览(40)
  • PyTorch中的torch.nn.Linear函数解析

    torch.nn是包含了构筑神经网络结构基本元素的包,在这个包中,可以找到任意的神经网络层。这些神经网络层都是nn.Module这个大类的子类。torch.nn.Linear就是神经网络中的线性层,可以实现形如y=Xweight^T+b的加和功能。 nn.Linear():用于设置网络中的全连接层,需要注意的是全连接

    2024年02月16日
    浏览(40)
  • Pytorch计算余弦相似度距离——torch.nn.CosineSimilarity函数中的dim参数使用方法

    前言 一、官方函数用法 二、实验验证 1.计算高维数组中各个像素位置的余弦距离 2.验证高维数组中任意一个像素位置的余弦距离 总结 现在要使用Pytorch中自带的 torch.nn. CosineSimilarity函数计算两个高维特征图(B,C,H,W)中各个像素位置的特征相似度,即特征图中的每个像素位置上

    2024年02月13日
    浏览(42)
  • 【PyTorch】PyTorch之Tensors索引切片篇

    介绍常用的PyTorch之Tensors索引切片等 torch.argwhere(input) → Tensor 返回一个张量,其中包含输入张量中所有非零元素的索引。结果中的每一行都包含输入中一个非零元素的索引。结果按字典序排序,最后一个索引变化最快(C风格)。 如果输入具有n维,则生成的索引张量out的大小

    2024年01月19日
    浏览(41)
  • 【Pytorch】进阶学习:深入解析 sklearn.metrics 中的 classification_report 函数---分类性能评估的利器

    【Pytorch】进阶学习:深入解析 sklearn.metrics 中的 classification_report 函数—分类性能评估的利器 🌈 个人主页:高斯小哥 🔥 高质量专栏:Matplotlib之旅:零基础精通数据可视化、Python基础【高质量合集】、PyTorch零基础入门教程👈 希望得到您的订阅和支持~ 💡 创作高质量博文

    2024年03月11日
    浏览(43)
  • 【PyTorch教程】pytorch入门系列 ——土堆教程的目录及索引

    一、几句题外话 深度学习上手已经很长时间了,还记得最初的入门是跟着 B站up小土堆 的一步步学起来的,从起初的环境配置,到现在调整整个模型的进阶,非常感谢土堆的贡献。 写这个博客的初衷是为了自己 看着方便 ,由于多台电脑多个环境下查看这些内容很麻烦,所以

    2024年03月17日
    浏览(52)
  • Pytorch实用教程:Pytorch中torch.max的用法

    torch.max 在 PyTorch 中是一个非常有用的函数,它可以用于多种场景,包括寻找张量中的最大值、沿指定维度进行最大值操作,并且还可以返回最大值的索引。其用法可以根据你的需求进行不同的调用方式。 基本用法 找到整个张量的最大值 如果直接对一个张量使用 torch.max ,它

    2024年04月13日
    浏览(33)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包