1. pth2onnx
pth模型转onnx模型采用torch框架现有转换函数:torch.onnx.export(),函数使用过程中注意要设置输入输出name,以及将batch维度设置为动态,以便后续onnx转tflite。
具体转换代码如下:
def pth2onnx(pth_path, onnx_path):
//加载模型
model = MyModel() //实例化modle对象
model.load_state_dict(torch.load(pth_path))
//模型转换
dummy_input = torch.randn(1, 3, 256, 256).type(torch.FloatTensor) #.to(self.device_num)
dummy_input = dummy_input.to(next(model.parameters()).device)
input_names = ["inputs"]
output_names = ["outputs"]
dynamic_axes = {"inputs":{0: "batch_size"}}
torch.onnx.export(model,
dummy_input,
onnx_path, //设置onnx模型输出路径,例如:c:/xxx.onnx
export_params=True,
verbose=False,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes
)
print("-----pth to onnx trans successed.")
2. onnx2tflite
网上其他转换方案,虽然可以成功转换,但是存在一个问题,由于pth模型默认的数据存储是[b, c, h, w],转换之后的tflite也是[b, c, h, w],但是tflite的常规数据格式是[b, h, w, c],这个问题曾经让我排查了1天。
最后借助开源项目完成转换,亲测有效,先上连接,再次感谢这位大佬。文章来源:https://www.toymoban.com/news/detail-544599.html
转换代码如下:文章来源地址https://www.toymoban.com/news/detail-544599.html
import sys
sys.path.append("./onnx2tflite_MPolaris")
from converter import onnx_converter
def onnx2tflite(onnx_path, tflite_path):
onnx_converter(
onnx_model_path = onnx_path,
need_simplify = False,
output_path = "./result/",
target_formats = ['tflite'], # or ['keras'], ['keras', 'tflite']
weight_quant = False,
int8_model = False,
int8_mean = None,
int8_std = None,
image_root = None
)
到了这里,关于pth转onnx,onnx转tflite,亲测有效的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!