在使用pytorch模型训练完成之后,我们现在使用的比较多的一种方法是将pytorch模型转成onnx格式的模型中间文件,然后再根据使用的硬件来生成具体硬件使用的深度学习模型,比如TensorRT。
在从pytorch模型转为onnx时,我们可能会遇到部分算子无法转换的问题,本篇注意记录下解决方法。
在导出onnx时,如果出现报错的算子,可以先在下面的链接中查找onnx算子是否支持
https://github.com/onnx/onnx/blob/main/docs/Operators.md
pytorch中有,onnx中也有的算子
导出时使用的onnx op 版本低导致
这个就好解决了,把op库的版本提高就行,但是有可能提高了版本以后,又出现了原来支持的算子现在又不支持了,这个再说
pytorch中没有注册某个onnx算子
如果是这种情况,就按照下面的方式进行:
from torch.onnx import register_custom_op_symbolic
# 创建一个asinh算子的symblic,符号函数,用来登记
# 符号函数内部调用g.op, 为onnx计算图添加Asinh算子
# g: 就是graph,计算图
# 也就是说,在计算图中添加onnx算子
# 由于我们已经知道Asinh在onnx是有实现的,所以我们只要在g.op调用这个op的名字就好了
# symblic的参数需要与Pytorch的asinh接口函数的参数对齐
# def asinh(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def asinh_symbolic(g, input, *, out=None):
return g.op("Asinh", input)
# 在这里,将asinh_symbolic这个符号函数,与PyTorch的asinh算子绑定。也就是所谓的“注册算子”
# asinh是在名为aten的一个c++命名空间下进行实现的
# aten是"a Tensor Library"的缩写,是一个实现张量运算的C++库
register_custom_op_symbolic('aten::asinh', asinh_symbolic, 12)
另外一个写法
这个是类似于torch/onnx/symbolic_opset*.py中的写法
通过torch._internal中的registration来注册这个算子,让这个算子可以与底层C++实现的aten::asinh绑定
一般如果这么写的话,其实可以把这个算子直接加入到torch/onnx/symbolic_opset*.py中文章来源:https://www.toymoban.com/news/detail-787646.html
import functools
from torch.onnx import register_custom_op_symbolic
from torch.onnx._internal import registration
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9)
@_onnx_symbolic('aten::asinh')
def asinh_symbolic(g, input, *, out=None):
return g.op("Asinh", input)
pytorch中有,onnx中无的算子
继承torch.autograd.Function实现自定义算子
import torch
import torch.onnx
import onnxruntime
from torch.onnx import register_custom_op_symbolic
OperatorExportTypes = torch._C._onnx.OperatorExportTypes
class CustomOp(torch.autograd.Function):
@staticmethod
def symbolic(g: torch.Graph, x: torch.Value) -> torch.Value:
return g.op("custom_domain::customOp2", x)
@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(x)
x = x.clamp(min=0)
return x / (1 + torch.exp(-x))
customOp = CustomOp.apply
然后再自己实现custom_domain::customOp2这个算子,如果用TensorRT,就需要自己实现一个插件。文章来源地址https://www.toymoban.com/news/detail-787646.html
到了这里,关于pytorch导出onnx时遇到不支持的算子怎么解决的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!