RuntimeError: mat1 and mat2 shapes cannot be multiplied (5760x6 and 128x4)

这篇具有很好参考价值的文章主要介绍了RuntimeError: mat1 and mat2 shapes cannot be multiplied (5760x6 and 128x4)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

在使用pytorch框架定义子类网络结构时,有时可能会出现mat1和mat2的形状不匹配的这种问题。如下,定义了一个7层的cnn网络:

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=16,
                kernel_size=3,
                stride=1,
                padding=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16,32,3,1,1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(32,64,2,1,1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.out = nn.Linear(128,4)
    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        output = self.out(x)
        return output

此时就会出现如下的错误 :

mat1 and mat2 shapes cannot be multiplied,深度学习,python,pytorch

这种问题源于所定义的最后一层池化层输出的形状和全连接层输入的形状不一样。我们通过在前向传播函数中打印池化层的输出形状可知:

def forward(self,x):
    x = self.conv1(x)
    x = self.conv2(x)
    print(x.shape)
    output = self.out(x)
    return output


>> torch.Size([30, 32, 6, 6])  #池化层输出形状

([30, 32, 6, 6]) 其中的30是设置的batch_size,后三维才是其真正的形状,而全连接层的输入是一维特征,因此需要添加一个flatten层进行压平操作。压平后如下:

torch.Size([30, 1152])

鉴于pytorch框架的特点, 需要再添加一个全连接层来衔接压平层和最后一层全连接层,其输入形状为1152,输出为128。(即在以上代码conv3和out再封装一个层):

    def __init__(self):
        super(CNN,self).__init__()
      ...... 
         self.conv3 = nn.Sequential(
            nn.Conv2d(32,64,2,1,1),
            nn.ReLU(),    
            nn.MaxPool2d(2),
        )
        self.dense = nn.Sequential(
            nn.Flatten(),
            nn.Linear(1152,128),
            nn.Linear(128,4),
        )

    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        output = self.out(x)
        return output

代码不再报错,训练网络成功。mat1 and mat2 shapes cannot be multiplied,深度学习,python,pytorch 

 文章来源地址https://www.toymoban.com/news/detail-598805.html

到了这里,关于RuntimeError: mat1 and mat2 shapes cannot be multiplied (5760x6 and 128x4)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • RuntimeError: mat1 dim 1 must match mat2 dim 0 解决方法

    RuntimeError: mat1 dim 1 must match mat2 dim 0 这个错误的大概意思是:矩阵mat1 的第二维度要与mat2的第一维度不匹配 在新增别的数据集进行训练时报当前错误,原因是输入的图像大小与之前不一样,这是新手在学习时常会遇到的问题。 先看报错信息,确定报错位置 我的这个代码是

    2024年02月15日
    浏览(73)
  • RuntimeError: Input type (unsigned char) and bias type (float) should be the same错误

    这个错误通常是由于输入数据类型与模型参数的类型不匹配导致的。在PyTorch中,当输入的张量类型与模型的参数类型不匹配时,PyTorch会尝试将它们转换为相同的类型,但是当它们的类型不可转换时,就会出现这个错误。 解决办法是确保输入的张量类型与模型的参数类型相同

    2024年02月15日
    浏览(38)
  • RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the

    问题描述: mobilenetv3在残差块中加入了注意力机制 用GPU 进行训练时报的错 解决方法1: 1,不用GPU 用CPU 就可以 CUDA 设置为False,确实可以解决,但是不用GPU 好像意义不大 解决方法2 : 用仍然用GPU ,看下面的的解决方案: 报错的原因:2 1,我直接在倒残差块的前向传播内对导

    2024年02月16日
    浏览(37)
  • Commit cannot be completed since the group has already rebalanced and assign

    这里是说提交commit失败, 因为这个组已经重新分配了 正常情况下, kafka会有一个配置用于设置一条消息的过期时间, 在规定时间内, 如果消费者提交了消费完成的信息, 那么就可以正常的分配下一条记录给消费者, 并且将当前记录的状态记为\\\"已消费\\\"状态, 对消息队列做一个标识

    2024年02月11日
    浏览(50)
  • ERROR: Could not build wheels for pycocotools which use PEP 517 and cannot be installed directl

    安装yolov5依赖库时,最后pycocotools报错 重点是以下原因: error: Microsoft Visual C++ 14.0 or greater is required. Get it with \\\"Microsoft C++ Build Tools\\\": https://visualstudio.microsoft.com/visual-cpp-build-tools/ 尝试的解决方法如下: 1、直接下载VS2022中,工作负荷里有关C++和Python的(未成功) 结果报错 vs2

    2024年02月02日
    浏览(51)
  • git fatal: ‘xxx‘ is not a commit and a branch ‘xxx‘ ‘ cannot be created from it

    当拉取一个git远程仓库分支时报错: 命令:git checkout -b 本地分支名 远程分支名 报错:fatal: \\\'origin/dev_v2.8.4_v10.74.1\\\' is not a commit and a branch \\\'dev_v2.8.4_v10.74.1\\\' cannot be created from it 远程新建的分支没有更新到本地。实际上,git仓库分为本地仓库和远程仓库,我们用 checkout 命令是从本

    2024年02月10日
    浏览(41)
  • Could not build wheels for opencv-python which use PEP 517 and cannot be installed directly

    当我们运行代码要运用到cv2库时,提示我们没有安装cv2,而直接用pip install opencv-python下载却显示下载失败: Could not build wheels for opencv-python which use PEP 517 and cannot be installed directly 直接运用conda安装: 随后完成cv2的安装。

    2024年02月10日
    浏览(38)
  • 【Git报错】fatal: ‘origin/XXX‘ is not a commit and a branch ‘XXX‘ cannot be created from it

    发现问题 远程已有分支,本地需要新建对应分支,于是执行命令: git checkout --track origin/XXX 时报错。 原因: 远程真的没有这个分支,所以失败 这个情况没什么好说的 远程有这个分支,但是本地认为远程没有这个分支 执行如下命令,查看本地缓存的所有远程分支,看看你要

    2024年02月16日
    浏览(49)
  • Could not build wheels for opencv-python-headless which use PEP 517 and cannot be installed directly

    笔者是python环境下安装 albumentations 出现的,该库经常用于图像增强,在cv领域有很大的知名度。在使用下边的命令进行安装后 就报了 ERROR:Could not build wheels for opencv-python-headless which use PEP 517 and cannot be installed directly 。 albumentations库依赖opencv,在直接使用pip命令安装时,albumen

    2024年02月15日
    浏览(64)
  • ERROR: Could not build wheels for opencv-python which use PEP 517 and cannot be installed directly

    pip install --upgrade -r requirements.txt -i https://mirror.baidu.com/pypi/simple Looking in indexes: https://mirror.baidu.com/pypi/simple Collecting prettytable Downloading https://mirror.baidu.com/pypi/packages/5f/ab/64371af206988d7b15c8112c9c277b8eb4618397c01471e52b902a17f59c/prettytable-3.3.0-py3-none-any.whl (26 kB) Collecting ujson Downloading https://

    2024年01月22日
    浏览(68)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包