LLaMA长度外推高性价比trick:线性插值法及相关改进源码阅读及相关记录

这篇具有很好参考价值的文章主要介绍了LLaMA长度外推高性价比trick:线性插值法及相关改进源码阅读及相关记录。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

前言

最近,开源了可商用的llama2,支持长度相比llama1的1024,拓展到了4096长度,然而,相比GPT-4、Claude-2等支持的长度,llama的长度外推显得尤为重要,本文记录了三种网络开源的RoPE改进方式及相关源码的阅读。

关于长度外推性:https://kexue.fm/archives/9431

关于RoPE:https://kexue.fm/archives/8265

1、线性插值法

论文:EXTENDING CONTEXT WINDOW OF LARGE LANGUAGE MODELS VIA POSITION INTERPOLATION

链接:https://arxiv.org/pdf/2306.15595.pdf

思想:不进行长度外推,而是直接缩小位置索引。即:将4096的位置编码通过线性插值法压缩到2048内,这样只需在少量的4096长度的数据上继续预训练,便可达到不错的效果。

LLaMA长度外推高性价比trick:线性插值法及相关改进源码阅读及相关记录,自然语言处理,大语言模型,llama,自然语言处理,大语言模型,人工智能,算法

源码阅读(附注释)

class LlamaLinearScaledRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None):
        super().__init__()
        # 相比RoPE增加scale参数
        self.scale = scale
        # inv_freq为基值向量
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        # 构建max_seq_len_cached大小的张量t
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        # 张量t归一化,RoPE没有这一步
        t /= self.scale
        # einsum计算频率矩阵
        # 'i, j->i j’表示分别输入尺寸为[i]、[j]的向量,做笛卡尔运算得到尺寸为[i, j]的矩阵。
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        # 在-1维做一次拷贝、拼接
        emb = torch.cat((freqs, freqs), dim=-1)
        dtype = torch.get_default_dtype()
        # 注册为模型的缓冲区cos_cached和sin_cached
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        # seq_len为序列长度,seq_len大于max_seq_len_cached,则重新计算频率矩阵,并更新cos_cached和sin_cached的缓冲区
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            t /= self.scale
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
        # 长度裁剪:返回cos_cached和sin_cached中与seq_len(序列长度)
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

线性插值法的相关实验效果:https://lmsys.org/blog/2023-06-29-longchat/

2、NTK插值法

NTK插值改进llama中使用的RoPE插值方法,同样,对于RoPE代码改动更小,其他地方与线性插值法实现一致。

reddit原帖:NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation

链接:https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/?rdt=58346

源码阅读:

class LlamaNTKScaledRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None):
        super().__init__()
        # 与线性插值法相比,实现更简单,alpha仅用来改变base
        base = base * alpha ** (dim / (dim-2))
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        dtype = torch.get_default_dtype()
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

3、动态插值法

动态插值法又是对NTK插值法和线性插值法的改进,可以看作是上述两者的一种结合思想,旨在减少困惑度损失并实现更大的缩放。

reddit原帖:Dynamically Scaled RoPE further increases performance of long context LLaMA with zero fine-tuning

链接:https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/

源码阅读

class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):
        super().__init__()
        # 是否开启NTK(Neural Tangent Kernel)
        self.ntk = ntk
        self.base = base
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        # inv_freq为基值向量
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        # emb:[max_seq_len_cached, dim]
        emb = torch.cat((freqs, freqs), dim=-1)
        dtype = torch.get_default_dtype()
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            if self.ntk:
                base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2))
                # 计算新的inv_freq
                inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
                self.register_buffer("inv_freq", inv_freq)
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            if not self.ntk:
                # 缩放
                t *= self.max_position_embeddings / seq_len
            # 得到新的频率矩阵freqs
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            # freqs与自身拼接得到新的emb
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            # 注册为模型的缓冲区cos_cached和sin_cached
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)

        # 长度裁剪
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

网友对于困惑度的实验并取得了一定的效果:https://github.com/turboderp/exllama/pull/118

总结

本文介绍了llama通过线性插值法及相关改进方案进行长度外推的trcik,并对相关源码阅读及网络资源进行记录,个人粗浅认为,相比LongLLaMA,基于线性插值法+Finetune的方式,是一种高性价比的长度外推方案。文章来源地址https://www.toymoban.com/news/detail-641228.html

到了这里,关于LLaMA长度外推高性价比trick:线性插值法及相关改进源码阅读及相关记录的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • POE:性价比最高的 AI 整合网站

    创作不易,如果本文对你有帮助,胖友记得一键三连 😭。更多 AI 优质内容推荐请关注主页 “AI” 专栏,笔者会不定期更新觉得自己用下来还不错的 AI 相关产品。 Poe 是一款同时整合了 ChatGPT、Sage、GPT-4、Claude+、Claude-instant 和 NeevaAI 的网站,也是目前笔者使用体验很不错的一

    2024年02月03日
    浏览(54)
  • 万字解读|怎样激活 TDengine 最高性价比?

    不知不觉间,TDengine 已经 6 岁多了。在这 6 年多的时间里,我们从零开始,在一行又一行代码的淬炼下,TDengine 从 1.6 走过 2.0,终于走到如今的 3.0 时代。 自 2022 年下旬发布以来,经过我们不断地打磨优化之后,TDengine 3.0 在性能、功能、稳定性各个方面均有大幅提升,已经

    2024年02月07日
    浏览(46)
  • 低价位高性价比keychron机器键盘推荐

    目录 前言         1、为什么要使用机器键盘?         2、什么是keychron机器键盘? ​ 一、入坑keychron机器键盘         1、keychron入坑历程         2、键盘介绍        2.1键盘按键区介绍         2.2 风格介绍          2.3实物场景图片  二、综合实测   

    2024年02月10日
    浏览(54)
  • 不是 ES 用不起,而是 ClickHouse 更具“性价比”?

    云原生架构是一种基于云计算、容器化和微服务的架构模式。业内预测,到2025年,预计超过95%的工作负载将迁移到云端,云原生架构成为业务的必需品。 经过十三年的发展,某快递公司目前C端累计注册用户超2.5亿、P端(专业用户)累计注册快递员及网点经营者超130万、B端

    2024年01月25日
    浏览(75)
  • VPS服务器”性价比之王”系列:RackNerd

    2023 黑五!!!新 Ryzen 系列 洛杉矶dc02机房重新补货! 支付方式:支付宝、PayPal、信用卡、数字货币 “流量翻倍”活动参加方法在最后 CPU 内存 硬盘(SSD) 流量 带宽 价格(续费同价) 购买链接 1核 768 MB 15GB 1TB 1Gbps $10.8/年 直达链接 1核 2 GB 30GB 2.5TB 1Gbps $16.98/年 直达链接 2核 2.5

    2024年02月03日
    浏览(50)
  • 高性价比中兴智能穿墙路由京东预售详情介绍(附活动网址)

    随着科技的发展,人们对网络越来越依赖。而 路由器 作为家庭网络的入口也引来了大批厂商的兴趣。随着小米,360,百度的加入,更是将路由智能化的话题推向了一个新的高峰。但是单从功能上来说,智能路由更多是一个概念上,而没有给用户带来更多的实用。近日京东、

    2024年02月06日
    浏览(56)
  • AI智能音箱高性价比出好音质的功放芯片

    近几年 人工智能 等技术的不断发展,AI智能音箱已成为炙手可热的爆款;众多企业纷纷加入其中;如我们熟知的天猫精灵、小爱同学、小度智能音箱、华为AI音箱、腾讯叮当等等智能音箱;据不完全统计,目前国内做智能音箱的企业已有近百来家。 智能音箱虽然形态较小,但

    2024年02月04日
    浏览(78)
  • 蓝奥声开发高性价比智能wifi插座进军智能家居

    :智能家居、家用插座、WiFi插座、高性价比插座 智能硬件的大潮袭来让智能家居这一并不新鲜的概念再次火热起来,关于智能家居的各种场景的描述给了我们很大的想象空间,然而落到实处真正开始走进生活时却又显得那么骨感,一时间作为智能家居的控制中介,小

    2024年02月12日
    浏览(45)
  • 全面评测安全企业邮箱加密服务,推荐高性价比提供商

    安全电子邮件是加密形式的电子邮件。有权访问密钥的人只能阅读电子邮件。有许多安全的电子邮件发送工具可以避免业​​务风险并保护电子邮件中写入的信息。这些工具使您能够使用安全的端到端电子邮件加密来发送和接收消息。Zoho Mail企业邮箱最适合多用户帐户、小型

    2024年02月08日
    浏览(73)
  • ipad触控笔是哪几款?开学季性价比电容笔推荐

    随着新学期的临近,很多同学都在询问,步入新学期的时候,应该买什么类型的电容笔?苹果的电容笔价格不菲,有必要去选购吗?因为苹果笔拥有着一种特殊的重力压感,所以其的价格很贵,但是其的价格却让很多人望而却步。以下,我将为大家介绍一些具有较高性价比和

    2024年02月09日
    浏览(52)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包