pytorch里torch.gather()和torch.Tensor.scatter()解析

这篇具有很好参考价值的文章主要介绍了pytorch里torch.gather()和torch.Tensor.scatter()解析。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

torch.Tensor.scatter() 类似 gather 的反向操作(gather是读出数据,scatter是写入数据),所以这里只解析torch.gather()。
gather()这个操作在功能上较为反人类,即使某段时间理解透彻了,过了几个月不碰可能又会变得生疏。官方文档对其描述也是较为简单,有些小伙伴看完可能还是不完全理解,本文从根本上去解析这个操作的功能。
概括地说,gather()是index_select()的延伸操作,比index_select()更加灵活,它的操作不属于块操作,而是元素级别的操作,所以性能上应该较低,我们应该尽可能地避免使用这个操作。

下面开始解析这个操作。

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

这个功能的设计目的是“Gathers values along an axis specified by dim.”,这是官方文档的所有描述,看到这句话做目标检测的小伙伴应该能想到这样一个场景:

目标检测网络输出矩阵,前4列是box的坐标,第5列表示检测到目标的种类标签
print(pred)
tensor([[0.0080, 0.6403, 0.9865, 0.0158, 1.0000],
        [0.2742, 0.7470, 0.3837, 0.6689, 3.0000],
        [0.3260, 0.6683, 0.1888, 0.9525, 0.0000],
        [0.7989, 0.9154, 0.1040, 0.5538, 3.0000],
        [0.6746, 0.6193, 0.0161, 0.5166, 0.0000]])

现在我们要挑选出标签是3的所有检测目标框,

i = pred[:, 4].eq(3).nonzero().repeat(1, 4)
torch.gather(pred, 0, i)
tensor([[0.2742, 0.7470, 0.3837, 0.6689],
        [0.7989, 0.9154, 0.1040, 0.5538]])

gather()可以实现实现这种整行地抽取数据,但不是最优的实现方法,我们有更合适的实现方法,index_select()和下标索引:

i = pred[:, 4].eq(3).nonzero().squeeze()
pred.index_select(0, i)[:, :4]
tensor([[0.2742, 0.7470, 0.3837, 0.6689],
        [0.7989, 0.9154, 0.1040, 0.5538]])
        
# 下标索引方法        
pred[i, :4]
tensor([[0.2742, 0.7470, 0.3837, 0.6689],
        [0.7989, 0.9154, 0.1040, 0.5538]])

现在我们要进行更加复杂的数据抽取,输出张量的要求如下:

  • shape是2*2
  • 第0行的第0列对应原始数据pred的第0行第0列,第1列对应pred的第1行第1列,在图中红色元素
  • 第1行的第0列对应原始数据pred的第3行第0列,第1列对应pred的第2行第1列,在图中蓝色元素
    pytorch里torch.gather()和torch.Tensor.scatter()解析,pytorch笔记,机器学习,pytorch,人工智能,python

这时候index_selsect()无法实现,但gather()可以

index = torch.tensor([[0, 1],
				  	  [3, 2]])
torch.gather(pred, 0, index)
tensor([[0.0080, 0.7470],
        [0.7989, 0.6683]])

这个操作的规则如下:

  • 输出张量的shape和索引张量(index)相同

  • 除了dim指示的那个维度,其他所有的维度满足条件: index.size(d) <= input.size(d)

  • index和输入张量input的每个维度一一对应

  • 除了dim指示的那个维度,其他维度的input和output元素位置对应,当index.size(d) < input.size(d)时候,从最前面截取
    pytorch里torch.gather()和torch.Tensor.scatter()解析,pytorch笔记,机器学习,pytorch,人工智能,python
    pytorch里torch.gather()和torch.Tensor.scatter()解析,pytorch笔记,机器学习,pytorch,人工智能,python

  • dim指示的那个维度上数据根据index里具体元素指示的位置去定位

看起来还是不好理解的,好在这个函数的应用场景不多,到目前为止我还没遇到适合这个函数的应用场景,如果哪位小伙伴遇到了请评论区留言感激不尽。文章来源地址https://www.toymoban.com/news/detail-560037.html

到了这里,关于pytorch里torch.gather()和torch.Tensor.scatter()解析的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • linux配置miniconda、pytorch、torch_scatter以及cuda. - 叶辰

    在西方的天际,正在云海中下沉的夕阳仿佛被溶化着,太阳的血在云海和太空中弥漫开来,映现出一大片壮丽的血红。“这是人类的落日。” 进入官网miniconda 正常选择最新版Miniconda3 Linux 64-bit, jetson选择Miniconda3 Linux-aarch64 64-bit。 点击下载或者右键复制下载链接,使用命令下

    2024年02月05日
    浏览(54)
  • 【深度学习笔记】彻底理解torch中的tensor与numpy中array区别及用法

    刚接触深度学习的同学,很多开源项目代码中, 张量tensor 与 数组array 都有使用,不清楚两者有什么区别,以及怎么使用,如何相互转换等。博主起初也有类似的疑惑,经过查阅资料以及实践,逐渐有了深入了解,本文将记录并分享自己对两者的理解,可供参考。 提示:以下

    2023年04月08日
    浏览(97)
  • Pytorch数据类型转换(torch.tensor,torch.FloatTensor)

    之前遇到转为tensor转化为浮点型的问题,今天整理下,我只讲几个我常用的,如果有更好的方法,欢迎补充 1.首先讲下torch.tensor,默认整型数据类型为torch.int64,浮点型为torch.float32 2.这是我认为平常最爱用的转数据类型的方法,可以用dtype去定义数据类型 1.这个函数不要乱用

    2024年02月11日
    浏览(51)
  • 深入浅出Pytorch函数——torch.tensor

    分类目录:《深入浅出Pytorch函数》总目录 相关文章: · 深入浅出TensorFlow2函数——tf.constant · 深入浅出Pytorch函数——torch.tensor · 深入浅出Pytorch函数——torch.as_tensor · 深入浅出Pytorch函数——torch.Tensor · 深入浅出PaddlePaddle函数——paddle.to_tensor 基于 data 构建一个没有梯度历史

    2024年02月04日
    浏览(110)
  • 深入浅出Pytorch函数——torch.Tensor.backward

    分类目录:《深入浅出Pytorch函数》总目录 相关文章: · 深入浅出Pytorch函数——torch.Tensor 计算当前张量相对于图的梯度,该函数使用链式法则对图进行微分。如果张量不是一个标量(即其数据具有多个元素)并且需要梯度,则函数还需要指定梯度,指定的梯度应该是一个与

    2024年02月15日
    浏览(58)
  • 【Pytorch基础教程39】torch常用tensor处理函数

    torch.tensor 会复制data,不想复制可以使用 torch.Tensor.detach() 。 如果是获得numpy数组数据,可以使用 torch.from_numpy() ,共享内存 torch.mm : 用于两个矩阵(不包括向量)的乘法。如维度为(l,m)和(m,n)相乘 torch.bmm : 用于带batch的三维向量的乘法。如维度为(b,l,m)和(b,m,n)相乘 torch.mul : 用于

    2024年02月13日
    浏览(76)
  • RDMA Scatter Gather List详解

    1. 前言 在使用RDMA操作之前,我们需要了解一些RDMA API中的一些需要的值。其中在ibv_send_wr我们需要一个sg_list的数组,sg_list是用来存放ibv_sge元素,那么什么是SGL以及什么是sge呢?对于一个使用RDMA进行开发的程序员来说,我们需要了解这一系列细节。 2. SGE简介 在NVMe over PCIe中

    2024年01月20日
    浏览(49)
  • 深入浅出Pytorch函数——torch.as_tensor

    分类目录:《深入浅出Pytorch函数》总目录 相关文章: · 深入浅出TensorFlow2函数——tf.constant · 深入浅出Pytorch函数——torch.tensor · 深入浅出Pytorch函数——torch.as_tensor · 深入浅出Pytorch函数——torch.Tensor · 深入浅出PaddlePaddle函数——paddle.to_tensor 将数据转换为张量,共享数据并

    2024年02月05日
    浏览(53)
  • Pytorch:TypeError: pic should be PIL Image or ndarray. Got <class ‘torch.Tensor‘>

    关键代码 原因 在于 x 本就是 Tensor 类型的,有写了一次ToTensor()转换类型,因此会报错。 解决办法 删除 transforms.ToTensor() 或者 修改x 类型为其他类型

    2024年02月15日
    浏览(56)
  • Pytorch学习笔记(5):torch.nn---网络层介绍(卷积层、池化层、线性层、激活函数层)

     一、卷积层—Convolution Layers  1.1 1d / 2d / 3d卷积 1.2 卷积—nn.Conv2d() nn.Conv2d 1.3 转置卷积—nn.ConvTranspose nn.ConvTranspose2d  二、池化层—Pooling Layer (1)nn.MaxPool2d (2)nn.AvgPool2d (3)nn.MaxUnpool2d  三、线性层—Linear Layer  nn.Linear  四、激活函数层—Activate Layer (1)nn.Sigmoid  (

    2024年01月20日
    浏览(44)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包