使用推测解码 (Speculative Decoding) 使 Whisper 实现 2 倍的推理加速

这篇具有很好参考价值的文章主要介绍了使用推测解码 (Speculative Decoding) 使 Whisper 实现 2 倍的推理加速。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

Open AI 推出的 Whisper 是一个通用语音转录模型,在各种基准和音频条件下都取得了非常棒的结果。最新的 large-v3 模型登顶了 OpenASR 排行榜,被评为最佳的开源英语语音转录模型。该模型在 Common Voice 15 数据集的 58 种语言中也展现出了强大的多语言性能,在 42 种语言上的单词错误率 (WER) 低于 30%。

尽管转录准确度非常优秀,但推理速度非常缓慢。即使利用 flash attention 、半精度和 分块 等优化推理技术,1 小时长度的音频在 16GB T4 GPU 上也需要超过 6 分钟的转录时间。

在本文中,我们将演示如何运用推测解码将 Whisper 的推理时间缩减 2 倍,同时在数学上确保完全取得与原模型 相同的输出。因此,这种方法可以完美地替换现有的 Whisper 流水线,因为它可以在不降低准确性的情况下免费获得 2 倍的加速。想要看附带有更简洁解释的全部代码,请参阅配套的 Google Colab。

推测解码

推测解码由 Yaniv Leviathan 等人在 Fast Inference from Transformers via Speculative Decoding 中提出。其思想是,一个更快的 辅助模型 通常会生成和更大的 主模型 相同的 token。

首先,辅助模型会通过自回归生成 \(N\)候选 token 序列: \(\hat{\boldsymbol{y}}_{1:N}\)。在下图中,辅助模型生成了一个包含 5 个候选 token 的序列: The quick brown sock jumps

尽管这些候选 token 可以快速生成,但它们可能与主模型预测的 token 不同。因此,在第二步中,候选 token 被传入主模型以进行“验证”。主模型将候选 token 作为输入,并执行 单次前馈传播。主模型的输出是每个步骤中“正确”token 的序列 $ \boldsymbol{y}_{1:N}$。

在上图中,我们看到主模型预测的前三个 token 与辅助模型的 token 一致: <span style="color:green"> The quick brown 但是,辅助模型的第四个候选 token: “ <span style="color:red"> sock”与主模型的正确 token: “ <span style="color:green"> fox”不一致。

我们知道,所有候选 token 一直到第一个不匹配之前都是正确的 ( <span style="color:green"> The quick brown),因为这些与主模型的预测一致。但是,在第一个不匹配之后,候选 token 开始偏离主模型实际预测的 token。因此,我们可以用主模型的正确 token ( <span style="color:green"> fox) 替换第一个不正确的候选 token ( <span style="color:red"> sock),并放弃之后所有预测的 token,因为这些已经逐渐偏离主模型的预测。经过校正的序列 The quick brown fox 现在成为辅助模型的新输入:

然后,辅助模型再次通过自回归推理,生成一组新的 \(N\) 个候选 token,这些 token 再次通过主模型的单次前馈传播进行验证。

由于我们在生成的时候使用的快速的辅助模型进行自回归,并且缓慢的主模型仅用于验证前馈传播,解码过程将大大加快。此外,经过主模型前馈传播验证后可以确保与仅使用主模型时获得完全相同的输出。这使得推测解码可以完美地替换现有的 Whisper 流水线,因为我们可以确定会取得相同质量的输出。

为了最大限度地减少延迟,辅助模型应该比主模型快得多,同时尽可能频繁地预测相同的 token 分布。实际上,这两个属性之间需要权衡: 模型越快,其准确度越低。然而,由于所有预测 token 中的 70-80% 往往是“较易”的 token,此权衡倾向于选择一个更快的模型,而不是一个更准确的模型。因此,辅助模型应该至少比主模型快 3 倍 (越快越好),同时在示例中正确预测所有较“易”token。剩余的 20-30% 更“难”的 token 可以由更大的主模型进行验证。

选择辅助模型的唯一约束是它必须与主模型使用相同的词汇表。也就是说,辅助模型必须使用与主模型完全一对一相同的分词器。因此,如果我们想对诸如 large-v2 (多语言) 的 Whisper 多语言版本使用推测解码,我们需要选择诸如 tiny 的 Whisper 多语言版本作为辅助模型。而如果我们想对诸如 medium.en 的 Whisper 英文版本使用推测解码,我们需要选择诸如 tiny.en 的 Whisper 英文版本作为辅助模型。目前,large-v3 是唯一一个扩展了词汇量的 Whisper 检查点,因此与以前的 Whisper 检查点不兼容。

现在我们已经了解了推测解码背后的原理,我们准备实际实现它。在 🤗 Transformers 库中,推测解码被实现为“辅助生成 (Assisted Generation)”推理策略。欲了解更多实现细节,建议读者阅读 Joao Gante 关于 辅助生成 的精彩博文。

英文语音转录

基准实现

我们首先使用 Whisper large-v2 进行基准测试,以获得推理速度的基准数值。我们可以通过便捷的 AutoModelForSpeechSeq2SeqAutoProcessor 类加载主模型及其对应的处理器。我们将以 float16 精度加载模型,并通过传递 low_cpu_mem_usage=True 确保加载时间尽可能少。此外,我们要确保模型以 safetensors 格式加载,方法是传递 use_safetensors=True。最后,我们将传递参数 attn_implementation="sdpa" ,以通过 PyTorch 的 SDPA 注意力内核 进行 Flash 注意力加速。

import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-large-v2"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    attn_implementation="sdpa",
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

让我们加载将用于基准测试的英语语音转录数据集。我们将加载 LibriSpeech ASR 中验证数据集的 clean 分组中的 73 个样本组成的小型数据集。这大约有 9MB 的数据,因此非常轻量且可以快速下载到设备上。

from datasets import load_dataset

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

对于基准测试,我们只想测量生成时间,所以让我们编写一个简短的辅助函数来测量此步骤运行的时间。下面的函数将同时返回解码的 token 和运行模型所需的时间:

import time

def generate_with_time(model, inputs, **kwargs):
    start_time = time.time()
    outputs = model.generate(**inputs, **kwargs)
    generation_time = time.time() - start_time
    return outputs, generation_time

现在我们可以迭代语音数据集中的音频样本,并统计整体生成时间:

from tqdm import tqdm

all_time = 0
predictions = []
references = []

for sample in tqdm(dataset):
    audio = sample["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    inputs = inputs.to(device=device, dtype=torch.float16)
  
    output, gen_time = generate_with_time(model, inputs)
    all_time += gen_time
    predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
    references.append(processor.tokenizer._normalize(sample["text"]))

print(all_time)

Output:

100%|██████████| 73/73 [01:37<00:00, 1.33s/it]
72.99542546272278

很好!我们看到转录 73 个样本花了 73 秒。让我们检查一下预测的 WER:

from evaluate import load

wer = load("wer")
print(wer.compute(predictions=predictions, references=references))

Output:

0.03507271171941831

我们的最终基准数值为 73 秒,WER 为 3.5%。

推测解码

现在让我们加载推测解码的辅助模型。在此示例中,我们将使用 Whisper 蒸馏后的版本 distil-large-v2。蒸馏模型只使用了 Whisper 中 32 个解码器层中的 2 个编码器。因此,它比 Whisper 快 6 倍,同时在分布测试集上的 WER 性能相比于蒸馏前仅下降了 1%。这使其成为理想的辅助模型,因为它在转录准确性和生成速度方面都非常优秀\({}^1\)


\({}^1\) 我们即将发布 Distil-Whisper 的改进版本,在 token 分布中具有更佳的对齐性,这将进一步提高推测解码性能。关注 Distil-Whisper 存储库 来追踪最新的更新信息。


由于 Distil-Whisper 使用与 Whisper 模型完全相同的编码器,我们可以在主模型和辅助模型之间共享编码器。然后,我们只需要从 Distil-Whisper 加载 2 层解码器作为“仅解码器”模型。我们可以通过便捷的 AutoModelForCausalLM 自动类实现这一点。在实践中,相比于仅使用主模型,这仅增加了 8%的 VRAM 占用量。

from transformers import AutoModelForCausalLM

assistant_model_id = "distil-whisper/distil-large-v2"

assistant_model = AutoModelForCausalLM.from_pretrained(
    assistant_model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    attn_implementation="sdpa",
)

assistant_model.to(device)

我们可以为推测解码的基准测试定义一个新的函数。与前面的函数唯一的区别是,我们在对 .generate 的调用中传递辅助模型:

def assisted_generate_with_time(model, inputs, **kwargs):
    start_time = time.time()
    outputs = model.generate(**inputs, assistant_model=assistant_model, **kwargs)
    generation_time = time.time() - start_time
    return outputs, generation_time

让我们使用 Distil-Whisper 作为 Whisper 的助手运行推测解码的基准测试:

all_time = 0
predictions = []
references = []

for sample in tqdm(dataset):
    audio = sample["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    inputs = inputs.to(device=device, dtype=torch.float16)
  
    output, gen_time = assisted_generate_with_time(model, inputs)
    all_time += gen_time
    predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
    references.append(processor.tokenizer._normalize(sample["text"]))

print(all_time)

Outputs:

100%|██████████| 73/73 [00:38<00:00, 1.88it/s]
32.69683289527893

使用推测解码,推理时间仅为 33 秒,比之前快 2.2 倍!让我们验证一下 WER 是否相同:

print(wer.compute(predictions=predictions, references=references))

Outputs:

0.03507271171941831

太完美了!再次达到 3.5%的 WER,因为我们的输出与仅使用主模型的时候完全相同。

推测解码也可以与基础的 🤗 Transformers pipeline API 一起用于推理。下面,我们使用模型和处理器实例化管道,然后使用它来转录测试数据集中的第一个样本。这可以扩展为转录任意长度的音频样本,包括进行批处理:

from transformers import pipeline

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    chunk_length_s=15,
    batch_size=4,
    generate_kwargs={"assistant_model": assistant_model},
    torch_dtype=torch_dtype,
    device=device,
)

sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])

Outputs:

 Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.

使用 Whisper 和 Distil-Whisper 运行推测解码的端到端代码示例可在 Distil-Whisper 模型卡 中找到。它将本文中涵盖的推理阶段组合成一个代码示例。

多语言语音转录

Distil-Whisper 是英语语音转录的最佳辅助模型,因为它与原始 Whisper 模型的 WER 误差率仅相差 1%,而对短长语音样本的推理速度提高了 6 倍。然而,官方的 Distil-Whisper 检查点仅支持英语,这意味着它们无法用于多语言语音转录。

要使用推测解码进行多语言语音转录,您可以使用 官方 Whisper 多语言检查点 之一,或者 Whisper 的微调版本。在撰写本文时,Hugging Face Hub 上已有超过 5000 个微调过的 Whisper 检查点,支持超过 100 种语言。这些为选择表现出色的辅助模型提供了极好的起点。在此示例中,我们将使用最小的官方多语言检查点 Whisper tiny。您可以使用任意一个您的语言中微调过的不同检查点!

让我们为新的辅助模型 Whisper tiny 加载权重。由于 Whisper tiny 的编码器与 large-v2 不同,这次我们将使用 AutoModelForSpeechSeq2Seq 类同时加载编码器和解码器:

assistant_model_id = "openai/whisper-tiny"

assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
    assistant_model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    attn_implementation="sdpa",
)

assistant_model.to(device);

我们的基准数据集,将从 VoxPopuli 数据集的荷兰语 (“nl”) 部分中加载 73 个样本:

dataset = load_dataset("sanchit-gandhi/voxpopuli_dummy", "nl", split="validation")

非常好!现在我们可以像前面一样重新运行我们的 Whisper large-v2 模型的基准测试。我们所做的唯一更改是在 generate 函数中传递语言和任务参数,以确保执行语音转录 (而不是语音翻译)。推测解码完全兼容语音转录和翻译任务。只需如下所示设置任务参数即可:

all_time = 0
predictions = []
references = []

for sample in tqdm(dataset):
    audio = sample["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    inputs = inputs.to(device=device, dtype=torch.float16)
  
    output, gen_time = generate_with_time(model, inputs, language="nl", task="transcribe")
    all_time += gen_time
    predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
    references.append(processor.tokenizer._normalize(sample["normalized_text"]))

wer_result = wer.compute(predictions=predictions, references=references)

print("Time:", all_time)
print("WER:", wer_result)

Outputs:

100%|██████████| 73/73 [02:05<00:00, 1.72s/it]
Time: 116.50992178916931
WER: 0.127190136275146

没错!我们的基准时间为 117 秒,WER 为 12.8%。让我们使用推测解码重新运行生成过程:

all_time = 0
predictions = []
references = []

for sample in tqdm(dataset):
    audio = sample["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    inputs = inputs.to(device=device, dtype=torch.float16)

    output, gen_time = assisted_generate_with_time(model, inputs, language="nl", task="transcribe")
    all_time += gen_time
    predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
    references.append(processor.tokenizer._normalize(sample["normalized_text"]))

wer_result = wer.compute(predictions=predictions, references=references)

print("Time:", all_time)
print("WER:", wer_result)

Outputs:

100%|██████████| 73/73 [01:08<00:00, 1.06it/s]
Time: 62.10229682922363
WER: 0.127190136275146

Nice!我们达到了 12.8% 的 WER,但这次的推理时间只有 62 秒,表示速度提高了 1.9 倍。考虑到加载辅助模型的低开销和确保获得完全相同输出的数学证明,推测解码为现有的 Whisper 管道提供了完美的即插即用的替代方案。

高效推测解码的策略

在本最终部分,我们将介绍两种策略,以确保使用推测解码时获得可能最快的推理时间。

辅助模型

我们的目标是选择一个至少比主模型快 3 倍 并且 正确转录至少 70-80% 的预测 token (通常是示例中的“更简单”token) 的辅助模型。如果您想要转录某种特定语言,一种有效的策略是训练两个不同大小的 Whisper 模型,并将其中一个用作另一个的辅助模型:

  • 首先,微调 Whisper large-v3 以用作主模型
  • 其次,在同一数据集上蒸馏 Whisper large-v3 以用作快速的辅助模型

微调和蒸馏都可以提高主模型和辅助模型在您选择的语言上的 WER 性能,同时最大化 token 分布的对齐。有关 Whisper 微调的完整指南,请参阅 此处,有关蒸馏的指南请参阅 此处。

批次大小

值得注意的是,使用推测解码获得的最大速度提升来自批次大小为 1。对于批处理推测解码,批处理中的所有候选 token 必须与验证 token 相匹配,才能被接受。如果批处理中给定位置的 token 不一致,则所有在该位置之前的候选 token 将被丢弃。因此,推测解码更倾向于较小的批次大小。在实践中,我们发现推测解码可以提供速度提升,直到批次大小达到 4 为止。当批次大小超过 4 时,推测解码的推理速度比仅用主模型还要慢。有关完整结果,请参阅 Distil-Whisper 论文 的第 D.3 节。

结论

在本博文中,我们介绍了推测解码的推理策略,以及如何将其应用于语音转录的 Whisper 模型。我们展示了如何实现 2 倍的速度提升,同时数学上确保获得与仅使用原始模型相同的输出。我们鼓励您尝试将推测解码用作现有 Whisper 管道的即插即用替代方案,因为使用额外的辅助模型的开销很小,并且可以保证获得相同的转录结果。

致谢

本博客由 Sanchit Gandhi 撰写。非常感谢 Patrick von Platen 和 Pedro Cuenca 的建设性意见,以及 Joao Gante 在 🤗 Transformers 中实现辅助生成的贡献。


英文原文: https://hf.co/blog/whisper-speculative-decoding

作者: Sanchit Gandhi

译者: Hu Yaoqi (yaoqi)文章来源地址https://www.toymoban.com/news/detail-779418.html

到了这里,关于使用推测解码 (Speculative Decoding) 使 Whisper 实现 2 倍的推理加速的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 使用openai-whisper实现语音转文字

    FFmpeg是一套可以用来记录、转换数字音频、视频,并能将其转化为流的开源计算机程序。采用LGPL或GPL许可证。它提供了录制、转换以及流化音视频的完整解决方案。 在官网上选择windows版本 在GitHub上可以选择最新版本,选择 ffmpeg-master-latest-win64-gpl.zip ; 如果python程序出现“

    2024年02月20日
    浏览(48)
  • C#使用whisper.net实现语音识别(语音转文本)

    目录 介绍 效果 输出信息  项目 代码 下载  github地址:https://github.com/sandrohanea/whisper.net Whisper.net. Speech to text made simple using Whisper Models 模型下载地址:https://huggingface.co/sandrohanea/whisper.net/tree/main/classic whisper_init_from_file_no_state: loading model from \\\'ggml-small.bin\\\' whisper_model_load: loading

    2024年02月05日
    浏览(40)
  • <img> decoding属性

    标签的 decoding 属性用于告诉浏览器使用何种方式解析图像数据。 该属性可以取以下三个值: sync: 同步解码图像,保证与其他内容一起显示。 async: 异步解码图像,加快显示其他内容。 auto: 默认模式,表示不偏好解码模式。由浏览器决定哪种方式更适合用户。 此属性类似于在

    2024年02月11日
    浏览(31)
  • 如和使用matlab实现香农编码和解码

    在网上看了好多 , 都是对香农进行编码的案例 , 却没有 进行解码的操作 , 今天就来补齐这个欠缺 定义一个字符串类型的变量text,其值为’你好’。 调用函数shannonCoding对文本信息进行编码,并将编码、解码、平均码长和编码效率作为四个返回值保存到变量encoded, decoded, avgC

    2024年02月08日
    浏览(45)
  • 使用AutoDecoder自动解码器实现简单MNIST特征向量提取

    自动解码器(AD)是论文\\\"DeepSDF: Learning Continuous Signed Distance Functions for Shape Representation\\\" 中使用的一种方法,与传统编码-解码结构不同,AD无编码器,仅有一个解码器。 解码器实现特征向量(隐向量)与图片之间的转换 。 在训练过程中 同时优化 特征向量与神经网络参数。如

    2024年02月02日
    浏览(48)
  • 神经网络实用工具(整活)系列---使用OpenAI的翻译模型whisper实现语音(中、日、英等等)转中字,从此生肉变熟肉---基础篇

    最近在做神经网络的研究,偶然间看到OpenAI开源出了一个多国语音转文字的模型,脑海里突然想到余大嘴在华为发布会发布实时语音翻译时满屏弹幕的“???”和“!!!”,于是决定做一个多国语音转简体中文字幕的软件来玩一玩。 想法是这样的:通过OpenAI最新发布的翻

    2024年02月09日
    浏览(43)
  • 神经网络实用工具(整活)系列---使用OpenAI的翻译模型whisper实现语音(中、日、英等等)转中字,从此生肉变熟肉---提高篇(附带打包好的程序)

    上一篇文章介绍了怎么用OpenAI的翻译模型whisper实现语音转中字的基本操作,在文章中也明确了该操作存在的三个问题: 处理速度慢。 存在幻听现象,字幕准确度不太理想。 要安装比较多的环境才能运行,对一般用户不太友好。 本篇文章将逐一介绍解决这些遗留问题的方法

    2024年02月12日
    浏览(34)
  • Whisper实现语音识别转文本

    #教程 主要参考开源免费离线语音识别神器whisper如何安装, OpenAI开源模型Whisper——音频转文字 Whisper是一个开源的 自动语音识别 系统,它在网络上收集了680,000小时的多语种和多任务监督数据进行训练,使得它可以将多种语言的音频转文字。 Whisper的好处是 开源免费、支持多

    2024年03月19日
    浏览(49)
  • whisper生成字幕python代码实现

    whisper模型

    2024年02月15日
    浏览(46)
  • 查询效率至少提高4倍的MySQL技巧

    SQL语句中IN包含的值不应过多 MySQL对于IN做了相应的优化,即将IN中的常量全部存储在一个数组里面,而且这个数组是排好序的。但是如果数值较多,产生的消耗也是比较大的。再例如:select id from t where num in(1,2,3) 对于连续的数值,能用between就不要用in了;再或者使用连接来替

    2024年04月26日
    浏览(36)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包