深度学习模型的Android部署方法

这篇具有很好参考价值的文章主要介绍了深度学习模型的Android部署方法。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

使用背景:

将python中训练的深度学习模型(图像分类、目标检测、语义分割等)部署到Android中使用。


Step1:下载并集成Pytorch Android库

1、下载Pytorch Android库。
在Pytorch的官网pytorch.org上找到最新版本的库。下载后,将其解压缩到项目的某个目录下。

2、配置项目gradle文件
配置项目的gradle文件,向项目添加Pytorch Android库的依赖项。打开项目的build.gradle文件,添加以下代码:

repositories {
    // 添加以下两行代码
    maven {
        url "https://oss.sonatype.org/content/repositories/snapshots/"
    }
}

dependencies {
    // 添加以下两行代码
    implementation 'org.pytorch:pytorch_android:1.8.0-SNAPSHOT'
    implementation 'org.pytorch:pytorch_android_torchvision:1.8.0-SNAPSHOT'
}

3、将库文件添加到项目中
将Pytorch Android库的库文件添加到项目中。可以将其复制到“libs”文件夹中,并在项目的gradle文件中添加以下代码:

android {
    sourceSets {
        main {
            jniLibs.srcDirs = ['libs']
        }
    }
}

4、配置NDK版本
确保项目使用了支持Pytorch Android库的NDK版本。打开项目的local.properties文件,添加以下代码:

//NDK目录
ndk.dir=/path/to/your/ndk 

5、同步gradle文件
在Android Studio中,点击“Sync Project with Gradle Files”按钮,等待同步完成。
到这就集成了Pytorch Android库。可以在应用程序中使用Pytorch Android库提供的API加载模型文件并进行预测。


Step2:准备Pytorch导出的.pt模型文件

假如我们的深度学习模型输入图片大小尺寸为(640,640,3),并且已经在python中训练好了my_model.pth,那么我们需要将其转换为.pt格式:

import torch

# 加载PyTorch模型
model = torch.load("my_model.pth")

# 将PyTorch模型转换为TorchScript格式
traced_script_module = torch.jit.trace(model, torch.randn(1, 3, 640, 640))
traced_script_module.save("my_model.pt")

转换Pytorch模型为TorchScript格式时,需要确保使用的所有操作都是TorchScript支持的。否则,在转换模型时可能会出现错误。

Step3:导入Pytorch模型文件

要在Android Studio中创建新项目并将m_model.pt模型文件放入该项目中,包含以下步骤:
1、打开Android Studio,并选择“Create New Project”选项。
2、在“Create New Project”向导中,输入项目名称,选择项目保存位置,并选择“Phone and Tablet”作为您的应用程序目标设备。然后,单击“Next”继续。
3、选择“Empty Activity”模板,并单击“Next”继续。
4、在“Configure Activity”对话框中,输入活动名称并单击“Finish”完成项目创建过程。
5、在项目中创建一个名为“assets”的文件夹。要创建该文件夹,请右键单击项目根目录,选择“New” -> “Folder” -> “Assets Folder”。
6、将m_model.pt模型文件复制到“assets”文件夹中。要将文件复制到“assets”文件夹中,右键单击该文件夹,选择“Show in Explorer”或“Show in Finder”,然后将文件复制到打开的文件夹中。
7、在代码中加载模型文件使用以下代码示例加载模型文件:

AssetManager assetManager = getAssets();
String modelPath = "m_model.pt";
File modelFile = new File(getCacheDir(), modelPath);

try (InputStream inputStream = assetManager.open(modelPath);
     FileOutputStream outputStream = new FileOutputStream(modelFile)) {
    byte[] buffer = new byte[4 * 1024];
    int read;
    while ((read = inputStream.read(buffer)) != -1) {
        outputStream.write(buffer, 0, read);
    }
    outputStream.flush();
} catch (IOException e) {
    e.printStackTrace();
}

// 加载PyTorch模型
Module model = Module.load(modelFile.getAbsolutePath());

在这里需要注意将模型文件保存到应用程序的缓存目录中,而不是将其保存在项目资源中。这是因为在运行时,Android应用程序不能直接读取项目资源,而是需要使用AssetManager类从“assets”文件夹中读取文件。


Step4:模型的调用及使用示例

接下来示例运行模型、获取模型输出和在主线程中更新UI的代码:

import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.os.Handler;
import android.os.Looper;
import android.util.Log;
import androidx.appcompat.app.AppCompatActivity;
import androidx.camera.core.CameraX;
import androidx.camera.core.ImageAnalysis;
import androidx.camera.core.ImageProxy;
import androidx.camera.core.Preview;
import androidx.camera.lifecycle.ProcessCameraProvider;
import androidx.camera.view.PreviewView;
import androidx.core.content.ContextCompat;
import androidx.lifecycle.LifecycleOwner;
import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class MainActivity extends AppCompatActivity {

    private static final String MODEL_PATH = "m_model.pt";
    private static final int INPUT_SIZE = 224;

    private Module mModule;
    private ExecutorService mExecutorService;
    private Handler mHandler;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        // 加载PyTorch模型和创建执行线程池
        loadModel();

        // 创建主线程处理程序
        mHandler = new Handler(Looper.getMainLooper());

        // 启动相机
        startCamera();
    }

    private void loadModel() {
        // 加载PyTorch模型
        try {
            AssetManager assetManager = getAssets();
            InputStream inputStream = assetManager.open(MODEL_PATH);
            mModule = Module.load(inputStream);
        } catch (IOException e) {
            Log.e("MainActivity", "Error reading model file: " + e.getMessage());
            finish();
        }

        // 创建执行线程池
        mExecutorService = Executors.newSingleThreadExecutor();
    }

    private void startCamera() {
        // 创建PreviewView
        PreviewView previewView = findViewById(R.id.preview_view);

        // 配置相机生命周期所有者
        LifecycleOwner lifecycleOwner = this;

        // 配置相机预览
        Preview preview = new Preview.Builder().build();
        preview.setSurfaceProvider(previewView.getSurfaceProvider());

        // 配置图像分析
        ImageAnalysis imageAnalysis =
                new ImageAnalysis.Builder()
                        .setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST)
                        .build();

        // 设置图像分析的处理程序
        imageAnalysis.setAnalyzer(
                mExecutorService,
                new ImageAnalysis.Analyzer() {
                    @Override
                    public void analyze(ImageProxy image, int rotationDegrees) {
                        // 将ImageProxy转换为Bitmap
                        Bitmap bitmap =
                                Bitmap.createScaledBitmap(
                                        image.getImage(),
                                        INPUT_SIZE,
                                        INPUT_SIZE,
                                        false);

                        // 将Bitmap转换为Tensor
                        Tensor tensor =
                                TensorImageUtils.bitmapToFloat32Tensor(
                                        bitmap,
                                        TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
                                        TensorImageUtils.TORCHVISION_NORM_STD_RGB);

                        // 创建输入列表
                        final IValue[] inputs = {IValue.from(tensor)};

                        // 运行模型
                        Tensor outputTensor = mModule.forward(inputs).toTensor();

                        // 获取模型输出
                        float[] scores = outputTensor.getDataAsFloatArray();

                        // 查找最高分数
                        float maxScore = -Float.MAX_VALUE;
                        int maxScoreIndex = -1;
                        for (int i = 0; i < scores.length; i++) {
                            if (scores[i] > maxScore) {
                                maxScore = scores[i];
                                maxScoreIndex = i;
                            }
                        }

                        // 获取分类标签
                        String[] labels = getLabels();
                        String predictedLabel = labels[maxScoreIndex];

                        // 更新UI
                        updateUI(predictedLabel);
                    }
                });

        // 绑定相机生命周期所有者
        CameraX.bindToLifecycle(lifecycleOwner, preview, imageAnalysis);
    }

    private String[] getLabels() {
        // 在此处替换为标签文件
        return new String[]{
                "tench",
                "goldfish",
                "great white shark",
                "tiger shark",
                // ...
        };
    }

    private void updateUI(String predictedLabel) {
        mHandler.post(
                new Runnable() {
                    @Override
                    public void run() {
                        // 更新UI
                        // 例如,将预测标签写入TextView
                        // TextView textView = findViewById(R.id.text_view);
                        // textView.setText(predictedLabel);
                    }
                });
    }

    @Override
    protected void onDestroy() {
        super.onDestroy();

        // 释放模型和执行线程池
        mModule.destroy();
        mExecutorService.shutdown();
    }
}


当模型预测输入图像时,它将返回一个整数,该整数表示模型预测的图像类型的索引。可以使用该索引来查找对应的标签并更新UI。例如:

// 查找最高分数
float maxScore = -Float.MAX_VALUE;
int maxScoreIndex = -1;
for (int i = 0; i < scores.length; i++) {
    if (scores[i] > maxScore) {
        maxScore = scores[i];
        maxScoreIndex = i;
    }
}

// 获取分类标签
String[] labels = getLabels();
String predictedLabel = labels[maxScoreIndex];

// 更新UI
updateUI(predictedLabel);

Step5:调试程序

编译和运行应用程序,并在Android Studio调试上测试图像识别功能。文章来源地址https://www.toymoban.com/news/detail-495174.html

到了这里,关于深度学习模型的Android部署方法的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 使用OpenCV与深度学习去除图像背景:Python实现指南

    第一部分:简介和OpenCV的背景去除 在现代的图像处理和计算机视觉应用中,背景去除是一个常见的需求。这不仅用于产品摄影和电商平台,还广泛应用于各种图像分析任务。在这篇文章中,我们将使用OpenCV和深度学习技术来实现此功能,并通过Python进行实现。本教程会介绍两

    2024年01月20日
    浏览(46)
  • 深度学习模型部署-番外-TVM机器学习编译

    图片来自知乎大佬的文章 机器学习编译是指:将模型从训练形式转变为部署模式 训练模式:使用训练框架定义的模型 部署模式:部署所需要的模式,包括模型每个步骤的实现代码,管理资源的控制器,与应用程序开发环境的接口。 这个行为和传统的编译很像,所以称为机器

    2024年03月18日
    浏览(40)
  • 使用OpenCV工具包成功实现人脸检测与人脸识别,包括传统视觉和深度学习方法(附完整代码,模型下载......)

    要实现人脸识别功能,首先要进行人脸检测,判断出图片中人脸的位置,才能进行下一步的操作。 参考链接: 1、OpenCV人脸检测 2、【OpenCV-Python】32.OpenCV的人脸检测和识别——人脸检测 3、【youcans 的图像处理学习课】23. 人脸检测:Haar 级联检测器 4、OpenCV实战5:LBP级联分类器

    2024年02月08日
    浏览(45)
  • Gradio快速搭建机器学习模型的wedui展示用户界面/深度学习网页模型部署

    官网 Gradio 是一个开源 Python 包,可让您快速为机器学习模型、API 或任何任意 Python 函数构建演示或 Web 应用程序。然后,您可以使用 Gradio 的内置共享功能在几秒钟内共享演示或 Web 应用程序的链接。无需 JavaScript、CSS 或网络托管经验! 只需几行 Python 代码就可以创建一个像上

    2024年04月23日
    浏览(39)
  • 深度学习模型部署——Flask框架轻量级部署+阿里云服务器

    ​因为参加一个比赛,需要把训练好的深度学习模型部署到web端,第一次做,在网上也搜索了很多教程,基本上没有适合自己的,只有一个b站up主讲的还不错 https://www.bilibili.com/video/BV1Qv41117SR/?spm_id_from=333.999.0.0vd_source=6ca6a313467efae52a28428a64104c10 https://www.bilibili.com/video/BV1Qv41117

    2024年02月07日
    浏览(83)
  • 深度学习模型压缩方法综述

    深度学习因其计算复杂度或参数冗余,在一些场景和设备上限制了相应的模型部署,需要借助 模型压缩 、系统优化加速等方法突破瓶颈,本文主要介绍模型压缩的各种方法,希望对大家有帮助。 我们知道,一定程度上, 网络越深,参数越多,模型也会越复杂,但其最终效果

    2024年02月10日
    浏览(39)
  • 深度学习模型部署综述(ONNX/NCNN/OpenVINO/TensorRT)

    点击下方 卡片 ,关注“ 自动驾驶之心 ”公众号 ADAS巨卷干货,即可获取 今天自动驾驶之心很荣幸邀请到 逻辑牛 分享深度学习部署的入门介绍,带大家盘一盘ONNX、NCNN、OpenVINO等框架的使用场景、框架特点及代码示例。如果您有相关工作需要分享,请在文末联系我们! 点击

    2024年02月08日
    浏览(49)
  • 模型驱动的深度学习方法

           本篇文章是博主在人工智能等领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对人工智能等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在 学习摘录和笔记专栏 :         学习摘录和

    2024年02月16日
    浏览(42)
  • 深度学习提高模型准确率方法

    我们已经收集好了一个数据集,建立了一个神经网络,并训练了模型,在测试和验证阶段最后得到的准确率不高不到90%。或者没有达到业务的期望(需要100%)。 下面列举一些提高模型性能指标的策略或技巧,来提高模型的准确率。 使用更多数据 最简单的方法就是增加数据集

    2024年02月03日
    浏览(55)
  • 深度学习模型部署(六)TensorRT工作流and入门demo

    官方给出的步骤: 总结下来可以分为两大部分: 模型生成:将onnx经过一系列优化,生成tensorrt的engine模型 选择batchsize,选择精度precision,模型转换 模型推理:使用python或者C++进行推理 生成trt模型: 然后就坐等输出模型,我们可以根据log信息看一下tensorRT都干了什么: 得到

    2024年03月13日
    浏览(52)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包