聊聊拉长LLaMA的一些经验

这篇具有很好参考价值的文章主要介绍了聊聊拉长LLaMA的一些经验。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

Sequence Length是指LLM能够处理的文本的最大长度,越长,自然越有优势:

  1. 更强的记忆性。更多轮的历史对话被拼接到对话中,减少出现遗忘现象

  2. 长文本场景下体验更佳。比如文档问答、小说续写等

当今开源LLM中的当红炸子鸡——LLaMA,第一版上下文长度是2048,第二版长度是4096。相比之下ChatGPT、GPT4已经支持到16k,Claude甚至支持到了100k。足以见得将LLaMA拉长是如此的任重而道远。本文将会介绍三种在旋转位置编码(RoPE)基础上扩充上下文的高性价比方案,在文末会介绍我的实践经验。

线性插值法

Kaiokendev的博客[1]中提到了方法,和Meta的一篇工作[2]不谋而合,其思想主要是将目标长度压缩到原始长度。如下图所示,LLaMA-1预训练的长度为2048,如果我们想把它拉长到4096:

  • 方法一:推理时直接拉长到4096。这考虑位置编码的外推性(即在短文本上训练,长文本上推理的能力[2]),而RoPE的外推性则是相当一般[2]。由于训练时长度都是小于2048的,超过2048部分Attention分数会飙升,导致困惑度急剧上升。

  • 方法二:在原始模型基础上做长度为4096的继续训练。这里先岔开介绍另一款模型——MPT-30B的做法,根据官方博客[3]的介绍:

    As mentioned earlier, MPT-30B was trained with a long context window of 8k tokens (vs. 2k for LLaMa and Falcon) and can handle arbitrarily long context windows via ALiBi or with fine-tuning. To build 8k support into MPT-30B efficiently, we first pre-trained on 1T tokens using sequences that were 2k tokens long, and continued training for an additional 50B tokens using sequences that were 8k tokens long.

    MPT-30B采用ALiBi位置编码(外推性优于RoPE),在2k的长度进行1T token的训练,然后在8k长度上进行50B token的预训练——这是在外推性强于RoPE的ALiBi上的情况。LLaMA-1预训练的token数是1T以上,想要在长度为4096样本上效果不下降,那需要训练足够多的token数才行,这就需要较大的成本了。

  • 方法三:另一种思路则是将4096的位置编码通过线性插值法压缩到2048内,这样只需要在少量的长度为4096的数据上进行继续预训练,便可达到不错的效果。

聊聊拉长LLaMA的一些经验,llama,人工智能,chatgpt,算法,数据挖掘

来自论文[2]

代码实现

线性插值法的实现代码相当的简单,这需要在原始RoPE上进行微小的改动,即加上下图的scale参数。

聊聊拉长LLaMA的一些经验,llama,人工智能,chatgpt,算法,数据挖掘

来自[7],scaled_rope/LlamaLinearScaledRotaryEmbedding.py

效果

Meta的工作[2]中进行了充足实验和公式推导证明,如果想看具体的代码,建议看lmsys.org(Vicuna的出品方)的一篇工作[4]:How Long Can Open-Source LLMs Truly Promise on Context Length? 他们对比了商用模型、号称支持长文本的开源模型和“Vicuna+线性插值法”的效果,并给出了几个结论:

  1. 商用模型在长文本的效果上很能打!而那些号称支持长文本的开源模型,在长文本上则表现不佳。

  2. 随着文本长度的增加,越接近边界,Vicuna+线性插值法的效果降低越明显。这可能是因为训练数据存在短文本的情况。

聊聊拉长LLaMA的一些经验,llama,人工智能,chatgpt,算法,数据挖掘

来自[4]

聊聊拉长LLaMA的一些经验,llama,人工智能,chatgpt,算法,数据挖掘

来自[4]

写这篇文章的同时,ChatGLM团队更新了ChatGLM2-6B-32K,也是使用了插值法。同时推出了长文本的中英评测集LongBench,在这个评测集上ChatGLM2-6B-32K展示了强大的实力,但值得注意的是,该评测集的评测方式是使用ChatGLM2-6B来进行评估的。

NTK插值法

NTK插值法的提出于一篇Reddit帖子[5],它提出使用Neural Tangent Kernel (NTK)来解决这个问题。

if you apply Neural Tangent Kernel (NTK) theory to this problem, it becomes clear that simply interpolating the RoPE's fourier space "linearly" is very sub-optimal, as it prevents the network to distinguish the order and positions of tokens that are very close by. Borrowing from NTK literature, scaling down the fourier features too much will eventually even prevent succesful finetunes (this is corroborated by the recent paper by Meta that suggests an upper bound of ~600x)

Instead of the simple linear interpolation scheme, I've tried to design a nonlinear interpolation scheme using tools from NTK literature. Basically this interpolation scheme changes the base of the RoPE instead of the scale, which intuitively changes the "spinning" speed which each of the RoPE's dimension vectors compared to the next. Because it does not scale the fourier features directly, all the positions are perfectly distinguishable from eachother, even when taken to the extreme (eg. streched 1million times, which is effectively a context size of 2 Billion)

帖子中作者还用了时钟的例子来解释线性插值和NTK插值的异同:

RoPE behaves like a clock. Your 12 hours wall clock is basically a RoPE of dimension 3 with a base of 60. So for each second, the minute hand turns 1/60th of a minute, and for each minute, the hour hand turns 1/60th.

Now if you slowed down time by a factor of 4x, that is a linear RoPE scaling used in SuperHOT. Unfortunately now it is really really hard to distinguish each second, because now the seconds hand barely moves each second. So if someone gave you two different times, which is only different by a single second, you won't be able to distinguish them from afar (let's say the NNs have myopia because that's basically what NTK predicts)

Now NTK-Aware RoPE scaling does not slow down the seconds. One second is still one second, but it slows down minutes by a factor of let's say 1.5, and the hours by a factor of 2. This way you can fit 90 minutes in a hour, and fit 24 hours in half a day. So now you basically have a clock that can measure 129.6k seconds instead of 43.2k seconds.

Because you don't need a precise measurement of the hour hand when looking at the time, scaling the hours more compared to seconds is crucial. You don't want to lose the precision of the seconds hand, but you can afford to lose precision on the minutes hand and even more on the hours hand.

Then, it's just a matter of deriving the base change formula in order to obtain such a scaling. (where less precise dimensions are scaled more and more)

代码实现

NTK的实现则更加简单了,根据超参数alpha,对应修改base变量即可:

聊聊拉长LLaMA的一些经验,llama,人工智能,chatgpt,算法,数据挖掘

来自[7],scaled-rope/scaled_rope/LlamaNTKScaledRotaryEmbedding.py

效果

在效果上,帖子中也给出了NTK插值法和线性插值法的PPL比较,可以看到,在二者都不做Finetune的情况下,NTK插值法具备更低的PPL。

聊聊拉长LLaMA的一些经验,llama,人工智能,chatgpt,算法,数据挖掘

来自[5]

动态插值法

动态插值法同样出自于一篇Reddit帖子[6],它的出发点很简单:

My idea was to use the exact position values for the first 2k context (after all, why mess with a good thing?) and then re-calculate the position vector for every new sequence length as the model generates token by token.

这种做法可以和先前的两种方法相结合,[7]中也给出了详细的代码实现。

效果

作者和前两种方法做了对比,展示了动态插值法在PPL下降上的优势。

聊聊拉长LLaMA的一些经验,llama,人工智能,chatgpt,算法,数据挖掘

实践经验

我在实践的过程中,评估效果主要使用longChat[4]中使用的评估方式,以下是一些takeaway tips,欢迎大家一起交流。

  1. 线性插值法具备完整的理论支持和大量的实验证明,在我的实践中,“线性插值法+Finetune”取得了最佳效果。

  2. NTK插值法的实验中,对比的是不做Finetune的情况,在我的实践中,“NTK插值+Finetune”效果会明显优于单独的“NTK插值”,但它的收敛速度会慢于“线性插值法+Finetune”。

  3. 动态插值法的实验同样是在不做Finetune的情况对比的,目前为止我并没有尝试过这种方法。在Reddit的评论区有人提出一个很好的问题:如果采取这种方法,逐token推理时,文本的长度是在变化的,则导致无法使用kv-cache,这会对性能产生很大的影响。

最后,拉长LLaMA的方案可以不从RoPE入手(如:LongLLaMA[8]),但“线性插值法+Finetune”无疑是一种性价比很高的方案,推荐大家尝试!

——2023.07.31

Reference

[1] Extending Context is Hard…but not Impossible†

[2] EXTENDING CONTEXT WINDOW OF LARGE LANGUAGE MODELS VIA POSITION INTERPOLATION

[3] MPT-30B: Raising the bar for open-source foundation models

[4] How Long Can Open-Source LLMs Truly Promise on Context Length?

[5] NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.

[6] Dynamically Scaled RoPE further increases performance of long context LLaMA with zero fine-tuning

[7] GitHub - jquesnelle/scaled-rope

[8] Focused Transformer: Contrastive Training for Context Scaling文章来源地址https://www.toymoban.com/news/detail-624781.html

到了这里,关于聊聊拉长LLaMA的一些经验的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • LLAMA模型部署与一些关键定义

    这个有很长的路要走,当前先不讲了,后面开一个专题讲讲。 生成一个新的kernel环境 在bash中切换到这个环境 克隆项目 安装相关包 安装相关依赖包 下载llama模型包 命令行执行 其中 ckpt_dir 是指模型文件存放的文件夹名称 tokenizer_path 是指分词器所存放的文件夹位置 nproc_per_

    2024年02月11日
    浏览(29)
  • 关于LLaMA Tokenizer的一些坑...

    使用LLaMA Tokenizer对 jsonl 文件进行分词,并将分词结果保存到 txt 文件中,分词代码如下: 从以上代码可以看出, txt 文件中的每行内容实际上是 jsonl 文件对应行的文档的分词结果,分词之间以空格分隔。理论上,这意味着 txt 文件的行数应与 jsonl 文件的行数 相匹配 , 均等同

    2024年02月20日
    浏览(26)
  • 关于生成式语言大模型的一些工程思考 paddlenlp & chatglm & llama

    生成式语言大模型,随着chatgpt的爆火,市场上涌现出一批高质量的生成式语言大模型的项目。近期百度飞桨自然语言处理项目paddlenlp发布了2.6版本。更新了以下特性:全面支持主流开源大模型Bloom, ChatGLM, GLM, Llama, OPT的训练和推理;Trainer API新增张量训练能力, 简单配置即可开

    2024年02月12日
    浏览(42)
  • 中山大学人工智能学院——考研上岸经验贴

    首先是初试成绩,中山大学在2.21号就公布了成绩和 排名 ,这点很不错,有很多学校只公布成绩而没有排名。我的初试总分386,总排名第二,各个科目还是比较平均的: 要说的是,2023年,人工智能学院专硕本来只招3个人,后来在复试前扩了4个名额,所以最后专硕录取7个名

    2024年02月08日
    浏览(41)
  • 【实践探索】人工智能语音转换技术的实践经验和优化建议

    [toc] 【实践探索】人工智能语音转换技术的实践经验和优化建议 随着人工智能技术的快速发展,语音识别技术作为其基础应用之一,也得到了越来越广泛的应用。针对目前市场上主流的人工智能语音识别技术,本文将深入探讨其原理、实现过程以及优化建议。本文将重点分析

    2024年02月06日
    浏览(70)
  • 人工智能中一些看不懂的代码

    def forward(self, input: Tensor, hx: Optional[Tensor] = None) - Tuple[Tensor, Tensor]: # noqa: F811         pass forward ,它的第一个参数 input 是一个 Tensor 类型的变量,第二个参数 hx 是一个可选的 Tensor 类型变量,这里使用了 Python 3.7 引入的类型注解语法。 函数返回值类型是一个由两个 Tensor 类

    2023年04月21日
    浏览(37)
  • 【人工智能】关于人类大脑模型的一些数学公式

    关于人类大脑建模的数学公式主要涉及到神经元网络、激活函数、学习算法等方面。这里是一些常见的数学公式(使用Markdown和LaTeX语法)。 神经网络的万能逼近定理(Universal Approximation Theorem)是关于在一定条件下神经网络能够逼近任意连续函数的定理。有多个版本的定理针

    2024年02月07日
    浏览(62)
  • 6款超好用AI写作神器,写作效率秒拔高! #经验分享#人工智能#知识分享

    在当今信息爆炸的时代,写作成为了人们表达思想、分享知识和传递情感的重要方式之一。对于很多人来说,写作并非易事。我们会陷入困境,无法找到灵感,我们会苦恼于语言表达的准确性,还有时候我们可能遭遇到了创作瓶颈,随着科技的进步和人工智能技术的发展,

    2024年04月15日
    浏览(51)
  • 机器之心 AI 技术--人工智能助力个性化视频实战经验分享(文末送书)

    在视频生成即将迎来技术和应用大爆发之际,为了帮助企业和广大从业者掌握技术前沿,把握时代机遇,机器之心AI论坛就将国内的视频生成技术力量齐聚一堂,共同分享国内顶尖力量的技术突破和应用实践。 论坛将于2024.01.20在北京举办,现场汇聚领域内专家和一线开发者,

    2024年02月03日
    浏览(51)
  • 玩转AIGC(人工智能生成内容)需要一些小技巧

    玩转AIGC(人工智能生成内容)的确需要一些技巧,而Prompt提示词的选择非常关键,可以影响到生成的答案。以下是一些与AI对话的技巧和咒语示例: 确保你的Prompt清晰明了,包括主题、问题或指令,以便AI能够更好地理解你的需求。 有点像小学语文,老师会要求你用一句话描

    2024年02月05日
    浏览(83)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包