😄 无聊整理下torch里的张量的各种乘法相关操作。
0、简单提一下广播法则的定义:
- 1、让所有输入张量都向其中shape最长的矩阵看齐,shape不足的部分在前面加1补齐。
- 2、两个张量的维度要么在某一个维度一致,若不一致其中一个维度为1也可广播。否则不能广播。【如两个维度:(4, 1, 4)和(2, 1)可以广播,因为他们不相等的维度其中一个为1就可以广播了。】
1、torch.mm()
- 只适合于二维张量的矩阵乘法。
- m x n, n x p -> m x p
mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 4)
out = torch.mm(mat1, mat2)
out.shape
# torch.Size([2, 4])
2、torch.bmm()
- 只适合于三维张量的矩阵乘法,与torch.mm类似,但多了一个batch_size维度。
- b x m x n, b x n x p -> b x m x p
mat1 = torch.randn(8, 2, 3)
mat2 = torch.randn(8, 3, 4)
out = torch.bmm(mat1, mat2)
out.shape
# torch.Size([8, 2, 4])
3、torch.mul()和*
-
- ⭐ torch.mul()和*等价。
- 张量对应位置元素相乘。
- 将输入张量input的每个元素与另一个向量or标量other相乘,返回一个新的张量out,两者维度需满足广播规则
# 方式1:张量 和 标量相乘
input = torch.randn(3)
other = 100
out = torch.mul(input, other)
# 等价 out = input*other
out.shape
# torch.Size([3])
# 方式2:张量 和 张量(需满足广播规则)
input = torch.randn(4, 1, 4)
other = torch.randn(2, 1)
out = torch.mul(input, other)
# 等价 out = input*other
out.shape
# torch.Size([4, 2, 4])
# 方式3:元素对应项相乘
input = torch.randn(3, 2)
other = torch.randn(3, 2)
out = torch.mul(input, other)
# 等价 out = input*other
out.shape
# torch.Size([3, 2])
4、torch.dot()
向量点积:两向量对应位置相乘然后全部相加。只能支持两个一维向量。
5、torch.mv()
矩阵和向量的乘法
- 第一个参数只能是二维的,第二个参数是一维的,则在其维数末尾追加一个1,以实现矩阵乘法。在矩阵相乘之后,附加的维度被删除。如shape为:(2,6)和(6)运算过程:(2,6)和(6,1) -> (2,1) -> (2)
mat1 = torch.randn(6,8)
mat2 = torch.randn(8)
out = torch.mv(mat1, mat2)
out.shape
# torch.Size([6])
6、@
torch中的@操作是可以实现前面的某几个函数,是一种强大的操作。
- 若mat1和mat2都是两个一维向量,那么对应操作就是torch.dot()
- 若mat1是二维向量,mat2是一维向量,那么对应操作就是torch.mv()
- 若mat1和mat2都是两个二维向量,那么对应操作就是torch.mm()
7、torch.matmul()
torch.matmul()与@操作类似,但是torch.matmul()不止局限于一维和二维,可以进行高维张量的乘法。两个张量的矩阵乘积。其行为取决于张量的维数如下:
-
1、如果两个张量都是一维的,则返回点积(标量)。
-
2、如果两个参数都是二维的,则返回矩阵-矩阵乘积。
-
3、如果第一个参数是二维的,第二个参数是一维的,则在其维数末尾追加一个1,以实现矩阵乘法。在矩阵相乘之后,附加的维度被删除。如shape为:(2,6)和(6)运算过程:(2,6)和(6,1) -> (2,1) -> (2)
-
4、如果第一个参数是一维的(则在其维数前加一个1,),第二个参数是二维的,则返回矩阵乘法。在矩阵相乘之后,附加的维度被删除。如shape为:(6)和(6,2)运算过程:(1,6)和(6,2) -> (1,2) -> (2)
-
5、对3和4的总结。如果两个参数至少是一个参数是一维的,且至少一个参数是N维的(其中N > 2),则返回一个批处理矩阵乘法。如果第一个参数是一维的,则在其维数前加上1,以便批处理矩阵相乘,然后删除。如果第二个参数是一维的,则为批处理矩阵倍数的目的,将在其维上追加一个1,然后删除它。非矩阵(即批处理)维度是广播的(因此必须是可广播的)文章来源:https://www.toymoban.com/news/detail-440081.html
-
两个参数都是N维(>2),只有非矩阵的维度才是可以广播的,最后两维需满足矩阵乘法即
m x n, n x p -> m x p
。如bx1xnxm, kxmxp -> jxkxnxp
文章来源地址https://www.toymoban.com/news/detail-440081.html
>>> # vector x vector
>>> tensor1 = torch.randn(3)
>>> tensor2 = torch.randn(3)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([])
>>> # matrix x vector
>>> tensor1 = torch.randn(3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([3])
>>> # batched matrix x broadcasted vector
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])
>>> # batched matrix x batched matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(10, 4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
>>> # batched matrix x broadcasted matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
到了这里,关于pytorch中的矩阵乘法操作:torch.mm(), torch.bmm(), torch.mul()和*, torch.dot(), torch.mv(), @, torch.matmul()的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!