自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments
推理代码
# text embedding
toks = self.tokenizer([text])
if self.debug:
print('toks', toks)
text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)
错误提示
Traceback (most recent call last):
File "/xx/workspace/model/test_onnx.py", line 90, in <module>
res = inferencer.inference(text, img_path)
File "/xx/workspace/model/test_onnx.py", line 58, in inference
text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)
File "/xx/miniconda3/envs/py39/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
return self._sess.run(output_names, input_feed, run_options)
TypeError: run(): incompatible function arguments. The following argument types are supported:
1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]
Invoked with: <onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession object at 0x7f975ded1570>, ['output'], {'input_ids': array([[ 101, 3899, 102]]), 'token_type_ids': array([[0, 0, 0]]), 'attention_mask': array([[1, 1, 1]])}, None
核心错误
TypeError: run(): incompatible function arguments. The following argument types are supported:
1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]
解决方法
核对参数
arg0: List[str]
arg1: Dict[str, object]
对应的参数
output_names=['output'], input_feed=toks
arg0=[‘output’] 参数类型正确
arg1=toks 表面看参数也正常,打印看看toks的每个值的类型
type(toks[‘input_ids’]) 输出为 <class ‘torch.Tensor’>, 实际需要输入类型为 <class ‘numpy.ndarray’>文章来源:https://www.toymoban.com/news/detail-817054.html
修改代码
# text embedding
toks = self.tokenizer([text])
if self.debug:
print('toks', toks)
text_input = {}
text_input['input_ids'] = toks['input_ids'].numpy()
text_input['token_type_ids'] = toks['token_type_ids'].numpy()
text_input['attention_mask'] = toks['attention_mask'].numpy()
text_embed = self.text_model_session.run(output_names=['output'], input_feed=text_input)
再次执行代码,正常运行,无报错!!文章来源地址https://www.toymoban.com/news/detail-817054.html
到了这里,关于自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!