【transformer】自注意力源码解读和复杂度计算

这篇具有很好参考价值的文章主要介绍了【transformer】自注意力源码解读和复杂度计算。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

Self-attention

【transformer】自注意力源码解读和复杂度计算,深度学习,transformer,深度学习,人工智能

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V

其中, Q Q Q为查询向量, K K K V V V为键向量和值向量, d k d_k dk为向量的维度。 Q Q Q K K K V V V在一般情况下是相同的。公式中的softmax函数将分数归一化为概率,得到加权的值向量。这里的注意力机制是通过计算查询向量 Q Q Q和键向量 K K K之间的相似性,来为值向量 V V V分配不同的权重。如果两个向量越相似,则它们之间的权重应该越大,反之则越小。

def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)  # 获取文本嵌入维度大小
    # 按照注意力机制的公式计算注意力分数
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k)
    # 是否使用掩码
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    # 使用softmax对最后一个维度获得注意力张量
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    # 注意力张量与value相乘得到query的注意力表示
    return torch.matmul(p_attn, value), p_attn

一个形状为 N × M N\times M N×M 的矩阵,与另一个形状为 M × P M\times P M×P的矩阵相乘,其运算复杂度来源于乘法操作的次数,时间复杂度为 O ( N M P ) O(NMP) O(NMP)

Self-attention的公式如下:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V其中, Q Q Q为查询向量, K K K V V V为键向量和值向量, d k d_k dk为向量的维度。 Q Q Q K K K V V V在一般情况下是相同的。公式中的softmax函数将分数归一化为概率,得到加权的值向量。
Self-Attention的计算复杂度主要来自三个方面:查询矩阵、键矩阵和值矩阵的乘积、softmax 的计算、以及输出向量和值的加权平均。
对于一个由n个单词组成的输入序列,假设有d个维度的特征,那么查询矩阵、键矩阵和值矩阵的维度都将是 n × d。

  • 对于查询矩阵 Q 和键矩阵 K 的点积, n × d n\times d n×d d × n d\times n d×n计算复杂度是 O ( n 2 d ) O(n^2d) O(n2d)
  • 每行 softmax 的计算,计算复杂度为 O ( n ) O(n) O(n),对n行做softmax,复杂度为 O ( n 2 ) O(n^2) O(n2)
  • 对于值矩阵 V (维度 n × d n\times d n×d)和 softmax 后的结果(维度 n × n n\times n n×n)进行点积,得到每个查询向量的加权平均值,复杂度是 O ( n 2 d ) O(n^2d) O(n2d)

因此,总的计算复杂度是 O ( n 2 d ) + O ( n 2 ) + O ( n 2 d ) ≃ O ( n 2 d ) O(n^2d) + O(n^2) + O(n^2d) \simeq O(n^2d) O(n2d)+O(n2)+O(n2d)O(n2d)
由于这个复杂度是关于输入序列长度n的平方级别,因此Self-Attention在处理长序列时可能会面临计算上的挑战。

多头注意力

【transformer】自注意力源码解读和复杂度计算,深度学习,transformer,深度学习,人工智能
多头注意力的计算公式如下:
MultiHead ⁡ ( Q , K , V ) = Concat ⁡ ( head ⁡ 1 , … ,  head  h ) W O  where   head  i = A ( Q W i Q , K W i K , V W i V ) \begin{aligned} \operatorname{MultiHead}(Q, K, V) & =\operatorname{Concat}\left(\operatorname{head}_1, \ldots, \text { head }_{\mathrm{h}}\right) W^O \\ \text { where } \text { head }_{\mathrm{i}} & =A\left(Q W_i^Q, K W_i^K, V W_i^V\right) \end{aligned} MultiHead(Q,K,V) where  head i=Concat(head1,, head h)WO=A(QWiQ,KWiK,VWiV)其中, Q , K , V Q,K,V Q,K,V 分别表示查询、键和值, h h h 表示头数, h e a d i head_i headi 表示第 i i i 个注意力头, W O W^O WO 表示输出层的权重矩阵。

# 用于深度拷贝的copy工具包
import copy

# 首先需要定义克隆函数, 因为在多头注意力机制的实现中, 用到多个结构相同的线性层.
# 我们将使用clone函数将他们一同初始化在一个网络层列表对象中. 之后的结构中也会用到该函数.
def clones(module, N):
    """用于生成相同网络层的克隆函数, 它的参数module表示要克隆的目标网络层, N代表需要克隆的数量"""
    # 在函数中, 我们通过for循环对module进行N次深度拷贝, 使其每个module成为独立的层,
    # 然后将其放在nn.ModuleList类型的列表中存放.
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

# 我们使用一个类来实现多头注意力机制的处理
class MultiHeadedAttention(nn.Module):
    def __init__(self, head, embedding_dim, dropout=0.1):
        """在类的初始化时, 会传入三个参数,head代表头数,embedding_dim代表词嵌入的维度, 
           dropout代表进行dropout操作时置0比率,默认是0.1."""
        super(MultiHeadedAttention, self).__init__()

        # 在函数中,首先使用了一个测试中常用的assert语句,判断h是否能被d_model整除,
        # 这是因为我们之后要给每个头分配等量的词特征.也就是embedding_dim/head个.
        assert embedding_dim % head == 0

        # 得到每个头获得的分割词向量维度d_k
        self.d_k = embedding_dim // head

        # 传入头数h
        self.head = head

        # 然后获得线性层对象,通过nn的Linear实例化,它的内部变换矩阵是embedding_dim x embedding_dim,然后使用clones函数克隆四个,
        # 为什么是四个呢,这是因为在多头注意力中,Q,K,V各需要一个,最后拼接的矩阵还需要一个,因此一共是四个.
        self.linears = clones(nn.Linear(embedding_dim, embedding_dim), 4)

        # self.attn为None,它代表最后得到的注意力张量,现在还没有结果所以为None.
        self.attn = None

        # 最后就是一个self.dropout对象,它通过nn中的Dropout实例化而来,置0比率为传进来的参数dropout.
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        """前向逻辑函数, 它的输入参数有四个,前三个就是注意力机制需要的Q, K, V,
           最后一个是注意力机制中可能需要的mask掩码张量,默认是None. """

        # 如果存在掩码张量mask
        if mask is not None:
            # 使用unsqueeze拓展维度
            mask = mask.unsqueeze(0)

        # 接着,我们获得一个batch_size的变量,他是query尺寸的第1个数字,代表有多少条样本.
        batch_size = query.size(0)

        # 之后就进入多头处理环节
        # 首先利用zip将输入QKV与三个线性层组到一起,然后使用for循环,将输入QKV分别传到线性层中,
        # 做完线性变换后,开始为每个头分割输入,这里使用view方法对线性变换的结果进行维度重塑,多加了一个维度h,代表头数,
        # 这样就意味着每个头可以获得一部分词特征组成的句子,其中的-1代表自适应维度,
        # 计算机会根据这种变换自动计算这里的值.然后对第二维和第三维进行转置操作,
        # 为了让代表句子长度维度和词向量维度能够相邻,这样注意力机制才能找到词义与句子位置的关系,
        # 从attention函数中可以看到,利用的是原始输入的倒数第一和第二维.这样我们就得到了每个头的输入.
        query, key, value = \
           [model(x).view(batch_size, -1, self.head, self.d_k).transpose(1, 2)
            for model, x in zip(self.linears, (query, key, value))]

        # 得到每个头的输入后,接下来就是将他们传入到attention中,
        # 这里直接调用我们之前实现的attention函数.同时也将mask和dropout传入其中.
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)

        # 通过多头注意力计算后,我们就得到了每个头计算结果组成的4维张量,我们需要将其转换为输入的形状以方便后续的计算,
        # 因此这里开始进行第一步处理环节的逆操作,先对第二和第三维进行转置,然后使用contiguous方法,
        # 这个方法的作用就是能够让转置后的张量应用view方法,否则将无法直接使用,
        # 所以,下一步就是使用view重塑形状,变成和输入形状相同.
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.head * self.d_k)

        # 最后使用线性层列表中的最后一个线性层对输入进行线性变换得到最终的多头注意力结构的输出.
        return self.linears[-1](x)

在多头注意力中,假设有 h h h 个头,每个头的查询、键和值的维度是 d k d_k dk d k d_k dk d v d_v dv,一般情况 d q = d k = d v = d h d_q=d_k=d_v=\frac{d}{h} dq=dk=dv=hd, 输入序列的长度为 N N N

  • 输入线性映射的复杂度: n × d n\times d n×d d × d h d \times \frac{d}{h} d×hd,计算复杂度是 O ( n d 2 h ) O(\frac{nd^2 }{h}) O(hnd2)
  • 注意力计算:输入线性映射后的维度 n × d h n \times \frac{d}{h} n×hd n × d h n \times \frac{d}{h} n×hd d h × n \frac{d}{h}\times n hd×n计算复杂度是 O ( n 2 d h ) O(n^2\frac{d}{h}) O(n2hd)
  • 输出线性映射: 多个头的结果concat成一个 n × d n\times d n×d矩阵, n × d n\times d n×d d × d d \times d d×d,计算复杂度是 O ( n d 2 ) O(nd^2) O(nd2)

总时间复杂度 O ( n d 2 h + n 2 d h + n d 2 ) O(\frac{nd^2}{h}+n^2\frac{d}{h}+nd^2) O(hnd2+n2hd+nd2)


参考:
传智博客-Transformer文章来源地址https://www.toymoban.com/news/detail-695057.html

到了这里,关于【transformer】自注意力源码解读和复杂度计算的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 注意力机制和Transformer

    机器翻译是NLP领域中最重要的问题之一,也是Google翻译等工具的基础。传统的RNN方法使用两个循环网络实现序列到序列的转换,其中一个网络(编码器)将输入序列转换为隐藏状态,而另一个网络(解码器)则将该隐藏状态解码为翻译结果。但是,这种方法存在两个问题:

    2024年02月09日
    浏览(38)
  • 简单理解Transformer注意力机制

    这篇文章是对《动手深度学习》注意力机制部分的简单理解。 生物学中的注意力 生物学上的注意力有两种,一种是无意识的,零一种是有意识的。如下图1,由于红色的杯子比较突出,因此注意力不由自主指向了它。如下图2,由于有意识的线索是想要读书,即使红色杯子比较

    2024年02月03日
    浏览(27)
  • Transformer中的注意力机制及代码

    最近在学习transformer,首先学习了多头注意力机制,这里积累一下自己最近的学习内容。本文有大量参考内容,包括但不限于: ① 注意力,多注意力,自注意力及Pytorch实现 ② Attention 机制超详细讲解(附代码) ③ Transformer 鲁老师机器学习笔记 ④ transformer中: self-attention部分是否需

    2023年04月11日
    浏览(35)
  • 图解transformer中的自注意力机制

    本文将将介绍注意力的概念从何而来,它是如何工作的以及它的简单的实现。 在整个注意力过程中,模型会学习了三个权重:查询、键和值。查询、键和值的思想来源于信息检索系统。所以我们先理解数据库查询的思想。 假设有一个数据库,里面有所有一些作家和他们的书籍

    2024年02月09日
    浏览(35)
  • Transformer仅有自注意力还不够?微软联合巴斯大学提出频域混合注意力SpectFormer

    本文介绍一篇来自 英国巴斯大学(University of Bath)与微软合作完成的工作, 研究者从频率域角度入手探究视觉Transformer结构中的频域注意力和多头注意力在视觉任务中各自扮演的作用。 论文链接: https://arxiv.org/abs/2304.06446 项目主页: https://badripatro.github.io/SpectFormers/ 代码链

    2024年02月07日
    浏览(34)
  • 大模型基础之注意力机制和Transformer

    核心思想:在decoder的每一步,把encoder端所有的向量提供给decoder,这样decoder根据当前自身状态,来自动选择需要使用的向量和信息. decoder在每次生成时可以关注到encoder端所有位置的信息。 通过注意力地图可以发现decoder所关注的点。 注意力使网络可以对齐语义相关的词汇。

    2024年02月11日
    浏览(30)
  • 注意力机制——Spatial Transformer Networks(STN)

    Spatial Transformer Networks(STN)是一种空间注意力模型,可以通过学习对输入数据进行空间变换,从而增强网络的对图像变形、旋转等几何变换的鲁棒性。STN 可以在端到端的训练过程中自适应地学习变换参数,无需人为设置变换方式和参数。 STN 的基本结构包括三个部分:定位网

    2024年02月07日
    浏览(36)
  • 【】理解ChatGPT之注意力机制和Transformer入门

    作者:黑夜路人 时间:2023年4月27日 想要连贯学习本内容请阅读之前文章: 【原创】理解ChatGPT之GPT工作原理 【原创】理解ChatGPT之机器学习入门 【原创】AIGC之 ChatGPT 高级使用技巧 GPT是什么意思 GPT 的全称是 Generative Pre-trained Transformer(生成型预训练变换模型),它是基于大

    2024年02月16日
    浏览(32)
  • 深入理解Transformer,兼谈MHSA(多头自注意力)、Cross-Attention(交叉注意力)、LayerNorm、FFN、位置编码

    Transformer其实不是完全的Self-Attention(SA,自注意力)结构,还带有Cross-Attention(CA,交叉注意力)、残差连接、LayerNorm、类似1维卷积的Position-wise Feed-Forward Networks(FFN)、MLP和Positional Encoding(位置编码)等 本文涵盖Transformer所采用的MHSA(多头自注意力)、LayerNorm、FFN、位置编

    2024年04月12日
    浏览(48)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包