LLaMa 原理+源码——拆解 (KV-Cache, Rotary Positional Embedding, RMS Norm, Grouped Query Attention, SwiGLU)

这篇具有很好参考价值的文章主要介绍了LLaMa 原理+源码——拆解 (KV-Cache, Rotary Positional Embedding, RMS Norm, Grouped Query Attention, SwiGLU)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

原理

Vanilla Transformer 与 LLaMa 的区别

Vanilla Transformer 与 LLaMa 的对比:LLaMa与普通的Transformer架构不同的地方,包括采用了前置了层归一化(Pre-normalization)并使用RMSNorm 归一化函数(Normalizing Function)、使用了旋转位置嵌入(RoPE)、激活函数由ReLU更换为SwiGLU,并且将self-attention改进为使用KV-Cache的Grouped Query,整体Transformer架构与GPT-2 类似。
vanilla llama,llama,embedding

LLaMa -> Alpaca -> Vicuna 的演进:

  • LLaMa:Meta开源的Pre-trained Model,模型参数从7B、13B、32B、65B不等,LLaMa-7B在大多数基准测试上超过了Text-davinci-003(即GPT3-173B),相比于ChatGPT或者GPT4来说,LLaMa可能效果上还有差距,目前hugging face已集成了LLaMa的代码实现和开源模型。学术界和工业界都可以在此基础上进行学习和研究。
    vanilla llama,llama,embedding

  • Alpaca:斯坦福在LLaMa-7B的基础上监督微调出来的模型,斯坦福是用OpenAI的Text-davinci-003(即GPT3-173B)的API配合self-instruct技术,使用175个提示语种子自动生成了52K条提示-回复的指示数据集,在LLaMa-7B上微调得到的模型,在8张80G的A100上训练了3小时。

  • Vicuna在LLaMa-13B的基础上使用监督微调得到的模型,数据集来自于ShareGPT 产生的用户对话数据,共70K条。使用Pytorch FSDP在8张A100上训练了一天。相较于Alpaca,Vicuna在训练中将序列长度由512扩展到了2048,并且通过梯度检测和flash attention来解决内存问题;调整训练损失考虑多轮对话,并仅根据模型的输出进行微调。通过GPT4来打分评测,Vicuna可以达到ChatGPT 90%的效果。

  • LLaMa2:采用了Llama 1的大部分预训练设置和模型架构。LLaMa2和LLaMa1的最大差别是增加了文本长度,并在训练34B、70B的模型中应用了GQA
    vanilla llama,llama,embedding

Embedding

Embedding的过程word -> token_id -> embedding_vector,其中第一步转化使用tokenizer的词表进行,第二步转化使用 learnable 的 Embedding layer

vanilla llama,llama,embedding

RMS Norm

对比 Batch Norm 和 Layer Norm:都是减去均值Mean,除以方差Var,最终将归一化为正态分布N(0,1)。只不过两者是在不同的维度(batch还是feature)求均值和方差,(其中,减均值:re-centering 将均值mean变换为0,除方差:re-scaling将方差varance变换为1)。
vanilla llama,llama,embedding

RMS Norm(Root Mean Layer Norm):RMS Norm认为,Layer Norm成功的原因是re-scaling,因为方差Var计算的过程中使用了均值Mean,因此RMS Norm不再使用均值Mean,而是构造了一个特殊的统计量RMS代替方差Var。为什么使用RMS Norm?(1)RMS Norm计算量更小。(2)RMS的效果和Layer Norm一样好。

针对输入向量 a 的RMS Norm 函数计算公式如下:

vanilla llama,llama,embedding

此外,RMSNorm 还可以引入可学习的缩放因子gi 和偏移参数bi,从而得到

vanilla llama,llama,embedding

RMSNorm 在HuggingFace Transformer 库中代码实现如下所示:

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps # eps 防止取倒数之后分母为0
    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        # weight 是末尾乘的可训练参数, 即g_i
        return (self.weight * hidden_states).to(input_dtype)

为了使得模型训练过程更加稳定,GPT-2 相较于GPT 就提出了将Layer Norm前置,将第一个层归一化移动到多头自注意力层之前,第二个层归一化也移动到了全连接层之前,同时残差连接的位置也调整到了多头自注意力层与全连接层之后。层归一化中也采用了RMSNorm 归一化函数。

Rotary Positional Encodding

普通绝对Positional Encodding的使用过程word -> token_id -> embedding_vector + position_encodding -> Encoder_Input,其中第一步转化使用tokenizer的词表进行,第二步转化使用 learnable 的 Embedding layer。将得到的embedding_vector 和 position_encodding 进行element-wise的相加,然后才做为input送入LLM的encoder。

vanilla llama,llama,embedding
对比Absolute PE 和 Relative PE

  • Absolute PE 绝对位置编码:每次单独1个token的PE,每个token的PE之间没有关系,是一组固定的vector,反映每个token在sequence中的绝对位置
  • Relative PE 相对位置编码:每次处理2个token的PE,只在计算attention时使用(在query@key时加在key上),反映2个token的相关度

vanilla llama,llama,embedding

旋转位置编码(RoPE):RoPE 借助了复数的思想,出发点是通过绝对位置编码的方式实现相对位置编码。其目标是通过下述 f 运算,来给q,k 添加绝对位置信息m和n,得到˜qm 和˜kn,然后进行q@k

vanilla llama,llama,embedding

实际上,我们借助了复数的思想寻找了一个 g 运算来合并 f 运算q@k这两个操作,这样只需要token qk 以及两者的在seqence中的绝对位置mn即可:

vanilla llama,llama,embedding
可以看到与普通的相对位置编码不同,旋转相对位置编码用于在计算attention_score=q@k之后,对attention_score强调每个token之间的相对位置:

为什么叫旋转位置编码?因为使用欧拉公式构造旋转矩阵,将q@k的计算结果旋转到空间中对应的位置,实现对计算结果添加位置信息
vanilla llama,llama,embedding
上面是2维的例子,只有2个token xmxn,LLaMa中是n维的,n个token也是一样操作:
vanilla llama,llama,embedding

由于上述旋转矩阵Rn 具有稀疏性,有大量元素是0,因此可以使用逐位相乘⊗ 操作进一步加快计算速度。

vanilla llama,llama,embedding

RoPE 在HuggingFace Transformer 库中代码实现如下所示:

class LlamaRotaryEmbedding(torch.nn.Module):

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)
        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device,
        dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation
        # in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        dtype = torch.get_default_dtype()
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
        
    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`.
        # Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation
            # in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype),
            persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype),
            persistent=False)
    
        return (
        self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )
    def rotate_half(x):
        """Rotates half the hidden dims of the input."""
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
        # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
        cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
        sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
        cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
        sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
        q_embed = (q * cos) + (rotate_half(q) * sin)
        k_embed = (k * cos) + (rotate_half(k) * sin)
        return q_embed, k_embed

SwiGLU Function

SwiGLU 激活函数是Shazeer 在文献中提出,并在PaLM等模中进行了广泛应用,并且取得了不错的效果,相较于ReLU 函数在大部分评测中都有不少提升。在LLaMA 中全连接层使用带有SwiGLU 激活函数的FFN(Position-wise Feed-Forward Network)的计算公式如下:

vanilla llama,llama,embedding

其中,σ(x) 是Sigmoid 函数。下图给出了Swish 激活函数在参数β 不同取值下的形状。可以看到当β 趋近于0 时,Swish 函数趋近于线性函数y = x,当β 趋近于无穷大时,Swish 函数趋近于ReLU 函数,β 取值为1 时,Swish 函数是光滑且非单调。

vanilla llama,llama,embedding
HuggingFace 的Transformer 库中 S w i s h β = 1 Swish_{\beta=1} Swishβ=1函数使用 SILU 函数 代替。

KV-Cache

首先来了解一下LLama的训练(下词预测任务):seq2seq的生成,但迭代T次,seq_len逐渐增加
vanilla llama,llama,embedding

下句预测时的Self-Attention:

  • timpstep=1时seq_len=1,给[SOS]时,预测Love;
    vanilla llama,llama,embedding
  • timpstep=2时seq_len=2,给[SOS] 和 Love时,预测that
    vanilla llama,llama,embedding
  • timpstep=4时seq_len=4,给[SOS] 和 Love 和 can 和 quickly时,预测seize…
    vanilla llama,llama,embedding

每个timestep我们只关注生成的最后一个token,但因为LLaMa是一个seq2seq的model,每次必须重新计算和生成前面的token,因此我们希望能将之前timestep计算生成过的token给缓存起来,下个timestep不用再次计算,这样的背景下,KV-Cache就产生了。

再来分析一下,每次个timestep的self-attention中我们到底需要哪些:因为我们只关注最后一个token的attention_output,如下图timestep=4,我们只需要attention_output的第4个token。

因此我们只需要Q的最后一个tokenK的所有token相乘,得到最后一个token的attention_score,然后用V的所有token再与attention_score点积(相乘求和),得到最后一个token的attention_output
vanilla llama,llama,embedding
由上分析可知,每个timestep,我们的Q只需要新增的那个token即可,而K和V要缓存之前timestep的token,保证token是全的每次计算出来的attention_output就是那个新增的token的attention。 这样就可以节省大量计算开销。

vanilla llama,llama,embedding

vanilla llama,llama,embedding
vanilla llama,llama,embedding

Grouped Multi-Query Attention

回顾原始的多头注意力Multi-Head Attention:时间开销的瓶颈在于矩阵的运算matrix computation

vanilla llama,llama,embedding

当我们使用KV-Cache后:时间开销的瓶颈在于内存的访问memory access

vanilla llama,llama,embedding

Multi Query Attention

多查询注意力(Multi Query Attention,MQA 是多头注意力的一种变体。其主要区别在于,在多查询注意力中不同的注意力头共享一个键和值的集合,每个头只单独保留了一份查询参数。 具体操作上,去除 K和V 的head维度,只为Q保留head维度。因此这就是被叫做Multi Query Attention的原因。

vanilla llama,llama,embedding

因此K和V的矩阵仅有一份(不分head),这大幅度减少了显存占用,使其更高效。由于多查询注意力改变了注意力机制的结构,因此模型通常需要从训练开始就支持多查询注意力。

研究结果表明,可以通过对已经训练好的模型进行微调来添加多查询注意力支持,仅需要约 5% 的原始训练数据量就可以达到不错的效果。包括Falcon、SantaCoder、StarCoder等在内很多模型都采用了多查询注意力机制。

vanilla llama,llama,embedding

以LLM Foundry 为例,多查询注意力实现代码如下,与LLM Foundry 中实现的多头自注意力代码相对比,其区别仅在于建立Wqkv 层上:

class MultiQueryAttention(nn.Module):
"""Multi-Query self attention.
Using torch or triton attention implemetation enables user to also use
additive bias.
"""
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        device: Optional[str] = None,
    ):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.Wqkv = nn.Linear( # Multi-Query Attention 创建
            d_model,
            d_model + 2 * self.head_dim, # 只创建查询的头向量,所以只有1 个d_model
            device=device, # 而键和值则共享各自的一个head_dim 的向量
        )
        self.attn_fn = scaled_multihead_dot_product_attention
        self.out_proj = nn.Linear(
            self.d_model,
            self.d_model,
            device=device
        )
        self.out_proj._is_residual = True # type: ignore
    def forward(
        self,
        x,
    ):
        qkv = self.Wqkv(x) # (1, 512, 960)
        query, key, value = qkv.split( # query -> (1, 512, 768)
            [self.d_model, self.head_dim, self.head_dim], # key -> (1, 512, 96)
            dim=2 # value -> (1, 512, 96)
        )
        context, attn_weights, past_key_value = self.attn_fn(
            query,
            key,
            value,
            self.n_heads,
            multiquery=True,
    )
        return self.out_proj(context), attn_weights, past_key_value
Grouped Multi-Query Attention

就是在 Multi-Query Attention的基础上,对input进行分组,每组都有自己的K,V,以及多头Q。

vanilla llama,llama,embedding

源码

[LLMs 实践] 01 llama、alpaca、vicuna 整体介绍及 llama 推理过程文章来源地址https://www.toymoban.com/news/detail-823704.html

到了这里,关于LLaMa 原理+源码——拆解 (KV-Cache, Rotary Positional Embedding, RMS Norm, Grouped Query Attention, SwiGLU)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 结合源码拆解Handler机制

    作者:Pingred 当初在讲App启动流程的时候,它的整个流程涉及到的类可以汇总成下面这张图: 那时着重讲了AMS、PMS、Binder这些知识点,有一个是没有对它进行详细讲解的,那就是常见的Handler,它不仅在这个流程里作用在ApplicationThread和ActivityThread进行通信,它在整个安卓体系

    2024年02月11日
    浏览(48)
  • 华为HPLC模组全拆解之电力载波收发原理分析

    目录 一、前言 二、华为HPLC模组简介 三、HPLC模组拆解过程 四、模组电路原理图逆向 五、电力载波收发原理分析 六、通用单片机实现电力载波收发 七、结束语        电力线载波通信(PLC)是一种使用电力线进行数据传输的通信技术,即利用现有电网作为信号的传输介质,

    2024年01月18日
    浏览(39)
  • 双滤光片(IR-CUT)原理及拆解

            IR-CUT双滤光片切换器是由:滤光片(一片红外截止滤光片和一片全透光谱滤光片) + 动力部分(可以是电磁、电机或其他动力源)构成。         白天光线充分,电路控制板驱使切换器中切换到红外截止滤光片工作,CCD还原出真实彩色,解决红外光进入成像设

    2024年02月14日
    浏览(36)
  • MG996R 舵机内部驱动电路原理图和拆解实物图

     此原理图是180°舵机结构,将电位器去掉就是360°舵机的结构了,360°舵机相当于当电机使用了  图中PIN脚为PWM引脚 舵机内部拆解图如下:  

    2024年02月16日
    浏览(66)
  • 【数据结构与算法】Vue3实现选择排序动画效果与原理拆解

    删除有序数组中的重复项 JavaScript实现选择排序 选择排序(Selection Sort)是一种简单的排序算法,其基本思想是从待排序的数据中选择最小(或最大)的元素,然后将其放到已排序的序列的末尾(或开头)。该算法的时间复杂度为O(n^2),其中n是待排序数据的数量,因此在大规

    2024年02月13日
    浏览(38)
  • drop cache原理分析

    通过 echo 到文件/proc/sys/vm/drop_cache的处理函数为drop_caches_sysctl_handler,其中echo 1 /proc/sys/vm/drop_cache为释放page 页cache,echo 2 /proc/sys/vm/drop_cache为释放slab cache, echo 3 /proc/sys/vm/drop_cache为释放page 页cache和slab cache。 echo 2 /proc/sys/vm/drop_cache为释放slab cache其主要处理函数为drop_slab-dro

    2024年04月10日
    浏览(46)
  • python装饰器原理 | 常用装饰器使用(@cache, @lru_cache)

    🚀 关于python的装饰器原理介绍可看这里,讲的挺简洁易懂:python装饰器原理 ⭐ 弄懂装饰器原理后,来学学常用装饰器。 也就是一种装饰在被执行的函数上,将其执行的结果缓存起来,当下次请求的时候,如果请求该函数的传参未变则直接返回缓存起来的结果而不再执行函

    2023年04月25日
    浏览(57)
  • 拆解Spring boot:Springboot为什么如此丝滑而简单?源码剖析解读自动装配

    🎉🎉欢迎光临,终于等到你啦🎉🎉 🏅我是苏泽,一位对技术充满热情的探索者和分享者。🚀🚀 🌟持续更新的专栏 《Spring 狂野之旅:从入门到入魔》 🚀 本专栏带你从Spring入门到入魔   这是苏泽的个人主页可以看到我其他的内容哦👇👇 努力的苏泽 http://suzee.blog.csdn

    2024年03月23日
    浏览(44)
  • Python数据分析实战-*和**实现可变多参数的传入或变量的拆解(附源码和实现效果)

    实现功能 *和**实现多参数的传入或变量的拆解 实现代码 实现效果   本人读研期间发表5篇SCI数据挖掘相关论文,现在某研究院从事数据挖掘相关科研工作,对数据挖掘有一定认知和理解,会结合自身科研实践经历不定期分享关于python机器学习、深度学习、数据挖掘基础知识

    2024年02月12日
    浏览(51)
  • 计算机组成原理(4)-----Cache的原理及相关知识点

    目录 1.Cache的原理 2.Cache的性能 3.Cache和主存的映射方式  (1)全相联映射 (2)直接映射 (3)组相联映射 4.替换算法 (1)随机算法(RAND) (2)先进先出算法(FIFO) (3)近期最少使用(LRU) (4)最近不经常使用(LFU) 5.Cache写策略 (1)写命中 •写回法 •全写法 (2)写不命中 •写分配法 •非写分

    2024年02月21日
    浏览(60)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包