Pytorch 中 expand和repeat

这篇具有很好参考价值的文章主要介绍了Pytorch 中 expand和repeat。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

       在torch中,如果要改变某一个tensor的维度,可以利用view、expand、repeat、transpose和permute等方法,这里对这些方法的一些容易混淆的地方做个总结。

​expand和repeat函数是pytorch中常用于进行张量数据复制和维度扩展的函数,但其工作机制差别很大,本文对这两个函数进行对比。

1 torch.expand()

  • 作用: expand()函数可以将张量广播到新的形状。
  • 注意: 只能对维度值为1的维度进行扩展,无需扩展的维度,维度值不变,对应位置可写上原始维度大小或直接写作-1;且扩展的Tensor不会分配新的内存,只是原来的基础上创建新的视图并返回,返回的张量内存是不连续的。类似于numpy中的broadcast_to函数的作用。如果希望张量内存连续,可以调用contiguous函数。

expand函数用于将张量中单数维的数据扩展到指定的size。

首先解释下什么叫单数维(singleton dimensions),张量在某个维度上的size为1,则称为单数维。比如zeros(2,3,4)不存在单数维,而zeros(2,1,4)在第二个维度(即维度1)上为单数维。expand函数仅仅能作用于这些单数维的维度上。
参数*sizes用于逐个指定各个维度扩展后的大小(也可以理解为拓展的次数),对于不需要或者无法(即非单数维)进行扩展的维度,对应位置可写上原始维度大小或直接写作-1
expand函数可能导致原始张量的升维,其作用在张量前面的维度上(在tensor的低维增加更多维度),因此通过expand函数可将张量数据复制多份(可理解为沿着第一个batch的维度上)。
 

import torch

a = torch.tensor([1, 0, 2])     # a -> torch.Size([3])
b1 = a.expand(2, -1)            # 第一个维度为升维,第二个维度保持原样
'''
b1为 -> torch.Size([3, 2])
tensor([[1, 0, 2],
        [1, 0, 2]])
'''

a = torch.tensor([[1], [0], [2]])   # a -> torch.Size([3, 1])
b2 = a.expand(-1, 2)                 # 保持第一个维度,第二个维度只有一个元素,可扩展
'''
b2 -> torch.Size([3, 2])
b2为  tensor([[1, 1],
             [0, 0],
             [2, 2]])
'''

a = torch.Tensor([[1, 2, 3]])   # a -> torch.Size([1, 3])
b3 = a.expand(4, 3)              # 也可写为a.expand(4, -1)  对于某一个维度上的值为1的维度,
                                # 可以在该维度上进行tensor的复制,若大于1则不行
'''
b3 -> torch.Size([4, 3])
tensor(
	[[1.,2.,3.],
	[1.,2.,3.],
	[1.,2.,3.],
	[1.,2.,3.]]
)
'''

a = torch.Tensor([[1, 2, 3], [4, 5, 6]])  # a -> torch.Size([2, 3])
b4 = a.expand(4, 6)  # 最高几个维度的参数必须和原始shape保持一致,否则报错
'''
RuntimeError: The expanded size of the tensor (6) must match 
the existing size (3) at non-singleton dimension 1.
'''

b5 = a.expand(1, 2, 3)  # 可以在tensor的低维增加更多维度
'''
b5 -> torch.Size([1,2, 3])
tensor(
	[[[1.,2.,3.],
	 [4.,5.,6.]]]
)
'''
b6 = a.expand(2, 2, 3)  # 可以在tensor的低维增加更多维度,同时在新增加的低维度上进行tensor的复制
'''
b5 -> torch.Size([2,2, 3])
tensor(
	[[[1.,2.,3.],
	 [4.,5.,6.]],
	 [[1.,2.,3.],
	 [4.,5.,6.]]]
)
'''

b7 = a.expand(2, 3, 2)  # 不可在更高维增加维度,否则报错
'''
RuntimeError: The expanded size of the tensor (2) must match the 
existing size (3) at non-singleton dimension 2.
'''

b8 = a.expand(2, -1, -1)  # 最高几个维度的参数可以用-1,表示和原始维度一致
'''
b8 -> torch.Size([2,2, 3])
tensor(
	[[[1.,2.,3.],
	 [4.,5.,6.]],
	 [[1.,2.,3.],
	 [4.,5.,6.]]]
)
'''

# expand返回的张量与原版张量具有相同内存地址
print(b8.storage())  # 存储区的数据,说明expand后的a,aa,aaa,aaaa是共享storage的,
# 只是tensor的头信息区设置了不同的数据展示格式,从而使得a,aa,aaa,aaaa呈现不同的tensor形式
'''
1.0
2.0
3.0
4.0
5.0
6.0
'''

1.1 expand_as

 可视为expand的另一种表达,其size通过函数传递的目标张量的size来定义。

import torch
a = torch.tensor([1, 0, 2])
b = torch.zeros(2, 3)
c = a.expand_as(b)  # a照着b的维度大小进行拓展
# c为 tensor([[1, 0, 2],
#        [1, 0, 2]])

2 tensor.repeat()

  • 作用:和expand()作用类似,均是将tensor广播到新的形状。
  • 注意:不允许使用维度-1,1即为不变。

前文提及expand仅能作用于单数维,那对于非单数维的拓展,那就需要借助于repeat函数了。

tensor.repeat(*sizes)

参数*sizes指定了原始张量在各维度上复制的次数。整个原始张量作为一个整体进行复制,这与Numpy中的repeat函数截然不同,而更接近于tile函数的效果。

与expand不同,repeat函数会真正的复制数据并存放于内存中。repeat开辟了新的内存空间,torch.repeat返回的张量在内存中是连续的

import torch
a = torch.tensor([1, 0, 2])
b = a.repeat(3,2)  # 在轴0上复制3份,在轴1上复制2份
# b为 tensor([[1, 0, 2, 1, 0, 2],
#        [1, 0, 2, 1, 0, 2],
#        [1, 0, 2, 1, 0, 2]])


import torch
a = torch.Tensor([[1,2,3]])
'''
tensor(
	[[1.,2.,3.]]
)
'''

aa = a.repeat(4, 3) # 维度不变,在各个维度上进行数据复制
'''
tensor(
	[[1.,2.,3.,1.,2.,3.,1.,2.,3.],
	 [1.,2.,3.,1.,2.,3.,1.,2.,3.],
	 [1.,2.,3.,1.,2.,3.,1.,2.,3.],
	 [1.,2.,3.,1.,2.,3.,1.,2.,3.]]
)
'''

a = torch.Tensor([[1,2,3], [4, 5, 6]])
'''
tensor(
	[[1.,2.,3.],
	 [4.,5.,6.]]
)
'''
aa = a.repeat(4,6) # 维度不变,在各个维度上进行数据复制
'''
tensor(
	[[1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.],
	 [4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.],
	 [1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.],
	 [4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.],
	 [1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.],
	 [4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.],
	 [1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.],
	 [4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.]]
)
'''

aaa = a.repeat(1,2,3) # 可以在tensor的低维增加更多维度,并在各维度上复制数据
'''
tensor(
	[[[1.,2.,3.,1.,2.,3.,1.,2.,3.],
	  [4.,5.,6.,4.,5.,6.,4.,5.,6.],
	  [1.,2.,3.,1.,2.,3.,1.,2.,3.],
	  [4.,5.,6.,4.,5.,6.,4.,5.,6.]]]
)
'''
aaaa = a.repeat(2,3,1) # 可以在tensor的高维增加更多维度,并在各维度上复制数据
'''
tensor(
	[[[1.,2.,3.],
	  [4.,5.,6.],
	  [1.,2.,3.],
	  [4.,5.,6.],
	  [1.,2.,3.],
	  [4.,5.,6.]],
	 [[1.,2.,3.],
	  [4.,5.,6.],
	  [1.,2.,3.],
	  [4.,5.,6.],
	  [1.,2.,3.],
	  [4.,5.,6.]]]
)
'''

aaaaa = a.repeat(2, 3, -1) 
'''
RuntimeError: Trying to create tensor with negative dimension -3: [2,6,-3]
'''

print(aaaa.storage()) # 存储区的数据,说明repeat后的a,aa,aaa,aaaa是有各自独立的storage的
'''
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
'''

2.1 repeat_intertile

Pytorch中,与Numpyrepeat函数相类似的函数为torch.repeat_interleave

torch.repeat_interleave(input, repeats, dim=None)

参数input为原始张量,repeats为指定轴上的复制次数,而dim为复制的操作轴,若取值为None则默认将所有元素进行复制,并会返回一个flatten之后一维张量。与repeat将整个原始张量作为整体不同,repeat_interleave操作是逐元素的。

a = torch.tensor([[1], [0], [2]])
b = torch.repeat_interleave(a, repeats=3)   # 结果flatten
# b为tensor([1, 1, 1, 0, 0, 0, 2, 2, 2])

c = torch.repeat_interleave(a, repeats=3, dim=1)  # 沿着axis=1逐元素复制
# c为tensor([[1, 1, 1],
#        [0, 0, 0],
#        [2, 2, 2]])

总结
相同:
(1)都可以扩展维度,或在某个维度上进行tensor的复制

区别:
(1)参数意义不同,repeat的参数表示沿某维度的数据复制倍数,可为大于0的任何整数值;expand的参数表示tensor对应的维度上的值,且只有增加新的低维度时表示沿该低维度的数据复制倍数,其他参数必须和原始tensor保持一致
(2)返回的结果的存储区不同,repeat返回的tensor会重新拥有一个独立存储区,而expand返回的tensor则与原始tensor共享存储区文章来源地址https://www.toymoban.com/news/detail-768721.html

到了这里,关于Pytorch 中 expand和repeat的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 深度学习笔记(kaggle课程《Intro to Deep Learning》)

    深度学习是一种机器学习方法,通过构建和训练深层神经网络来处理和理解数据。它模仿人脑神经系统的工作方式,通过多层次的神经网络结构来学习和提取数据的特征。深度学习在图像识别、语音识别、自然语言处理等领域取得了重大突破,并被广泛应用于人工智能技术中

    2024年02月13日
    浏览(47)
  • 解锁深度表格学习(Deep Tabular Learning)的关键:算术特征交互

    近日,阿里云人工智能平台PAI与浙江大学吴健、应豪超老师团队合作论文《Arithmetic Feature Interaction is Necessary for Deep Tabular Learning》正式在国际人工智能顶会AAAI-2024上发表。本项工作聚焦于深度表格学习中的一个核心问题:在处理结构化表格数据(tabular data)时,深度模型是否

    2024年04月17日
    浏览(34)
  • 残差网络(ResNet) -深度学习(Residual Networks (ResNet) – Deep Learning)

    在第一个基于cnn的架构(AlexNet)赢得ImageNet 2012比赛之后,每个随后的获胜架构都在深度神经网络中使用更多的层来降低错误率。这适用于较少的层数,但当我们增加层数时,深度学习中会出现一个常见的问题,称为消失/爆炸梯度。这会导致梯度变为0或太大。因此,当我们增加

    2024年02月15日
    浏览(41)
  • Deep Learning Tuning Playbook(深度学习调参手册中译版)

    由五名研究人员和工程师组成的团队发布了《Deep Learning Tuning Playbook》,来自他们自己训练神经网络的实验结果以及工程师的一些实践建议,目前在Github上已有1.5k星。原项目地址 本文为《Deep Learning Tuning Playbook》中文翻译版本,全程手打,非机翻。因为本人知识水平有限,翻

    2023年04月27日
    浏览(67)
  • 基于深度学习的语音识别(Deep Learning-based Speech Recognition)

    随着科技的快速发展,人工智能领域取得了巨大的进步。其中,深度学习算法以其强大的自学能力,逐渐应用于各个领域,并取得了显著的成果。在语音识别领域,基于深度学习的技术也已经成为了一种主流方法,极大地推动了语音识别技术的发展。本文将从深度学习算法的

    2024年02月04日
    浏览(50)
  • Pytorch 中 expand和repeat

           在torch中,如果要改变某一个tensor的维度,可以利用view、expand、repeat、transpose和permute等方法,这里对这些方法的一些容易混淆的地方做个总结。 ​expand和repeat函数是pytorch中常用于进行张量数据复制和维度扩展的函数,但其工作机制差别很大,本文对这两个函数进

    2024年02月03日
    浏览(35)
  • 深度强化学习的变道策略:Harmonious Lane Changing via Deep Reinforcement Learning

    偏理论,假设情况不易发生 多智能体强化学习的换道策略,不同的智能体在每一轮学习后交换策略,达到零和博弈。 和谐驾驶仅依赖于单个车辆有限的感知结果来平衡整体和个体效率,奖励机制结合个人效率和整体效率的和谐。 自动驾驶不能过分要求速度性能, 考虑单个车

    2024年01月17日
    浏览(41)
  • 基于深度学习的目标检测的介绍(Introduction to object detection with deep learning)

    物体检测的应用已经深入到我们的日常生活中,包括安全、自动车辆系统等。对象检测模型输入视觉效果(图像或视频),并在每个相应对象周围输出带有标记的版本。这说起来容易做起来难,因为目标检测模型需要考虑复杂的算法和数据集,这些算法和数据集在我们说话的时

    2024年02月11日
    浏览(35)
  • 第二章:Learning Deep Features for Discriminative Localization ——学习用于判别定位的深度特征

            在这项工作中,我们重新审视了在[13]中提出的全局平均池化层,并阐明了它如何明确地使卷积神经网络(CNN)具有出色的定位能力,尽管它是在图像级别标签上进行训练的。虽然这个技术之前被提出作为一种训练规范化的手段, 但我们发现它实际上构建了一个通

    2024年02月15日
    浏览(34)
  • 基于深度学习的手写数字识别项目GUI(Deep Learning Project – Handwritten Digit Recognition using Python)

    一步一步教你建立手写数字识别项目,需要源文件的请可直接跳转下边的链接:All project 在本文中,我们将使用MNIST数据集实现一个手写数字识别应用程序。我们将使用一种特殊类型的深度神经网络,即卷积神经网络。最后,我们将构建一个GUI,您可以在其中绘制数字并立即

    2024年02月11日
    浏览(36)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包