cross attention输入不同维度的矩阵

这篇具有很好参考价值的文章主要介绍了cross attention输入不同维度的矩阵。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

一.问题背景

在学习使用cross attention的时候我查阅了很多资料,发现里面说的都是cross attention的输入需要是相同维度的矩阵,但是我所需要的是可以处理不同维度数据的cross attention。
cross attention

二.cross attention的代码

看了关于cross attention的一些介绍和代码,发现大多都是这样

import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttention(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(CrossAttention, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.query = nn.Linear(in_dim, out_dim, bias=False)
        self.key = nn.Linear(in_dim, out_dim, bias=False)
        self.value = nn.Linear(in_dim, out_dim, bias=False)

    def forward(self, x, y):
        batch_size = x.shape[0]
        num_queries = x.shape[1]
        num_keys = y.shape[1]
        x = self.query(x)
        y = self.key(y)
        # 计算注意力分数
        attn_scores = torch.matmul(x, y.transpose(-2, -1)) / (self.out_dim ** 0.5)
        attn_weights = F.softmax(attn_scores, dim=-1)
        # 计算加权和
        V = self.value(y)
        output = torch.bmm(attn_weights, V)
        
        return output

这里的x和y所输入的维度需要一致,那么从代码上看好像不太好分析如何进行改变,我们先看看cross attention的公式:

Cross-Attention ( Q , K , V ) = softmax ( ( W Q S 2 ) ( W K S 1 ) T ) W V S 1 \text{Cross-Attention}(Q,K,V) = \text{softmax}\left((W_{Q}S2)(W_{K}S1)^T\right)W_{V}S1 Cross-Attention(Q,K,V)=softmax((WQS2)WKS1T)WVS1

其中, Q Q Q为查询向量, K K K为编码器的键向量, V V V为编码器的值向量, d k d_k dk为编码器键向量的维度。

所以,当输入的维度不同的时候,我们可以对从W入手进行维度的变换和适配。

那W从哪儿来呢?

注意看代码中,QKV均过了一个线性层,所以,我们将线性层的输出改为我们所需要的输出,就可以完成不同维度的输入了。

代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttention(nn.Module):
    def __init__(self, in_dim, out_dim,in_q_dim,hid_q_dim):
        super(CrossAttention, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.in_q_dim = in_q_dim #新增
        self.hid_q_dim = hid_q_dim #新增
        # 定义查询、键、值三个线性变换
        self.query = nn.Linear(in_q_dim, hid_q_dim, bias=False) #变化
        self.key = nn.Linear(in_dim, out_dim, bias=False)
        self.value = nn.Linear(in_dim, out_dim, bias=False)
        
    def forward(self, x, y):
        # 对输入进行维度变换,为了方便后面计算注意力分数
        batch_size = x.shape[0]   # batch size
        num_queries = x.shape[1]  # 查询矩阵中的元素个数
        num_keys = y.shape[1]     # 键值矩阵中的元素个数
        x = self.query(x)  # 查询矩阵
        y = self.key(y)    # 键值矩阵
        # 计算注意力分数
        attn_scores = torch.matmul(x, y.transpose(-2, -1)) / (self.out_dim ** 0.5)  # 计算注意力分数,注意力分数矩阵的大小为 batch_size x num_queries x num_keys x num_keys
        attn_weights = F.softmax(attn_scores, dim=-1)  # 对注意力分数进行 softmax 归一化
        # 计算加权和
        V = self.value(y)  # 通过值变换得到值矩阵 V
        output = torch.bmm(attn_weights, V)  # 计算加权和,output 的大小为 batch_size x num_queries x num_keys x out_dim
       
        return output

例如:输入的两个矩阵分别为x=[batch, 1024, 512] y=[batch, 1024, 1024],其中X作为被用作查询,那么下面是一个实例:文章来源地址https://www.toymoban.com/news/detail-439378.html

# 定义输入矩阵 x 和 y,其大小分别为 1024, 512 和 1024, 1024
x = torch.randn(1, 1024, 512)
y = torch.randn(1, 1024, 1024)

# 创建 CrossAttention 模型,并对输入进行前向传播
cross_attn = CrossAttention(in_dim=1024, out_dim=1024,in_q_dim=512,hid_q_dim=1024)
output = cross_attn(x=x, y=y)

# 输出新的矩阵大小
print(output.shape) # (1, 1024,1024,1024)
print(output)

到了这里,关于cross attention输入不同维度的矩阵的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 论文阅读 | Cross-Attention Transformer for Video Interpolation

    前言:ACCV2022wrokshop用transformer做插帧的文章,q,kv,来自不同的图像 代码:【here】 传统的插帧方法多用光流,但是光流的局限性在于 第一:它中间会算至少两个 cost volumes,它是四维的,计算量非常大 第二:光流不太好处理遮挡(光流空洞)以及运动的边缘(光流不连续)

    2024年02月09日
    浏览(44)
  • What the DAAM: Interpreting Stable Diffusion Using Cross Attention

    论文链接:https://arxiv.org/pdf/2210.04885.pdf Background 在读本篇文章之前先来了解深度学习的可解释性,可解释性方法有类激活映射CAM、基于梯度的方法、反卷积等,在diffusion模型出来之后,本篇文章就对扩散模型中的交叉注意力做了探究,主要做的工作是用交叉注意力来解释扩散

    2024年02月09日
    浏览(42)
  • Cross-Modal Learning with 3D Deformable Attention for Action Recognition

    标题:基于三维可变形注意力的跨模态学习用于动作识别 发表:ICCV2023 在基于视觉的动作识别中,一个重要的挑战是将具有两个或多个异构模态的时空特征嵌入到单个特征中。在这项研究中,我们提出了一种 新的三维变形变压器 ,用于动作识别, 具有自适应时空感受野和跨

    2024年03月24日
    浏览(63)
  • ​目标检测算法——YOLOv5/YOLOv7改进之结合Criss-Cross Attention

    论文题目: CCNet: Criss-Cross Attention for Semantic Segmentation 论文地址: https://arxiv.org/pdf/1811.11721.pdf 代码地址:https://github.com/shanglianlm0525/CvPytorch 本文是ICCV2019的语义分割领域的文章,旨在解决long-range dependencies问题,提出了基于十字交叉注意力机制(Criss-Cross Attention)的模块,利

    2024年02月02日
    浏览(58)
  • 【论文简述】Cross-Attentional Flow Transformer for Robust Optical Flow(CVPR 2022)

    1. 第一作者: Xiuchao Sui、Shaohua Li 2. 发表年份: 2021 3. 发表期刊: arxiv 4. : 光流、Transformer、自注意力、交叉注意力、相关体 5. 探索动机: 由于卷积的局部性和刚性权重,有限的上下文信息被纳入到像素特征中,并且计算出的相关性具有很高的随机性,以至于大多数

    2024年02月03日
    浏览(56)
  • 深入理解Transformer,兼谈MHSA(多头自注意力)、Cross-Attention(交叉注意力)、LayerNorm、FFN、位置编码

    Transformer其实不是完全的Self-Attention(SA,自注意力)结构,还带有Cross-Attention(CA,交叉注意力)、残差连接、LayerNorm、类似1维卷积的Position-wise Feed-Forward Networks(FFN)、MLP和Positional Encoding(位置编码)等 本文涵盖Transformer所采用的MHSA(多头自注意力)、LayerNorm、FFN、位置编

    2024年04月12日
    浏览(65)
  • 第二十一章:CCNet:Criss-Cross Attention for Semantic Segmentation ——用于语义分割的交叉注意力

    原文题目:《CCNet:Criss-Cross Attention for Semantic Segmentation 》 原文引用:Huang Z, Wang X, Huang L, et al. Ccnet: Criss-cross attention for semantic segmentation[C]//Proceedings of the IEEE/CVF international conference on computer vision. 2019: 603-612. 原文链接: https://openaccess.thecvf.com/content_ICCV_2019/papers/Huang_CCNet_Criss

    2024年02月16日
    浏览(43)
  • tensor的不同维度种类

    0维标量(scalar),1维向量(vector),二维矩阵(matrix),3维以上n维张量

    2024年02月08日
    浏览(40)
  • Pytorch模型如何查看每层输入维度输出维度

    在 PyTorch 中,可以使用 torchsummary 库来实现对 PyTorch 模型的结构及参数统计的输出,其可以方便我们查看每层输入、输出的维度以及参数数量等信息。 安装 torchsummary 库: 使用方法如下: 其中, model 是需要查看的模型, (3, 32, 32) 表示模型的输入维度,即 C = 3,H = 32,W = 32。

    2024年02月16日
    浏览(40)
  • pytorch中选取不同维度sum和mean方法理解

    在PyTorch中,sum()函数用于对输入张量的所有元素进行求和操作。该函数的语法如下: 具体而言,sum()函数会对输入张量的所有元素进行求和操作,并返回一个标量值。 如果指定了dim参数,则会沿着指定的维度对输入张量进行求和操作,并返回一个形状与输入张量除了指定维度

    2024年02月13日
    浏览(35)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包