LLaMA中ROPE位置编码实现源码解析

这篇具有很好参考价值的文章主要介绍了LLaMA中ROPE位置编码实现源码解析。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

1、Attention中q,经下式,生成新的q。m为句长length,d为embedding_dim/head
θ i = 1 1000 0 2 i d \theta_i=\frac{1}{10000^\frac{2i}{d}} θi=10000d2i1
LLaMA中ROPE位置编码实现源码解析,Pytorch使用,NLP,llama,深度学习,pytorch,人工智能

2、LLaMA中RoPE源码

import torch

def precompute_freqs_cis(dim: int, end: int, constant: float = 10000.0):
    '''
    计算cos和sin的值,cos值在实部,sin值在虚部,类似于 cosx+j*sinx
    :param dim: q,k,v的最后一维,一般为emb_dim/head_num
    :param end: 句长length
    :param constant: 这里指10000
    :return:
    复数计算 torch.polar(a, t)输出, a*(cos(t)+j*sin(t))
    '''
    # freqs: 计算 1/(10000^(2i/d) ),将结果作为参数theta
    # 形式化为 [theta_0, theta_1, ..., theta_(d/2-1)]
    freqs = 1.0 / (constant ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [d/2]

    # 计算m
    t = torch.arange(end, device=freqs.device)  # [length]
    # 计算m*theta
    freqs = torch.outer(t, freqs).float()  # [length, d/2]
    # freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1

    # 计算cos(m*theta)+j*sin(m*theta)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    # freqs_cis: [cos(m*theta_0)+j*sin(m*theta_0),  cos(m*theta_1)+j*sin(m*theta_1),), ..., cos(m*theta_(d/2-1))+j*sin(m*theta_(d/2-1))]
    # 其中j为虚数单位, m=0,1,...,length-1
    return freqs_cis # [length, d/2]

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # (1, length, 1, d/2)
    return freqs_cis.view(*shape) # [1, length, 1, d/2]

def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor,):
    # 先将xq维度变为[bs, length, head,  d/2, 2], 利用torch.view_as_complex转变为复数
    # xq:[q0, q1, .., q(d-1)] 转变为 xq_: [q0+j*q1, q2+j*q3, ..., q(d-2)+j*q(d-1)]
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [bs, length, head, d/2]
    # 同样的,xk_:[k0+j*k1, k2+j*k3, ..., k(d-2)+j*k(d-1)]
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # [1, length, 1, d/2]
    # 下式xq_ * freqs_cis形式化输出,以第一个为例, 如下
    # (q0+j*q1)(cos(m*theta_0)+j*sin(m*theta_0)) = q0*cos(m*theta_0)-q1*sin(m*theta_0) + j*(q1*cos(m*theta_0)+q0*sin(m*theta_0))
    # 上式的实部为q0*cos(m*theta_0)-q1*sin(m*theta_0),虚部为q1*cos(m*theta_0)+q0*sin(m*theta_0)
    # 然后通过torch.view_as_real函数,取出实部和虚部,维度由[bs, length, head, d/2]变为[bs, length, head, d/2, 2],最后一维放实部与虚部
    # 最后经flatten函数将维度拉平,即[bs, length, head, d]
    # 此时xq_out形式化为 [实部0,虚部0,实部1,虚部1,..., 实部(d/2-1), 虚部(d/2-1)]
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [bs, length, head, d]
    # 即为新生成的q

    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

if __name__=='__main__':
    # (bs, length, head, d)
    q = torch.randn((2, 10, 12, 32))  # q=[q0, q1, .., qd-1]
    k = torch.randn((2, 10, 12, 32))
    v = torch.randn((2, 10, 12, 32))
    freqs_cis= precompute_freqs_cis(dim=32, end=10, constant= 10000.0)
    # print(freqs_cis.detach().numpy())

    q_new, k_new = apply_rotary_emb(xq=q, xk=k, freqs_cis=freqs_cis)
    print()

3、表示
看1中的公式表示,q0和q1相互作用,得到新的q0和q1,也即
q 0 n e w = q 0 ∗ c o s ( m θ 0 ) − q 1 ∗ s i n ( m θ 0 ) q^{new}_0=q_0*cos(m\theta_0)-q_1*sin(m\theta_0) q0new=q0cos(mθ0)q1sin(mθ0)
q 1 n e w = q 1 ∗ c o s ( m θ 0 ) + q 0 ∗ s i n ( m θ 0 ) q^{new}_1=q_1*cos(m\theta_0)+q_0*sin(m\theta_0) q1new=q1cos(mθ0)+q0sin(mθ0)
这里将 ( q 0 , q 1 ) (q_0,q_1) (q0,q1) ( q 0 n e w , q 1 n e w ) (q^{new}_0,q^{new}_1) (q0new,q1new)看做向量,很明显上式是向量旋转,旋转角度为逆时针 m θ 0 m\theta_0 mθ0

LLaMA中ROPE位置编码实现源码解析,Pytorch使用,NLP,llama,深度学习,pytorch,人工智能
可与PalM中ROPE实现方式做对比
PaLM中ROPE位置编码实现源码解析文章来源地址https://www.toymoban.com/news/detail-676944.html

到了这里,关于LLaMA中ROPE位置编码实现源码解析的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 如何修改大模型的位置编码 --以LLama为例

    最近在看RoPE相关内容,一些方法通过简单修改位置编码就可以无需训练支持更长的文本内容。由于一些模型,已经训练好了,但是怎么修改已经训练好的模型位置编码。查了以下相关代码,记录一下。原理这里就不细讲了,贴几个相关博客。 十分钟读懂旋转编码(RoPE) Tr

    2024年04月15日
    浏览(33)
  • 《动手学深度学习 Pytorch版》 10.6 自注意力和位置编码

    在注意力机制中,每个查询都会关注所有的键-值对并生成一个注意力输出。由于查询、键和值来自同一组输入,因此被称为 自注意力(self-attention),也被称为内部注意力(intra-attention)。本节将使用自注意力进行序列编码,以及使用序列的顺序作为补充信息。 给定一个由

    2024年02月06日
    浏览(40)
  • 概念解析 | 神经网络中的位置编码(Positional Encoding)

    注1:本文系“概念解析”系列之一,致力于简洁清晰地解释、辨析复杂而专业的概念。本次辨析的概念是:Positional Encoding A Gentle Introduction to Positional Encoding in Transformer Models, Part 1 在自然语言处理任务中,序列的顺序信息非常重要。例如,“小明去公园玩球”和“小明玩球去公园”

    2024年02月05日
    浏览(49)
  • 自编码器简单介绍—使用PyTorch库实现一个简单的自编码器,并使用MNIST数据集进行训练和测试

    自编码器是一种无监督学习算法,用于学习数据中的特征,并将这些特征用于重构与输入相似的新数据。自编码器由编码器和解码器两部分组成,编码器用于将输入数据压缩到一个低维度的表示形式,解码器将该表示形式还原回输入数据的形式。自编码器可以应用于多种领域

    2023年04月27日
    浏览(75)
  • Llama模型结构解析(源码阅读)

    参考资料: https://zhuanlan.zhihu.com/p/636784644 https://spaces.ac.cn/archives/8265 ——《Transformer升级之路:2、博采众长的旋转式位置编码》 前言 :本次阅读代码位置,在transformers库底下的modeling_llama.py,具体位置在:transformers/models/llama/modeling_llama.py,如下图所示: 代码如下 RMSNorm的公

    2024年02月10日
    浏览(29)
  • LLaMA-Adapter源码解析

    2024年02月06日
    浏览(36)
  • 【论文阅读随笔】RoPE/旋转编码:ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING

    绝对位置编码比较简单,加或乘一个有次序的数 实现相对位置编码,也即意味着,要蕴含位置差的信息: 假设m是某个token的位置信息,n是另一个token的位置信息,要有类似 m − n m-n m − n 的信息,比较容易想到复数乘法会产生 m − n m-n m − n ,以及复数乘法和复数内积的性

    2024年03月11日
    浏览(42)
  • LLMs之Colossal-LLaMA-2:源码解读(train.py文件)基于给定数据集实现持续预训练LLaMA-2—解析命令行参数→初始化配置(分布式训练环境colossalai+训练日志+加速插

    LLMs之Colossal-LLaMA-2:源码解读(train.py文件)基于给定数据集实现持续预训练LLaMA-2—解析命令行参数→初始化配置(分布式训练环境colossalai+训练日志+加速插件)→数据预处理(初始化分词器+数据处理器+数据加载器)→模型训练(初始化模型/优化器/学习率调度器/梯度检查点/Flash-Att

    2024年02月06日
    浏览(40)
  • 2023年的深度学习入门指南(19) - LLaMA 2源码解析

    上一节我们学习了LLaMA 2的补全和聊天两种API的使用方法。本节我们来看看LLaMA 2的源码。 上一节我们讲了LLaMA 2的编程方法。我们来复习一下: 我们先来看看text_completion函数的参数是什么意思,该函数的原型为: 我们来看下这些参数的含义: prompts:这是一个字符串列表,每

    2024年02月15日
    浏览(48)
  • 机器学习&&深度学习——自注意力和位置编码(数学推导+代码实现)

    👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er 🌌上期文章:机器学习深度学习——注意力分数(详细数学推导+代码实现) 📚订阅专栏:机器学习深度学习 希望文章对你们有所帮助 在深度学习中,经常使用CNN和RNN对序列进行编码。有了自注意力之后,我们

    2024年02月12日
    浏览(69)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包