ONNX格式模型 学习笔记 (onnxRuntime部署)---用java调用yolov8模型来举例

这篇具有很好参考价值的文章主要介绍了ONNX格式模型 学习笔记 (onnxRuntime部署)---用java调用yolov8模型来举例。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

java yolo8,java,YOLO,开发语言

ONNX(Open Neural Network Exchange)是一个开源项目,旨在建立一个开放的标准,使深度学习模型可以在不同的软件平台和工具之间轻松移动和重用

ONNX模型可以用于各种应用场景,例如机器翻译、图像识别、语音识别、自然语言处理等。

由于ONNX模型的互操作性,开发人员可以使用不同的框架来训练,模型可以更容易地在不同的框架之间转换,例如从PyTorch转换到TensorFlow,或从TensorFlow转换到MXNet等。然后将其部署到不同的环境中,例如云端、边缘设备或移动设备等

ONNX还提供了一组工具和库,帮助开发人员更容易地创建、训练和部署深度学习模型。

ONNX模型是由多个节点(node)组成的图(graph),每个节点代表一个操作或一个张量(tensor)。ONNX模型还包含了一些元数据,例如模型的版本、输入和输出张量的名称等。

onnx官网

ONNX | Home

pytorch官方使用onnx模型格式举例

(optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime — PyTorch Tutorials 2.2.0+cu121 documentation

TensorFlow官方使用onnx模型格式举例

https://github.com/onnx/tutorials/blob/master/tutorials/TensorflowToOnnx-1.ipynb

Netron可视化模型结构工具

Netron

你可通过该工具看到onnx具体的模型结构,点击每层都能看到其对应的内容信息

java yolo8,java,YOLO,开发语言

onnxRuntime  | 提供各种编程语言推导onnx格式模型的接口

ONNX Runtime | Home

比如我需要在java环境下调用一个onnx模型,我可以先导入onnxRuntime的依赖,对数据预处理后,调用onnx格式模型正向传播导出数据,然后将数据处理成我要的数据。 

onnxRuntime也提供了其他编程语言的接口,如C++、C#、JavaScript、python等等。

实际案例举例

python部分

python下利用ultralytics从网上下载并导出yolov8的onnx格式模型,用java调用onnxruntim接口,正向传播推导模型数据。

pip install ultralytics
from ultralytics import YOLO

# 加载模型
model = YOLO('yolov8n.pt')  # 加载官方模型
#加载自定义训练的模型
#model = YOLO('F:\\File\\AI\\Object\\yolov8_test\\runs\\detect\\train\\weights\\best.pt')  

# 导出模型
model.export(format='onnx')

java部分

前提安装java的opencv(Get Started - OpenCV),我这安装的是opencv480

maven依赖

<dependencies>
        <dependency>
            <groupId>com.microsoft.onnxruntime</groupId>
            <artifactId>onnxruntime</artifactId>
            <version>1.12.0</version>
        </dependency>


        <!-- 加载lib目录下的opencv包 -->
        <dependency>
            <groupId>org.opencv</groupId>
            <artifactId>opencv</artifactId>
            <version>4.8.0</version>
            <scope>system</scope>
            <!--通过路径加载OpenCV480的jar包-->
            <systemPath>${basedir}/lib/opencv-480.jar</systemPath>
        </dependency>


        <dependency>
            <groupId>com.alibaba</groupId>
            <artifactId>fastjson</artifactId>
            <version>2.0.32</version>
        </dependency>
</dependencies>

java完整代码

package com.sky;

//天宇 2023/12/21 20:23:13


import ai.onnxruntime.*;
import com.alibaba.fastjson.JSONObject;
import org.opencv.core.*;
import org.opencv.core.Point;
import org.opencv.highgui.HighGui;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;

import java.nio.FloatBuffer;
import java.text.DecimalFormat;
import java.util.*;
import java.util.List;

/**
 * onnx学习笔记  GTianyu
 */
public class onnxLoadTest01 {
    public static OrtEnvironment env;
    public static OrtSession session;
    public static JSONObject names;
    public static long count;
    public static long channels;
    public static long netHeight;
    public static long netWidth;
    public static  float srcw;
    public static  float srch;
    public static float confThreshold = 0.25f;
    public static float nmsThreshold = 0.5f;
    static Mat src;

    public static void load(String path) {
         String weight = path;
        try{
            env = OrtEnvironment.getEnvironment();
            session = env.createSession(weight, new OrtSession.SessionOptions());
            OnnxModelMetadata metadata = session.getMetadata();
            Map<String, NodeInfo> infoMap = session.getInputInfo();
            TensorInfo nodeInfo = (TensorInfo)infoMap.get("images").getInfo();
            String nameClass = metadata.getCustomMetadata().get("names");
            System.out.println("getProducerName="+metadata.getProducerName());
            System.out.println("getGraphName="+metadata.getGraphName());
            System.out.println("getDescription="+metadata.getDescription());
            System.out.println("getDomain="+metadata.getDomain());
            System.out.println("getVersion="+metadata.getVersion());
            System.out.println("getCustomMetadata="+metadata.getCustomMetadata());
            System.out.println("getInputInfo="+infoMap);
            System.out.println("nodeInfo="+nodeInfo);
            System.out.println(nameClass);
            names = JSONObject.parseObject(nameClass.replace("\"","\"\""));
            count = nodeInfo.getShape()[0];//1 模型每次处理一张图片
            channels = nodeInfo.getShape()[1];//3 模型通道数
            netHeight = nodeInfo.getShape()[2];//640 模型高
            netWidth = nodeInfo.getShape()[3];//640 模型宽
            System.out.println(names.get(0));
            // 加载opencc需要的动态库
            System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
        }
        catch (Exception e){
            e.printStackTrace();
            System.exit(0);
        }
    }


    public static Map<Object, Object> predict(String imgPath) throws Exception {
        src=Imgcodecs.imread(imgPath);
        return predictor();
    }

    public static Map<Object, Object> predict(Mat mat) throws Exception {
        src=mat;
        return predictor();
    }


    public static OnnxTensor transferTensor(Mat dst){
        Imgproc.cvtColor(dst, dst, Imgproc.COLOR_BGR2RGB);
        dst.convertTo(dst, CvType.CV_32FC1, 1. / 255);
        float[] whc = new float[ Long.valueOf(channels).intValue() * Long.valueOf(netWidth).intValue() * Long.valueOf(netHeight).intValue() ];
        dst.get(0, 0, whc);
        float[] chw = whc2cwh(whc);
        OnnxTensor tensor = null;
        try {
            tensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(chw), new long[]{count,channels,netWidth,netHeight});
        }
        catch (Exception e){
            e.printStackTrace();
            System.exit(0);
        }
        return tensor;
    }


    //宽 高 类型 to 类 宽 高
    public static float[] whc2cwh(float[] src) {
        float[] chw = new float[src.length];
        int j = 0;
        for (int ch = 0; ch < 3; ++ch) {
            for (int i = ch; i < src.length; i += 3) {
                chw[j] = src[i];
                j++;
            }
        }
        return chw;
    }

    public static Map<Object, Object> predictor() throws Exception{
        srcw = src.width();
        srch = src.height();
        System.out.println("width:"+srcw+" hight:"+srch);
        System.out.println("resize: \n width:"+netWidth+" hight:"+netHeight);
        float scaleW=srcw/netWidth;
        float scaleH=srch/netHeight;
        // resize
        Mat dst=new Mat();
        Imgproc.resize(src, dst, new Size(netWidth, netHeight));
        // 转换成Tensor数据格式
        OnnxTensor tensor = transferTensor(dst);
        OrtSession.Result result = session.run(Collections.singletonMap("images", tensor));
        System.out.println("res Data: "+result.get(0));
        OnnxTensor res = (OnnxTensor)result.get(0);
        float[][][] dataRes = (float[][][])res.getValue();
        float[][] data = dataRes[0];

        // 将矩阵转置
        // 先将xywh部分转置
        float rawData[][]=new float[data[0].length][6];
        System.out.println(data.length-1);
        for(int i=0;i<4;i++){
            for(int j=0;j<data[0].length;j++){
                rawData[j][i]=data[i][j];
            }
        }
        // 保存每个检查框置信值最高的类型置信值和该类型下标
        for(int i=0;i<data[0].length;i++){
            for(int j=4;j<data.length;j++){
                if(rawData[i][4]<data[j][i]){
                    rawData[i][4]=data[j][i];           //置信值
                    rawData[i][5]=j-4;                  //类型编号
                }
            }
        }
        List<ArrayList<Float>> boxes=new LinkedList<ArrayList<Float>>();
        ArrayList<Float> box=null;
        // 置信值过滤,xywh转xyxy
        for(float[] d:rawData){
            // 置信值过滤
            if(d[4]>confThreshold){
                // xywh(xy为中心点)转xyxy
                d[0]=d[0]-d[2]/2;
                d[1]=d[1]-d[3]/2;
                d[2]=d[0]+d[2];
                d[3]=d[1]+d[3];
                // 置信值符合的进行插入法排序保存
                box=new ArrayList<Float>();
                for(float num:d) {
                    box.add(num);
                }
                if(boxes.size()==0){
                    boxes.add(box);
                }else {
                    int i;
                    for(i=0;i<boxes.size();i++){
                        if(box.get(4)>boxes.get(i).get(4)){
                            boxes.add(i,box);
                            break;
                        }
                    }
                    // 插入到最后
                    if(i==boxes.size()){
                        boxes.add(box);
                    }
                }
            }
        }

        // 每个框分别有x1、x1、x2、y2、conf、class
        //System.out.println(boxes);

        // 非极大值抑制
        int[] indexs=new int[boxes.size()];
        Arrays.fill(indexs,1);                       //用于标记1保留,0删除
        for(int cur=0;cur<boxes.size();cur++){
            if(indexs[cur]==0){
                continue;
            }
            ArrayList<Float> curMaxConf=boxes.get(cur);   //当前框代表该类置信值最大的框
            for(int i=cur+1;i<boxes.size();i++){
                if(indexs[i]==0){
                    continue;
                }
                float classIndex=boxes.get(i).get(5);
                // 两个检测框都检测到同一类数据,通过iou来判断是否检测到同一目标,这就是非极大值抑制
                if(classIndex==curMaxConf.get(5)){
                    float x1=curMaxConf.get(0);
                    float y1=curMaxConf.get(1);
                    float x2=curMaxConf.get(2);
                    float y2=curMaxConf.get(3);
                    float x3=boxes.get(i).get(0);
                    float y3=boxes.get(i).get(1);
                    float x4=boxes.get(i).get(2);
                    float y4=boxes.get(i).get(3);
                    //将几种不相交的情况排除。提示:x1y1、x2y2、x3y3、x4y4对应两框的左上角和右下角
                    if(x1>x4||x2<x3||y1>y4||y2<y3){
                        continue;
                    }
                    // 两个矩形的交集面积
                    float intersectionWidth =Math.max(x1, x3) - Math.min(x2, x4);
                    float intersectionHeight=Math.max(y1, y3) - Math.min(y2, y4);
                    float intersectionArea =Math.max(0,intersectionWidth * intersectionHeight);
                    // 两个矩形的并集面积
                    float unionArea = (x2-x1)*(y2-y1)+(x4-x3)*(y4-y3)-intersectionArea;
                    // 计算IoU
                    float iou = intersectionArea / unionArea;
                    // 对交并比超过阈值的标记
                    indexs[i]=iou>nmsThreshold?0:1;
                    //System.out.println(cur+" "+i+" class"+curMaxConf.get(5)+" "+classIndex+"  u:"+unionArea+" i:"+intersectionArea+"  iou:"+ iou);
                }
            }
        }


        List<ArrayList<Float>> resBoxes=new LinkedList<ArrayList<Float>>();
        for(int index=0;index<indexs.length;index++){
            if(indexs[index]==1) {
                resBoxes.add(boxes.get(index));
            }
        }
        boxes=resBoxes;

        System.out.println("boxes.size : "+boxes.size());
        for(ArrayList<Float> box1:boxes){
            box1.set(0,box1.get(0)*scaleW);
            box1.set(1,box1.get(1)*scaleH);
            box1.set(2,box1.get(2)*scaleW);
            box1.set(3,box1.get(3)*scaleH);
        }
        System.out.println("boxes: "+boxes);
        //detect(boxes);
        Map<Object,Object> map=new HashMap<Object,Object>();
        map.put("boxes",boxes);
        map.put("classNames",names);
        return map;
    }


    public static Mat showDetect(Map<Object,Object> map){
        List<ArrayList<Float>> boxes=(List<ArrayList<Float>>)map.get("boxes");
        JSONObject names=(JSONObject) map.get("classNames");
        Imgproc.resize(src,src,new Size(srcw,srch));
        // 画框,加数据
        for(ArrayList<Float> box:boxes){
            float x1=box.get(0);
            float y1=box.get(1);
            float x2=box.get(2);
            float y2=box.get(3);
            float config=box.get(4);
            String className=(String)names.get((int)box.get(5).intValue());;
            Point point1=new Point(x1,y1);
            Point point2=new Point(x2,y2);
            Imgproc.rectangle(src,point1,point2,new Scalar(0,0,255),2);
            String conf=new DecimalFormat("#.###").format(config);
            Imgproc.putText(src,className+" "+conf,new Point(x1,y1-5),0,0.5,new Scalar(255,0,0),1);
        }
        HighGui.imshow("image",src);
        HighGui.waitKey();
        return src;
    }


    public static void main(String[] args) throws Exception {
        String modelPath="C:\\Users\\tianyu\\IdeaProjects\\test1\\src\\main\\java\\com\\sky\\best.onnx";
        String path="C:\\Users\\tianyu\\IdeaProjects\\test1\\src\\main\\resources\\img\\img.png";
        onnxLoadTest01.load(modelPath);
        Map<Object,Object> map=onnxLoadTest01.predict(path);
        showDetect(map);
    }
}

效果:

java yolo8,java,YOLO,开发语言
java yolo8,java,YOLO,开发语言

参考文献:

使用 java-onnx 部署 yolovx 目标检测_java onnx-CSDN博客文章来源地址https://www.toymoban.com/news/detail-815211.html

到了这里,关于ONNX格式模型 学习笔记 (onnxRuntime部署)---用java调用yolov8模型来举例的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • ONNX:C++通过onnxruntime使用.onnx模型进行前向计算【下载的onnxruntime是编译好的库文件,可直接使用】

    微软联合Facebook等在2017年搞了个深度学习以及机器学习模型的格式标准–ONNX,旨在将所有模型格式统一为一致,更方便地实现模型部署。现在大多数的深度学习框架都支持ONNX模型转出并提供相应的导出接口。 ONNXRuntime(Open Neural Network Exchange)是微软推出的一款针对ONNX模型格式

    2024年02月15日
    浏览(49)
  • C++使用onnxruntime/opencv对onnx模型进行推理(附代码)

    结果: current image classification : French bulldog, possible : 16.17 对两张图片同时进行推理 current image classification : French bulldog, possible : 16.17 current image class ification : hare, possible : 8.47 https://download.csdn.net/download/qq_44747572/87810859 https://blog.csdn.net/qq_44747572/article/details/131631153

    2024年02月05日
    浏览(51)
  • Android+OnnxRuntime+Opencv+Onnx模型操作图片擦除多余内容

    今年来AI的发展非常迅速,在工业、医疗等等行业逐渐出现相应的解决方案,AI也逐渐成为各行业基础设施建设重要的一环,未来发展的大趋势,不过这也需要一个漫长的过程,需要很多技术型人才加入其中,除了工业设施的基础建设,在娱乐方向也有很多有趣的能力,不如图

    2024年04月13日
    浏览(43)
  • 【深度学习】【Opencv】【CPU】Python/C++调用onnx模型【基础】

    提示:博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论 OpenCV是一个基于BSD许可发行的跨平台计算机视觉和机器学习软件库(开源),可以运行在Linux、Windows、Android和Mac OS操作系统上。可以将pytorch中训练好的模型使用ONNX导出,再使用opencv中的dnn模块直接进行

    2024年02月04日
    浏览(46)
  • 【深度学习】【Opencv】【GPU】python/C++调用onnx模型【基础】

    提示:博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论 OpenCV是一个基于BSD许可发行的跨平台计算机视觉和机器学习软件库(开源),可以运行在Linux、Windows、Android和Mac OS操作系统上。可以将pytorch中训练好的模型使用ONNX导出,再使用opencv中的dnn模块直接进行

    2024年02月04日
    浏览(56)
  • TRT4-trt-integrate - 3 使用onnxruntime进行onnx的模型推理过程

    onnx是microsoft开发的一个中间格式,而onnxruntime简称ort是microsoft为onnx开发的推理引擎。 允许使用onnx作为输入进行直接推理得到结果。 建立一个InferenceSession,塞进去的是onnx的路径,实际运算的后端选用的是CPU 也可以选用cuda等等 之后就是预处理 session.run就是运行的inference过程

    2024年02月15日
    浏览(43)
  • 深度学习模型部署综述(ONNX/NCNN/OpenVINO/TensorRT)

    点击下方 卡片 ,关注“ 自动驾驶之心 ”公众号 ADAS巨卷干货,即可获取 今天自动驾驶之心很荣幸邀请到 逻辑牛 分享深度学习部署的入门介绍,带大家盘一盘ONNX、NCNN、OpenVINO等框架的使用场景、框架特点及代码示例。如果您有相关工作需要分享,请在文末联系我们! 点击

    2024年02月08日
    浏览(48)
  • Pytorch图像分类模型转ONNX(同济子豪兄学习笔记)

    安装配置环境 代码运行云GPU平台:公众号 人工智能小技巧 回复 gpu 同济子豪兄 2022-8-22 2023-4-28 2023-5-8 安装 Pytorch 安装 ONNX 安装推理引擎 ONNX Runtime 安装其它第三方工具包 验证安装配置成功 Pytorch图像分类模型转ONNX-ImageNet1000类 把Pytorch预训练ImageNet图像分类模型,导出为ONNX格

    2024年02月09日
    浏览(49)
  • Torch 模型 onnx 文件的导出和调用

    Open Neural Network Exchange (ONNX,开放神经网络交换) 格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移 Torch 所定义的模型为动态图,其前向传播是由类方法定义和实现的 但是 Python 代码的效率是比较底下的,试想把动态图转化为静态图,模型的推理速

    2024年02月02日
    浏览(35)
  • AI模型部署 | onnxruntime部署YOLOv8分割模型详细教程

    本文首发于公众号【DeepDriving】,欢迎关注。 0. 引言 我之前写的文章《基于YOLOv8分割模型实现垃圾识别》介绍了如何使用 YOLOv8 分割模型来实现垃圾识别,主要是介绍如何用自定义的数据集来训练 YOLOv8 分割模型。那么训练好的模型该如何部署呢? YOLOv8 分割模型相比检测模型

    2024年04月24日
    浏览(38)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包