使用 java-onnx 部署 Meta-ai Segment anything 分割一切

这篇具有很好参考价值的文章主要介绍了使用 java-onnx 部署 Meta-ai Segment anything 分割一切。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

java onnx,深度学习/机器学习/强化学习/算法,人工智能,opencv,meta-ai,sam,分割一切java onnx,深度学习/机器学习/强化学习/算法,人工智能,opencv,meta-ai,sam,分割一切

java onnx,深度学习/机器学习/强化学习/算法,人工智能,opencv,meta-ai,sam,分割一切

 

近日,Meta AI在官网发布了基础模型 Segment Anything Model(SAM)并开源,其本质是用GPT的方式(基于Transform 模型架构)让计算机具备理解了图像里面的一个个“对象”的通用能力。SAM模型建立了一个可以接受文本提示、基于海量数据(603138)训练而获得泛化能力的图像分割大模型。图像分割是计算机视觉中的一项重要任务,有助于识别和确认图像中的不同物体,把它们从背景中分离出来,这在自动驾驶(检测其他汽车、行人和障碍物)、医学成像(提取特定结构或潜在病灶)等应用中特别重要。

下面是 java 使用 onnx 进行推理的分割代码,提示抠图点进行分割,目前还没有文本交互式提示的部署按理。代码如下:文章来源地址https://www.toymoban.com/news/detail-752540.html

package tool.deeplearning;

import ai.onnxruntime.*;
import org.opencv.core.*;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;

import javax.imageio.ImageIO;
import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.nio.FloatBuffer;
import java.util.*;


/**
*   @desc : meta-ai sam , 使用抠图点进行分割
*   @auth : tyf
*   @date : 2023-04-25  09:34:40
*/
public class metaai_sam_test {


    // 模型1
    public static OrtEnvironment env1;
    public static OrtSession session1;

    // 模型2
    public static OrtEnvironment env2;
    public static OrtSession session2;


    // 模型1
    public static void init1(String weight) throws Exception{
        // opencv 库
        System.loadLibrary(Core.NATIVE_LIBRARY_NAME);

        env1 = OrtEnvironment.getEnvironment();
        session1 = env1.createSession(weight, new OrtSession.SessionOptions());

        // 打印模型信息,获取输入输出的shape以及类型:
        System.out.println("---------模型1输入-----------");
        session1.getInputInfo().entrySet().stream().forEach(n->{
            String inputName = n.getKey();
            NodeInfo inputInfo = n.getValue();
            long[] shape = ((TensorInfo)inputInfo.getInfo()).getShape();
            String javaType = ((TensorInfo)inputInfo.getInfo()).type.toString();
            System.out.println(inputName+" -> "+ Arrays.toString(shape)+" -> "+javaType);
        });
        System.out.println("---------模型1输出-----------");
        session1.getOutputInfo().entrySet().stream().forEach(n->{
            String outputName = n.getKey();
            NodeInfo outputInfo = n.getValue();
            long[] shape = ((TensorInfo)outputInfo.getInfo()).getShape();
            String javaType = ((TensorInfo)outputInfo.getInfo()).type.toString();
            System.out.println(outputName+" -> "+Arrays.toString(shape)+" -> "+javaType);
        });
//        session1.getMetadata().getCustomMetadata().entrySet().forEach(n->{
//            System.out.println("元数据:"+n.getKey()+","+n.getValue());
//        });

    }

    // 模型2
    public static void init2(String weight) throws Exception{
        // opencv 库
        System.loadLibrary(Core.NATIVE_LIBRARY_NAME);

        env2 = OrtEnvironment.getEnvironment();
        session2 = env2.createSession(weight, new OrtSession.SessionOptions());

        // 打印模型信息,获取输入输出的shape以及类型:
        System.out.println("---------模型2输入-----------");
        session2.getInputInfo().entrySet().stream().forEach(n->{
            String inputName = n.getKey();
            NodeInfo inputInfo = n.getValue();
            long[] shape = ((TensorInfo)inputInfo.getInfo()).getShape();
            String javaType = ((TensorInfo)inputInfo.getInfo()).type.toString();
            System.out.println(inputName+" -> "+ Arrays.toString(shape)+" -> "+javaType);
        });
        System.out.println("---------模型2输出-----------");
        session2.getOutputInfo().entrySet().stream().forEach(n->{
            String outputName = n.getKey();
            NodeInfo outputInfo = n.getValue();
            long[] shape = ((TensorInfo)outputInfo.getInfo()).getShape();
            String javaType = ((TensorInfo)outputInfo.getInfo()).type.toString();
            System.out.println(outputName+" -> "+Arrays.toString(shape)+" -> "+javaType);
        });
//        session2.getMetadata().getCustomMetadata().entrySet().forEach(n->{
//            System.out.println("元数据:"+n.getKey()+","+n.getValue());
//        });

    }


    public static class ImageObj{

        // 原始图片
        Mat src;
        Mat dst_3_1024_1024;
        BufferedImage image_3_1024_1024;
        float[][][] image_embeddings;
        ArrayList<float[]> points;
        float[][] info;
        public ImageObj(String image) {
            this.src = this.readImg(image);
            this.dst_3_1024_1024 = this.resizeWithoutPadding(src,1024,1024);
            this.image_3_1024_1024 = mat2BufferedImage(this.dst_3_1024_1024);
        }
        public void setPoints(ArrayList<float[]> points) {

            this.points = points;
        }

        public Mat readImg(String path){
            Mat img = Imgcodecs.imread(path);
            return img;
        }
        public Mat resizeWithoutPadding(Mat src,int inputWidth,int inputHeight){
            Mat resizedImage = new Mat();
            Size size = new Size(inputWidth, inputHeight);
            Imgproc.resize(src, resizedImage, size, 0, 0, Imgproc.INTER_AREA);
            return resizedImage;
        }
        public float[] chw2chw(float[][][] chw,int c,int h,int w){
            float[] res = new float[ c * h * w ];

            int index = 0;
            for(int i=0;i<c;i++){
                for(int j=0;j<h;j++){
                    for(int k=0;k<w;k++){
                        float d = chw[i][j][k];
                        res[index] = d;
                        index++;
                    }
                }
            }

            return res;
        }

        // 推理1
        public void infenence1() throws Exception{
            Imgproc.cvtColor(dst_3_1024_1024, dst_3_1024_1024, Imgproc.COLOR_BGR2RGB);
            dst_3_1024_1024.convertTo(dst_3_1024_1024, CvType.CV_32FC1);
            float[] whc = new float[ 3 * 1024 * 1024];
            dst_3_1024_1024.get(0, 0, whc);
            float[] chw = new float[whc.length];
            int j = 0;
            for (int ch = 0; ch < 3; ++ch) {
                for (int i = ch; i < whc.length; i += 3) {
                    chw[j] = whc[i];
                    j++;
                }
            }
            float mean = 0.0f;
            float std = 0.0f;
            for (int i = 0; i < chw.length; i++) {
                mean += chw[i];
            }
            mean /= chw.length;
            for (int i = 0; i < chw.length; i++) {
                std += Math.pow(chw[i] - mean, 2);
            }
            std = (float) Math.sqrt(std / chw.length);
            for (int i = 0; i < chw.length; i++) {
                chw[i] = (chw[i] - mean) / std;
            }
            OnnxTensor tensor = OnnxTensor.createTensor(env1, FloatBuffer.wrap(chw), new long[]{1,3,1024,1024});
            OrtSession.Result res = session1.run(Collections.singletonMap("x", tensor));
            float[][][] image_embeddings  = ((float[][][][])(res.get(0)).getValue())[0];
            this.image_embeddings = image_embeddings;

        }

        // 推理2
        public void infenence2() throws Exception{


            float[] chw = this.chw2chw(this.image_embeddings,256,64,64);
            OnnxTensor _image_embeddings = OnnxTensor.createTensor(env2, FloatBuffer.wrap(chw), new long[]{1,256, 64, 64});

            float[] pc = new float[points.size()*2];
            float[] pc_label = new float[points.size()];

            for(int i=0;i<points.size();i++){
                float[] xyl = points.get(i);
                pc[i*2] = xyl[0] * 1024f / Float.valueOf(src.width());
                pc[i*2+1] = xyl[1] * 1024f / Float.valueOf(src.height());
                pc_label[i] = xyl[2];
            }

            OnnxTensor _point_coords = OnnxTensor.createTensor(env2, FloatBuffer.wrap(pc), new long[]{1,points.size(),2});
            OnnxTensor _point_labels = OnnxTensor.createTensor(env2, FloatBuffer.wrap(pc_label), new long[]{1,points.size()});

            OnnxTensor _orig_im_size = OnnxTensor.createTensor(env2, FloatBuffer.wrap(new float[]{1024,1024}), new long[]{2});


            OnnxTensor _has_mask_input = OnnxTensor.createTensor(env2, FloatBuffer.wrap(new float[]{0}), new long[]{1});

            float[] ar_256_156 = new float[256*256];
            for(int i=0;i<256*156;i++){
                ar_256_156[i] = 0;
            }
            OnnxTensor _mask_input = OnnxTensor.createTensor(env2, FloatBuffer.wrap(ar_256_156), new long[]{1,1,256,256});

            // 封装参数
            Map<String,OnnxTensor> in = new HashMap<>();
            in.put("image_embeddings",_image_embeddings);
            in.put("point_coords", _point_coords);
            in.put("point_labels",_point_labels);
            in.put("has_mask_input",_has_mask_input);
            in.put("orig_im_size",_orig_im_size);
            in.put("mask_input",_mask_input);


            // 推理
            OrtSession.Result res = session2.run(in);

            float[][][] masks  = ((float[][][][])(res.get(0)).getValue())[0];
            float[][] iou_predictions  = ((float[][])(res.get(1)).getValue());
            float[][][][] low_res_masks  = ((float[][][][])(res.get(2)).getValue());


            int count = masks.length;

            for(int i=0;i < count;i++){
                float[][] info = masks[i];
                this.info = info;
                break;

            }

        }

        public BufferedImage mat2BufferedImage(Mat mat){
            BufferedImage bufferedImage = null;
            try {
                MatOfByte matOfByte = new MatOfByte();
                Imgcodecs.imencode(".jpg", mat, matOfByte);
                byte[] byteArray = matOfByte.toArray();
                ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(byteArray);
                bufferedImage = ImageIO.read(byteArrayInputStream);
            } catch (Exception e) {
                e.printStackTrace();
            }
            return bufferedImage;
        }

        public BufferedImage resize(BufferedImage img, int newWidth, int newHeight) {
            Image scaledImage = img.getScaledInstance(newWidth, newHeight, Image.SCALE_SMOOTH);
            BufferedImage scaledBufferedImage = new BufferedImage(newWidth, newHeight, BufferedImage.TYPE_INT_ARGB);
            Graphics2D g2d = scaledBufferedImage.createGraphics();
            g2d.drawImage(scaledImage, 0, 0, null);
            g2d.dispose();
            return scaledBufferedImage;
        }

        public void show(){


            int sub_w = info.length;
            int sub_h = info[0].length;

            for(int j=0;j<sub_w;j++){
                for(int k=0;k<sub_h;k++){
                    float da = info[j][k];
                    da = da + 1;
                    if(da>0.5){
                        // 修改颜色为绿色
                        image_3_1024_1024.setRGB(k,j, Color.GREEN.getRGB());
                    }
                }
            }

            BufferedImage showImg = resize(image_3_1024_1024,src.width(),src.height());

            // 弹窗显示
            JFrame frame = new JFrame();
            frame.setTitle("Meta-ai: SAM");
            JPanel content = new JPanel();
            content.add(new JLabel(new ImageIcon(showImg)));
            frame.add(content);
            frame.pack();
            frame.setVisible(true);


        }

    }


    public static void main(String[] args) throws Exception{


        init1(new File("").getCanonicalPath()+
                "\\src\\main\\resources\\deeplearning\\metaai_sam\\encoder-vit_b.quant.onnx");

        init2(new File("").getCanonicalPath()+
                "\\src\\main\\resources\\deeplearning\\metaai_sam\\decoder-vit_b.quant.onnx");


        // 图片
        ImageObj imageObj = new ImageObj(new File("").getCanonicalPath()+
                "\\src\\main\\resources\\deeplearning\\metaai_sam\\truck.jpg");


        // 提示,这里使用抠图点进行提示,可以设置多个提示点
        ArrayList<float[]> points = new ArrayList<>();
        points.add(new float[]{514,357,1});// 车窗户
        points.add(new float[]{555,377,1});// 车窗户
        points.add(new float[]{556,387,1});// 车窗户
        imageObj.setPoints(points);

        // 推理
        imageObj.infenence1();
        imageObj.infenence2();

        // 显示
        imageObj.show();


    }


}

到了这里,关于使用 java-onnx 部署 Meta-ai Segment anything 分割一切的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • Meta AI最新出品,全能的分割模型SAM:掀桌子的Segment Anything,CV届的ChatGPT已经到来!

    本来不打算再发关于分割的相关内容的,但是13小时前,2023年4月5号,Meta AI在Arxiv网站发布了文章《Segment Anything》,并将SAM模型代码和数据开源。作为通用的分割网络,SAM或许将成为,甚至是已经成为了CV届的ChatGPT。简简单单的两个词Segment Anything,简单粗暴却不失优雅。 说

    2023年04月15日
    浏览(39)
  • Meta:segment anything

    介绍地址:https://ai.facebook.com/research/publications/segment-anything/ 演示地址:https://segment-anything.com/demo# 论文:https://scontent-akl1-1.xx.fbcdn.net/v/t39.2365-6/10000000_900554171201033_1602411987825904100_n.pdf?_nc_cat=100ccb=1-7_nc_sid=3c67a6_nc_ohc=Ald4OYhL6hgAX-FZV7S_nc_ht=scontent-akl1-1.xxoh=00_AfDDJRfDV85B3em0zMZvyCIp882H7Ha

    2024年02月05日
    浏览(34)
  • AI模型部署落地综述(ONNX/NCNN/TensorRT等)

    导读 费尽心血训练好的深度学习模型如何给别人展示?只在服务器上运行demo怎么吸引别人的目光?怎么才能让自己的成果落地?这篇文章带你进入模型部署的大门。 模型部署的步骤: 训练一个深度学习模型; 使用不同的推理框架对模型进行推理转换; 在应用平台运行转换

    2024年01月22日
    浏览(40)
  • Meta的分割一切模型SAM( Segment Anything )测试

    Meta不久前开源发布了一款图像处理模型,即分割一切模型:Segment Anything Model,简称 SAM,号称要从任意一张图片中分割万物,源码地址为: 打开后看到目录结构大概这样: 一般一个开源项目中都会有项目介绍和示例代码。本示例中的文件 README.md 即为项目概况介绍,主要说明

    2023年04月27日
    浏览(38)
  • CV不存在了?体验用Segment Anything Meta分割清明上河图

    在图像处理与计算机视觉领域, 图像分割(image segmentation) 是在像素级别将一个完整图像划分为若干具有特定语义 区域(region) 或 对象(object) 的过程。每个分割区域是一系列拥有相似特征——例如颜色、强度、纹理等的像素集合,因此图像分割也可视为 以图像属性为特征空间,

    2023年04月20日
    浏览(35)
  • 【多模态】14、Segment Anything | Meta 推出超强悍可分割一切的模型 SAM

    论文:Segment Anything 官网:https://segment-anything.com/ 代码:https://github.com/facebookresearch/segment-anything 出处:Meta、FAIR 时间:2023.04.05 贡献点: 首次提出基于提示的分割任务,并开源了可以分割一切的模型 SAM 开源了一个包含 1100 万张图像(约包含 10 亿 masks)的数据集 SA-1B,是目前

    2024年02月16日
    浏览(40)
  • 【多模态】12、Segment Anything | Meta 推出超强悍可分割一切的模型 SAM

    论文:Segment Anything 官网:https://segment-anything.com/ 代码:https://github.com/facebookresearch/segment-anything 出处:Meta、FAIR 时间:2023.04.05 贡献点: 首次提出基于提示的分割任务,并开源了可以分割一切的模型 SAM 开源了一个包含 1100 万张图像(约包含 10 亿 masks)的数据集 SA-1B,是目前

    2024年02月17日
    浏览(39)
  • 【AIGC】6、Segment Anything | Meta 推出超强悍可分割一切的模型 SAM

    论文:Segment Anything 官网:https://segment-anything.com/ 代码:https://github.com/facebookresearch/segment-anything 出处:Meta、FAIR 时间:2023.04.05 贡献点: 首次提出基于提示的分割任务,并开源了可以分割一切的模型 SAM 开源了一个包含 1100 万张图像(约包含 10 亿 masks)的数据集 SA-1B,是目前

    2023年04月23日
    浏览(49)
  • segment-anything本地部署使用

    前言 Segment Anything Model(SAM)是一种先进的图像分割模型,它基于Facebook AI在2020年发布的Foundation Model3,能够根据简单的输入提示(如点或框)准确地分割图像中的任何对象,并且无需额外训练就能适应不熟悉的对象和图像4。它利用了传统的计算机视觉技术和深度学习算法,

    2024年02月07日
    浏览(32)
  • ONNX格式模型 学习笔记 (onnxRuntime部署)---用java调用yolov8模型来举例

    ONNX(Open Neural Network Exchange)是一个开源项目,旨在建立一个开放的标准,使深度学习模型 可以在不同的软件平台和工具之间轻松移动和重用 。 ONNX模型可以用于各种应用场景,例如机器翻译、图像识别、语音识别、自然语言处理等。 由于ONNX模型的互操作性,开发人员 可以

    2024年01月22日
    浏览(35)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包