注意力机制——Spatial Transformer Networks(STN)

这篇具有很好参考价值的文章主要介绍了注意力机制——Spatial Transformer Networks(STN)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

Spatial Transformer Networks(STN)是一种空间注意力模型,可以通过学习对输入数据进行空间变换,从而增强网络的对图像变形、旋转等几何变换的鲁棒性。STN 可以在端到端的训练过程中自适应地学习变换参数,无需人为设置变换方式和参数。

STN 的基本结构包括三个部分:定位网络(Localization Network)、网格生成器(Grid Generator)和采样器(Sampler)。定位网络通常由卷积层、全连接层和激活函数构成,用于学习输入数据的空间变换参数。网格生成器用于生成采样网格,采样器则根据采样网格对输入数据进行采样。整个 STN 模块可以插入到任意位置,用于提高网络的对图像变形、旋转等几何变换的鲁棒性。

在 STN 中,定位网络通常由一个多层感知器(MLP)和一些辅助层(如卷积层、全连接层和激活函数)构成。MLP 的输出用于计算变换参数(如平移、旋转和缩放等),从而生成采样网格。采样器通常由双线性插值、最近邻插值和反卷积等方法实现,用于对输入数据进行采样。

STN 的优点在于,它可以学习对输入数据进行任意复杂的空间变换,从而提高网络的对图像变形、旋转等几何变换的鲁棒性。此外,STN 可以与其他深度学习模型结合使用,从而提高整个系统的性能。例如,在图像分类任务中,可以将 STN 插入到卷积神经网络中,用于对输入图像进行空间变换,增强网络对图像变形、旋转等几何变换的鲁棒性。

STN注意力模块pytorch实现:

import torch
import torch.nn as nn
import torch.nn.functional as F

class STN(nn.Module):
    def __init__(self):
        super(STN, self).__init__()
        # 定义本地化网络,用于估计空间变换的参数
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7), # 输入通道数为 1,输出通道数为 8,卷积核大小为 7
            nn.MaxPool2d(2, stride=2), # 最大池化层,核大小为 2,步长为 2
            nn.ReLU(True), # ReLU 激活函数
            nn.Conv2d(8, 10, kernel_size=5), # 输入通道数为 8,输出通道数为 10,卷积核大小为 5
            nn.MaxPool2d(2, stride=2), # 最大池化层,核大小为 2,步长为 2
            nn.ReLU(True) # ReLU 激活函数
        )
        # 定义空间变换网络,用于预测空间变换的参数
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 3 * 3, 32), # 全连接层,输入维度为 10 * 3 * 3,输出维度为 32
            nn.ReLU(True), # ReLU 激活函数
            nn.Linear(32, 3 * 2) # 全连接层,输入维度为 32,输出维度为 3 * 2
        )
        # 初始化空间变换网络的权重和偏置
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    def forward(self, x):
        # 使用本地化网络对输入图像进行特征提取
        xs = self.localization(x)
        # 将特征张量展开成一维张量
        xs = xs.view(-1, 10 * 3 * 3)
        # 使用空间变换网络预测空间变换的参数
        theta = self.fc_loc(xs)
        # 将一维张量转换成二维张量,用于执行仿射变换
        theta = theta.view(-1, 2, 3)
        # 使用仿射变换对输入图像进行空间变换
        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)
        return x

以上代码中,STN 类继承自 PyTorch 的 nn.Module 类,是一个包含了本地化网络和空间变换网络的模块。具体来说,STN 模块包含以下组件:

  • self.localization:本地化网络,用于对输入图像进行特征提取,提取出用于估计空间变换参数的特征向量。
  • self.fc_loc:空间变换网络,用于根据本地化网络提取的特征向量预测空间变换的参数。
  • self.fc_loc[2].weight.data.zero_() 和 self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)):用于初始化空间变换网络的权重和偏置,其中权重矩阵初始化为零矩阵,偏置向量初始化为一个 torch.tensor 对象,其元素为 [1,0,0,0,1,0][1,0,0,0,1,0],表示初始的空间变换为一个单位矩阵。
  • forward 方法:模块的前向传播过程。首先使用本地化网络对输入图像进行特征提取,然后将特征张量展开成一维张量,使用空间变换网络预测空间变换的参数。接着将一维张量转换成二维张量,用于执行仿射变换,并使用仿射变换对输入图像进行空间变换,最后返回变换后的图像张量。

STN模块在模型中添加:文章来源地址https://www.toymoban.com/news/detail-726295.html

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.stn = STN()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        # 使用 STN 对输入图像进行空间变换
        x = self.stn(x)
        # 经过卷积和池化层处理
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

到了这里,关于注意力机制——Spatial Transformer Networks(STN)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • Transformer中的注意力机制及代码

    最近在学习transformer,首先学习了多头注意力机制,这里积累一下自己最近的学习内容。本文有大量参考内容,包括但不限于: ① 注意力,多注意力,自注意力及Pytorch实现 ② Attention 机制超详细讲解(附代码) ③ Transformer 鲁老师机器学习笔记 ④ transformer中: self-attention部分是否需

    2023年04月11日
    浏览(35)
  • 图解transformer中的自注意力机制

    本文将将介绍注意力的概念从何而来,它是如何工作的以及它的简单的实现。 在整个注意力过程中,模型会学习了三个权重:查询、键和值。查询、键和值的思想来源于信息检索系统。所以我们先理解数据库查询的思想。 假设有一个数据库,里面有所有一些作家和他们的书籍

    2024年02月09日
    浏览(35)
  • 大模型基础之注意力机制和Transformer

    核心思想:在decoder的每一步,把encoder端所有的向量提供给decoder,这样decoder根据当前自身状态,来自动选择需要使用的向量和信息. decoder在每次生成时可以关注到encoder端所有位置的信息。 通过注意力地图可以发现decoder所关注的点。 注意力使网络可以对齐语义相关的词汇。

    2024年02月11日
    浏览(29)
  • 【】理解ChatGPT之注意力机制和Transformer入门

    作者:黑夜路人 时间:2023年4月27日 想要连贯学习本内容请阅读之前文章: 【原创】理解ChatGPT之GPT工作原理 【原创】理解ChatGPT之机器学习入门 【原创】AIGC之 ChatGPT 高级使用技巧 GPT是什么意思 GPT 的全称是 Generative Pre-trained Transformer(生成型预训练变换模型),它是基于大

    2024年02月16日
    浏览(32)
  • 【计算机视觉 | 注意力机制】13种即插即用涨点模块分享!含注意力机制、卷积变体、Transformer变体等

    用即插即用的模块“缝合”,加入自己的想法快速搭积木炼丹。 这种方法可以简化模型设计,减少冗余工作,帮助我们快速搭建模型结构,不需要从零开始实现所有组件。除此以外,这些即插即用的模块都具有标准接口,意味着我们可以很方便地替换不同的模块进行比较,加

    2024年02月04日
    浏览(37)
  • 【Transformer】自注意力机制Self-Attention

    \\\"Transformer\\\"是一种深度学习模型,首次在\\\"Attention is All You Need\\\"这篇论文中被提出,已经成为自然语言处理(NLP)领域的重要基石。这是因为Transformer模型有几个显著的优点: 自注意力机制(Self-Attention) :这是Transformer最核心的概念,也是其最大的特点。 通过自注意力机制,模

    2024年02月13日
    浏览(24)
  • Transformer(一)简述(注意力机制,NLP,CV通用模型)

    目录 1.Encoder 1.1简单理解Attention 1.2.什么是self-attention 1.3.怎么计算self-attention 1.4.multi-headed(q,k,v不区分大小写) 1.5.位置信息表达  2.Decoder(待补充)  3.BERT 参考文献 比方说,下图中的热度图中我们希望专注于小鸟,而不关注背景信息。那么如何关注文本和图像中的重点呢

    2024年02月13日
    浏览(27)
  • 解码Transformer:自注意力机制与编解码器机制详述与代码实现

    本文全面探讨了Transformer及其衍生模型,深入分析了自注意力机制、编码器和解码器结构,并列举了其编码实现加深理解,最后列出基于Transformer的各类模型如BERT、GPT等。文章旨在深入解释Transformer的工作原理,并展示其在人工智能领域的广泛影响。 作者 TechLead,拥有10+年互

    2024年02月13日
    浏览(32)
  • 图解Vit 2:Vision Transformer——视觉问题中的注意力机制

    上节回顾 在Transformer之前的RNN,其实已经用到了注意力机制。Seq2Seq。 对于Original RNN,每个RNN的输入,都是对应一个输出。对于original RNN,他的输入和输出必须是一样的。 在处理不是一对一的问题时,提出了RNN Seq2Seq。也就是在前面先输入整体,然后再依次把对应的输出出来

    2024年02月17日
    浏览(32)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包