NeMo中文/英文ASR模型微调训练实践

这篇具有很好参考价值的文章主要介绍了NeMo中文/英文ASR模型微调训练实践。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

1.安装nemo

pip install -U nemo_toolkit[all] ASR-metrics

2.下载ASR预训练模型到本地(建议使用huggleface,比nvidia官网快很多)

3.从本地创建ASR模型

asr_model = nemo_asr.models.EncDecCTCModel.restore_from("stt_zh_quartznet15x5.nemo")

3.定义train_mainfest,包含语音文件路径、时长和语音文本的json文件

{"audio_filepath": "test.wav", "duration": 8.69, "text": "诶前天跟我说昨天跟我说十二期利率是多少工号幺九零八二六十二期的话零点八一万的话分十二期利息八十嘛"}

4.读取模型的yaml配置

# 使用YAML读取quartznet模型配置文件
try:
    from ruamel.yaml import YAML
except ModuleNotFoundError:
    from ruamel_yaml import YAML
config_path ="/NeMo/examples/asr/conf/quartznet/quartznet_15x5_zh.yaml"

yaml = YAML(typ='safe')
with open(config_path) as f:
    params = yaml.load(f)
print(params['model']['train_ds']['manifest_filepath'])
print(params['model']['validation_ds']['manifest_filepath'])

5.设置训练及验证manifest

train_manifest = "train_manifest.json"
val_manifest = "train_manifest.json"

params['model']['train_ds']['manifest_filepath']=train_manifest
params['model']['validation_ds']['manifest_filepath']=val_manifest
print(params['model']['train_ds']['manifest_filepath'])
print(params['model']['validation_ds']['manifest_filepath'])

asr_model.setup_training_data(train_data_config=params['model']['train_ds'])
asr_model.setup_validation_data(val_data_config=params['model']['validation_ds'])

6.使用pytorch_lightning训练
import pytorch_lightning as pl 
trainer = pl.Trainer(accelerator='gpu', devices=1,max_epochs=10)
trainer.fit(asr_model)#调用‘fit’方法开始训练 

7.保存训练好的模型

asr_model.save_to('my_stt_zh_quartznet15x5.nemo')

8.看看训练后的效果

my_asr_model = nemo_asr.models.EncDecCTCModel.restore_from("my_stt_zh_quartznet15x5.nemo")
queries=my_asr_model.transcribe(['test1.wav'])
print(queries)

#['诶前天跟我说的昨天跟我说十二期利率是多少工号幺九零八二六零十二期的话零点八一万的话分十二期利息八十嘛']

9.计算字错率

from ASR_metrics import utils as metrics
s1 = "诶前天跟我说昨天跟我说十二期利率是多少工号幺九零八二六十二期的话零点八一万的话分十二期利息八十嘛"#指定正确答案
s2 = " ".join(queries)#识别结果
print("字错率:{}".format(metrics.calculate_cer(s1,s2)))#计算字错率cer
print("准确率:{}".format(1-metrics.calculate_cer(s1,s2)))#计算准确率accuracy

#字错率:0.041666666666666664

#准确率:0.9583333333333334

10.增加标点符号输出

from zhpr.predict import DocumentDataset,merge_stride,decode_pred
from transformers import AutoModelForTokenClassification,AutoTokenizer
from torch.utils.data import DataLoader

def predict_step(batch,model,tokenizer):
        batch_out = []
        batch_input_ids = batch

        encodings = {'input_ids': batch_input_ids}
        output = model(**encodings)

        predicted_token_class_id_batch = output['logits'].argmax(-1)
        for predicted_token_class_ids, input_ids in zip(predicted_token_class_id_batch, batch_input_ids):
            out=[]
            tokens = tokenizer.convert_ids_to_tokens(input_ids)
            
            # compute the pad start in input_ids
            # and also truncate the predict
            # print(tokenizer.decode(batch_input_ids))
            input_ids = input_ids.tolist()
            try:
                input_id_pad_start = input_ids.index(tokenizer.pad_token_id)
            except:
                input_id_pad_start = len(input_ids)
            input_ids = input_ids[:input_id_pad_start]
            tokens = tokens[:input_id_pad_start]
    
            # predicted_token_class_ids
            predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids]
            predicted_tokens_classes = predicted_tokens_classes[:input_id_pad_start]

            for token,ner in zip(tokens,predicted_tokens_classes):
                out.append((token,ner))
            batch_out.append(out)
        return batch_out

if __name__ == "__main__":
    window_size = 256
    step = 200
    text = queries[0]
    dataset = DocumentDataset(text,window_size=window_size,step=step)
    dataloader = DataLoader(dataset=dataset,shuffle=False,batch_size=5)

    model_name = 'zh-wiki-punctuation-restore'
    model = AutoModelForTokenClassification.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    model_pred_out = []
    for batch in dataloader:
        batch_out = predict_step(batch,model,tokenizer)
        for out in batch_out:
            model_pred_out.append(out)
        
    merge_pred_result = merge_stride(model_pred_out,step)
    merge_pred_result_deocde = decode_pred(merge_pred_result)
    merge_pred_result_deocde = ''.join(merge_pred_result_deocde)
    print(merge_pred_result_deocde)
#诶前天跟我说的。昨天跟我说十二期利率是多少。工号幺九零八二六零十二期的话,零点八一万的话,分十二期利息八十嘛。文章来源地址https://www.toymoban.com/news/detail-644162.html

到了这里,关于NeMo中文/英文ASR模型微调训练实践的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包