以Llama-2为例,在生成模型中使用自定义StoppingCriteria

这篇具有很好参考价值的文章主要介绍了以Llama-2为例,在生成模型中使用自定义StoppingCriteria。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

1. 前言

在之前的文章中,介绍了使用transformers模块创建的模型,其generate方法的详细原理和使用方法,文章链接:

以beam search为例,详解transformers中generate方法(上)
以beam search为例,详解transformers中generate方法(下)

其中提到了用户参与生成过程的两个关键组件,logits_processorstopping_criteria,使用这两个类,是用户控制生成过程的主要手段。其中,logits_processor用来在生成过程中,根据用户设置的指定规则,强行修改当前step在词表空间上的概率分布,而stopping_criteria,根据用户所规定的规则来中止生成。

这两个组件在transformers模块中都有一些预设的类可以直接使用,预设类的基本信息介绍可参考以beam search为例,详解transformers中generate方法(上)。

本文将结合实际应用场景,介绍用户如何根据自己的需求来设计并实现一个自定义的stopping_criteria,来控制生成过程提前结束。

2. 场景介绍

这次介绍的场景是,使用Llama-2的生成能力对一段新闻进行概括,希望能够生成一句简短的话,来概括新闻中发生的最核心的事情。

通过给定对话背景,结合历史样例的方式,希望Llama-2能够输出期望的结果。

对话的prompt构造方法可以参考之前的内容:NLP实践——Llama-2 多轮对话prompt构建。

然而,即便是采用了in-context learning的方式,Llama-2生成的结果仍然过于冗长。

例如对于这样一篇新闻:

text = """, Photo Credit : Associated Press Four air crew members were missing after an Australian army helicopter ditched into waters off the Queensland state coast during joint military exercises with the United States, officials said Saturday. The MRH-90 Taipan helicopter went down near Lindeman Island, a Great Barrier... ..."""  
# 后边忽略若干内容

模型生成的结果为:

Four Australian army air crew members are missing after an Australian army MRH-90 Taipan helicopter ditched into waters off the Queensland state coast during joint military exercises with the United States. The helicopter went down near Lindeman Island, a Great Barrier Reef tourist resort, at around 11 pm on Friday. A search involving US, Canadian, and Australian personnel is underway to find the missing crew, who are all Australian men. Debris that appeared to be from a helicopter has been recovered, according to Queensland Police Assistant Commissioner Douglas McDonald. The Taipan was taking part in Talisman Sabre, a biennial joint US-Australian military exercise that is largely based in Queensland. This year's exercise involves 13 nations and over 30,000 military personnel. Defense Minister Richard Marles said the helicopter ditched, which refers to an emergency landing on water. He added that defense exercises, which are so necessary for the readiness of our defense force, are serious and carry risk. US Defense Secretary... ...
# 后边忽略若干内容

可以看出,并不是模型生成的结果不好,但是它太啰嗦了,而对于我的需求而言,模型只需要输出其中的第一句话就足够了。

这时候可能有人就会觉得:“那我分句然后把第一句话保留下来不就好了?”

——这样做虽然也可以达成效果,但是这个生成过程,时间和算力已经被消耗了。

所以需要采取方法,让模型在生成到第一个句号的时候,就停止生成,返回结果。于是就需要用到今天的主角——Stopping Criteria。

3. 解决方法

transformers模块中内置了几个默认的stopping criteria,然而,在很多情况下,它们并不能满足需求,这时,就需要创建自定义的stopping criteria。

首先需要引用基类:

from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList, \
    STOPPING_CRITERIA_INPUTS_DOCSTRING, add_start_docstrings

其中,

  • StoppingCriteriaList是一个容器,需要将所有的criteria都添加到其中,generate时传入的是这个容器;
  • StoppingCriteria是基础类,自定义的criteria需要继承这个基础类。

接下来就实现一个criteria,效果是,遇到指定的token时,就停止生成:

class StopAtSpecificTokenCriteria(StoppingCriteria):
    """
    当生成出第一个指定token时,立即停止生成
    ---------------
    ver: 2023-08-02
    by: changhongyu
    """
    def __init__(self, token_id_list: List[int] = None):
        """
        :param token_id_list: 停止生成的指定token的id的列表
        """
        self.token_id_list = token_id_list
        
    @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        # return np.argmax(scores[-1].detach().cpu().numpy()) in self.token_id_list
        # 储存scores会额外占用资源,所以直接用input_ids进行判断
        return input_ids[0][-1].detach().cpu().numpy() in self.token_id_list

那么,如果希望遇到句号就停止生成,那就用句号对应的token_id去实例化一个这样的stopping criteria,并将它添加到容器中:

# Llama-2的词表中,英文句号的id是29889
stopping_criteria = StoppingCriteriaList()
stopping_criteria.append(StopAtSpecificTokenCriteria(token_id_list=[29889]))

然后,在生成的时候,假如原本的生成指令是:

model.generate(**inputs)

那么再把stopping criteria作为参数传入进去,就可以发挥效果了:

model.generate(stopping_criteria=stopping_criteria, **inputs)

4. 结语

Stopping Criteria用于在每一个step的生成结束时,判断生成过程是否要结束,是用户控制生成过程的有效手段,其发挥作用的方式也比较直接,实现自定义criteria也并不复杂,只需要确保该类的调用方法返回值是bool值,并覆盖全部情况即可。

Logits Processor是用户控制生成的另一个有效工具,在接下来的博客中,还将介绍自定义logits processor是如何使用的,欢迎感兴趣的同学继续关注。文章来源地址https://www.toymoban.com/news/detail-625482.html

到了这里,关于以Llama-2为例,在生成模型中使用自定义StoppingCriteria的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包