1.安装nemo
pip install -U nemo_toolkit[all] ASR-metrics
2.下载ASR预训练模型到本地(建议使用huggleface,比nvidia官网快很多)
3.从本地创建ASR模型
asr_model = nemo_asr.models.EncDecCTCModel.restore_from("stt_zh_quartznet15x5.nemo")
3.定义train_mainfest,包含语音文件路径、时长和语音文本的json文件
{"audio_filepath": "test.wav", "duration": 8.69, "text": "诶前天跟我说昨天跟我说十二期利率是多少工号幺九零八二六十二期的话零点八一万的话分十二期利息八十嘛"}
4.读取模型的yaml配置
# 使用YAML读取quartznet模型配置文件
try:
from ruamel.yaml import YAML
except ModuleNotFoundError:
from ruamel_yaml import YAML
config_path ="/NeMo/examples/asr/conf/quartznet/quartznet_15x5_zh.yaml"
yaml = YAML(typ='safe')
with open(config_path) as f:
params = yaml.load(f)
print(params['model']['train_ds']['manifest_filepath'])
print(params['model']['validation_ds']['manifest_filepath'])
5.设置训练及验证manifest
train_manifest = "train_manifest.json"
val_manifest = "train_manifest.json"
params['model']['train_ds']['manifest_filepath']=train_manifest
params['model']['validation_ds']['manifest_filepath']=val_manifest
print(params['model']['train_ds']['manifest_filepath'])
print(params['model']['validation_ds']['manifest_filepath'])
asr_model.setup_training_data(train_data_config=params['model']['train_ds'])
asr_model.setup_validation_data(val_data_config=params['model']['validation_ds'])
6.使用pytorch_lightning训练
import pytorch_lightning as pl
trainer = pl.Trainer(accelerator='gpu', devices=1,max_epochs=10)
trainer.fit(asr_model)#调用‘fit’方法开始训练
7.保存训练好的模型
asr_model.save_to('my_stt_zh_quartznet15x5.nemo')
8.看看训练后的效果
my_asr_model = nemo_asr.models.EncDecCTCModel.restore_from("my_stt_zh_quartznet15x5.nemo")
queries=my_asr_model.transcribe(['test1.wav'])
print(queries)
#['诶前天跟我说的昨天跟我说十二期利率是多少工号幺九零八二六零十二期的话零点八一万的话分十二期利息八十嘛']
9.计算字错率
from ASR_metrics import utils as metrics
s1 = "诶前天跟我说昨天跟我说十二期利率是多少工号幺九零八二六十二期的话零点八一万的话分十二期利息八十嘛"#指定正确答案
s2 = " ".join(queries)#识别结果
print("字错率:{}".format(metrics.calculate_cer(s1,s2)))#计算字错率cer
print("准确率:{}".format(1-metrics.calculate_cer(s1,s2)))#计算准确率accuracy
#字错率:0.041666666666666664
#准确率:0.9583333333333334
10.增加标点符号输出
from zhpr.predict import DocumentDataset,merge_stride,decode_pred
from transformers import AutoModelForTokenClassification,AutoTokenizer
from torch.utils.data import DataLoader
def predict_step(batch,model,tokenizer):
batch_out = []
batch_input_ids = batch
encodings = {'input_ids': batch_input_ids}
output = model(**encodings)
predicted_token_class_id_batch = output['logits'].argmax(-1)
for predicted_token_class_ids, input_ids in zip(predicted_token_class_id_batch, batch_input_ids):
out=[]
tokens = tokenizer.convert_ids_to_tokens(input_ids)
# compute the pad start in input_ids
# and also truncate the predict
# print(tokenizer.decode(batch_input_ids))
input_ids = input_ids.tolist()
try:
input_id_pad_start = input_ids.index(tokenizer.pad_token_id)
except:
input_id_pad_start = len(input_ids)
input_ids = input_ids[:input_id_pad_start]
tokens = tokens[:input_id_pad_start]
# predicted_token_class_ids
predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids]
predicted_tokens_classes = predicted_tokens_classes[:input_id_pad_start]
for token,ner in zip(tokens,predicted_tokens_classes):
out.append((token,ner))
batch_out.append(out)
return batch_out
if __name__ == "__main__":
window_size = 256
step = 200
text = queries[0]
dataset = DocumentDataset(text,window_size=window_size,step=step)
dataloader = DataLoader(dataset=dataset,shuffle=False,batch_size=5)
model_name = 'zh-wiki-punctuation-restore'
model = AutoModelForTokenClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)文章来源:https://www.toymoban.com/news/detail-644162.html
model_pred_out = []
for batch in dataloader:
batch_out = predict_step(batch,model,tokenizer)
for out in batch_out:
model_pred_out.append(out)
merge_pred_result = merge_stride(model_pred_out,step)
merge_pred_result_deocde = decode_pred(merge_pred_result)
merge_pred_result_deocde = ''.join(merge_pred_result_deocde)
print(merge_pred_result_deocde)
#诶前天跟我说的。昨天跟我说十二期利率是多少。工号幺九零八二六零十二期的话,零点八一万的话,分十二期利息八十嘛。文章来源地址https://www.toymoban.com/news/detail-644162.html
到了这里,关于NeMo中文/英文ASR模型微调训练实践的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!