Pytorch函数——torch.gather详解

这篇具有很好参考价值的文章主要介绍了Pytorch函数——torch.gather详解。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

在学习强化学习时,顺便复习复习pytorch的基本内容,遇到了 torch.gather()函数,参考图解PyTorch中的torch.gather函数 - 知乎 (zhihu.com)进行解释。

pytorch官网对函数给出的解释:

Pytorch函数——torch.gather详解,深度学习,pytorch,人工智能,python

即input是一个矩阵,根据dim的值,将index的值替换到不同的维度的索引,当dim为0时,index替代i的值,成为第0维度的索引。

输入和输出的矩阵形式相同。

例子:首先我们生成3×3的矩阵,明确行索引的概念,第0行指的是[3,4,5],第0列指的是[[3] [6] [9]]

import torch
tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)
tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])

index为行向量且dim=0时

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)
tensor([[9, 7, 5]])

当dim=0时,替换第0维度。由于input为二维列表,因此第0维度指的是选择第几行的维度,即行索引所在的维度,替换了i的索引,为input[index[i][j]] [j]

那么我们会输出tensor([[ input[2][j] input[1][j] input[0][j] ]]),那么j如何获得呢?从index of index中拿到,index每一个元素的索引为(0,0) (0,1) (0,2),取j,则为0,1,2,那么输出则为tensor([[ input[2][0] input[1][1] input[0][2] ]]),即

tensor([[9, 7, 5]])

输入行向量index,且dim=1

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
tensor([[5, 4, 3]])

维度为1,则替换列索引的值,那么输出为tensor([[ input[i][2] input[i][1] input[i][0] ]]),index每一个元素的索引为(0,0) (0,1) (0,2),i均为1,那么tensor([[ input[0][2] input[0][1] input[0][0] ]])

输入为行向量,dim=0

index = torch.tensor([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)
tensor([[5],
        [7],
        [9]])

维度为0,则替换行索引,且输出与输入的格式相同,为

tensor([input[2][j],
        input[1][j],
        input[0][j]])

index每一个元素的索引为(0,0) (1,0) (2,0),j对应的值为0,0,0,则

tensor([input[2][0],
        input[1][0],
        input[0][0]])

输入为行向量,dim=1

index = torch.tensor([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
tensor([[5],
        [7],
        [9]])

维度为1,则替换列索引,且输出与输入的格式相同,为

tensor([input[i][2],
        input[i][1],
        input[i][0]])

index每一个元素的索引为(0,0) (1,0) (2,0),i对应的值为0,1,2,则

tensor([input[0][2],
        input[1][1],
        input[2][0]])

输入为二维矩阵,且dim=1

index = torch.tensor([[0, 2],
                      [1, 2]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
tensor([[3, 5],
        [7, 8]])

维度为1,则替换列索引,且输出与输入的格式相同,为

tensor([[input[i][0], input[i][2]],
        [input[i][1], input[i][2]]])

替换为行索引后,可得:

tensor([[input[0][0], input[0][2]],
        [input[1][1], input[1][2]]])

在强化学习中的应用

在PyTorch官网DQN页面的代码中,i是state,j是a

# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
# columns of actions taken. These are the actions which would've been taken
# for each batch state according to policy_net
state_action_values = policy_net(state_batch).gather(1, action_batch)

我们使用dim=1action_batch 将获得的动作列表替换为列索引,即可获得每个state下该动作的动作价值。文章来源地址https://www.toymoban.com/news/detail-799848.html

到了这里,关于Pytorch函数——torch.gather详解的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【Torch API】pytorch 中bincount()函数详解

    torch.bincount 是 PyTorch 中的函数,用于计算给定整数张量中每个值的出现次数。它返回一个张量,其中的每个元素表示输入张量中对应索引值出现的次数。 具体而言, torch.bincount 函数的语法如下: 其中: input 是输入的整数张量,可以是一维或多维的。 weights 是可选的权重张量

    2024年02月11日
    浏览(87)
  • 【Torch API】pytorch 中repeat_interleave函数详解

    torch. repeat_interleave ( input ,  repeats ,  dim=None ) → Tensor Repeat elements of a tensor. Parameters input  (Tensor) – the input tensor. repeats  (Tensor  or  int) – The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis. dim  (int ,  optional ) – The dimension along which to repeat values

    2024年02月11日
    浏览(42)
  • 【深度学习】torch.utils.data.DataLoader相关用法 | dataloader数据加载器 | pytorch

    dataloader数据加载器属于是深度学习里面非常基础的一个概念了,基本所有的图像项目都会用上,这篇博客就把它的相关用法总结一下。 之所以要写这篇,是想分清楚len(data_loader)和len(data_loader.dataset) 这里加载的数据我们以Mnist手写数据集为例子 torchvision.datasets.MNIST是用来加载

    2024年02月16日
    浏览(54)
  • 深度学习:Pytorch安装的torch与torchvision的cuda版本冲突问题与解决历程记录

    今天不小心将conda环境中的一个pytorch环境中的torch包给搞混了,将其更新了一下,发生了一些问题: 当时运行了一下这个代码:  pip install torchvision --upgrade 导致了环境中包的混乱: 只能说欲哭无泪,当时这个 pytorch环境中我是安装的CUDA11.8的版本应该,后来安装了cpu版本的将

    2024年02月20日
    浏览(44)
  • 人工智能学习07--pytorch14--ResNet网络/BN/迁移学习详解+pytorch搭建

    亮点:网络结构特别深 (突变点是因为学习率除0.1?) 梯度消失 :假设每一层的误差梯度是一个小于1的数,则在反向传播过程中,每向前传播一层,都要乘以一个小于1的误差梯度。当网络越来越深的时候,相乘的这些小于1的系数越多,就越趋近于0,这样梯度就会越来越小

    2023年04月11日
    浏览(132)
  • Pytorch学习笔记(5):torch.nn---网络层介绍(卷积层、池化层、线性层、激活函数层)

     一、卷积层—Convolution Layers  1.1 1d / 2d / 3d卷积 1.2 卷积—nn.Conv2d() nn.Conv2d 1.3 转置卷积—nn.ConvTranspose nn.ConvTranspose2d  二、池化层—Pooling Layer (1)nn.MaxPool2d (2)nn.AvgPool2d (3)nn.MaxUnpool2d  三、线性层—Linear Layer  nn.Linear  四、激活函数层—Activate Layer (1)nn.Sigmoid  (

    2024年01月20日
    浏览(41)
  • 人工智能学习07--pytorch15(前接pytorch10)--目标检测:FPN结构详解

    backbone:骨干网络,例如cnn的一系列。(特征提取) (a)特征图像金字塔 检测不同尺寸目标。 首先将图片缩放到不同尺度,针对每个尺度图片都一次通过算法进行预测。 但是这样一来,生成多少个尺度就要预测多少次,训练效率很低。 (b)单一特征图 faster rcnn所采用的一种方式

    2023年04月12日
    浏览(65)
  • 【深度学习框架-torch】torch.norm函数详解用法

    torch版本 1.6 dim是matrix norm 如果 input 是 matrix norm ,也就是维度大于等于2维,则 P值默认为 fro , Frobenius norm 可认为是与计算向量的欧氏距离类似 有时候为了比较真实的矩阵和估计的矩阵值之间的误差 或者说比较真实矩阵和估计矩阵之间的相似性,我们可以采用 Frobenius 范数。

    2024年02月10日
    浏览(47)
  • 人工智能概论报告-基于PyTorch的深度学习手写数字识别模型研究与实践

    本文是我人工智能概论的课程大作业实践应用报告,可供各位同学参考,内容写的及其水,部分也借助了gpt自动生成,排版等也基本做好,大家可以参照。如果有需要word版的可以私信我,或者在评论区留下邮箱,我会逐个发给。word版是我最后提交的,已经调整统一了全文格

    2024年02月05日
    浏览(69)
  • 人工智能深度学习100种网络模型,精心整理,全网最全,PyTorch框架逐一搭建

    大家好,我是微学AI,今天给大家介绍一下人工智能深度学习100种网络模型,这些模型可以用PyTorch深度学习框架搭建。模型按照个人学习顺序进行排序: 深度学习模型 ANN (Artificial Neural Network) - 人工神经网络:基本的神经网络结构,包括输入层、隐藏层和输出层。 学习点击地

    2024年02月14日
    浏览(50)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包