pytorch模型运行到android手机上(仅使用pytorch+AndroidStudio)

这篇具有很好参考价值的文章主要介绍了pytorch模型运行到android手机上(仅使用pytorch+AndroidStudio)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

近期需要将pytorch模型运行到android手机上实验,在查阅网上博客后,发现大多数流程需要借助多个框架或软件,横跨多个编程语言、IDE。本文参考以下两篇博文,力求用更简洁的流程实现模型部署。

https://blog.csdn.net/xiaodidididi521/article/details/123985612
https://blog.csdn.net/m0_67391683/article/details/125401357

向两位作者表示感谢!本文进一步详细描述了实现流程。

一、pytorch模型转化

pytorch模型无法直接被Android调用,需要转化为特定格式.pt。本文使用pycharm IDE完成这一步,工程目录结构如下:
pytorch模型运行到android手机上(仅使用pytorch+AndroidStudio)#pic_center
其中,vgg16bn_CIFAR10.pth和另一个pth文件是需要部署到手机上的模型,models.py是自己定义的网络结构。在此默认读者熟悉pytorch,对models.py不做赘述。
pytorch模型运行到android手机上(仅使用pytorch+AndroidStudio)

执行以下代码实现转换:

import torch.utils.data.distributed

'定义转化后的模型名称'
model_ori_pt ='model_ori.pt'
model_pruned_pt ='model_pruned.pt'

'加载pytorch模型'
model_ori = torch.load('vgg16bn_CIFAR10.pth')
model_pruned = torch.load('vgg16bn_CIFAR10_pruned.pth')

'模型在cpu上运行'
device = torch.device('cpu')
model_ori.to(device)
model_pruned.to(device)
model_ori.eval()
model_pruned.eval()

'定义输入图片的大小'
input_tensor = torch.rand(1, 3, 32, 32)

'转化模型并存储'
mobile_ori = torch.jit.trace(model_ori, input_tensor)
model_pruned = torch.jit.trace(model_pruned, input_tensor)
mobile_ori.save(model_ori_pt)
model_pruned.save(model_pruned_pt)

请注意,让模型在cpu上,或cuda上执行eval()均可,但要保证模型与input_tensor在同一设备上,否则将运行出错。运行后,会得到model_ori.ptmodel_pruned.pt两个文件,即可以用于android上的文件。此时目录结构如下:
pytorch模型运行到android手机上(仅使用pytorch+AndroidStudio)

二、新建Android Studio工程

首先,需要在本地安装Android Studio,安装流程建议参照:

https://m.runoob.com/android/android-studio-install.html?ivk_sa=1024320u
然后打开Android Studio新建Empy Activity
pytorch模型运行到android手机上(仅使用pytorch+AndroidStudio)

点击Next。
pytorch模型运行到android手机上(仅使用pytorch+AndroidStudio)

点击Finsh。SDK建议选择7.0以往的安卓版本。**首次新建工程底部会长时间出现加载进度条,请耐心等待加载完成。**接下来,我们需要有一部手机调试工程,本文使用Android Studio自带的模拟器。首先点击顶部工具栏的Device Manager。
pytorch模型运行到android手机上(仅使用pytorch+AndroidStudio)
点击create device
pytorch模型运行到android手机上(仅使用pytorch+AndroidStudio)
接下来选择机型、安卓版本、内存等,如不想麻烦可一直点击next。
pytorch模型运行到android手机上(仅使用pytorch+AndroidStudio)
finsh后,Android Studio需要下载安卓版本包,需要耐心等待。下载完成后即可启动虚拟机。

pytorch模型运行到android手机上(仅使用pytorch+AndroidStudio)
再shift+F10即可在模拟机里运行程序。
pytorch模型运行到android手机上(仅使用pytorch+AndroidStudio)

三、转化后的模型部署安卓

首先,新建assets文件夹,请不要直接新建,需右键app->Folder->Assets Folder。
pytorch模型运行到android手机上(仅使用pytorch+AndroidStudio)
之后将转化好的两个模型及侧视图放入assets文件夹。本文使用的是CIFAR10数据集,可在以下网址下载:

http://www.cs.toronto.edu/~kriz/cifar.html
然后在gradle Scripts 文件夹中的build.gradle(Module :app)文件中的depencies里添加:

implementation 'org.pytorch:pytorch_android:1.12.1'
implementation 'org.pytorch:pytorch_android_torchvision:1.12.1'

请注意**1.12.1是本文使用的pytorch版本,读者应该为对应的版本号。**然后点击工具栏下的sync now,再耐心等待运行按钮变绿。
pytorch模型运行到android手机上(仅使用pytorch+AndroidStudio)
双击res->layout->activity_main.xml并切换到code。
pytorch模型运行到android手机上(仅使用pytorch+AndroidStudio)
删除所有代码,复制以下代码段:

<?xml version="1.0" encoding="utf-8"?>
<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    tools:context=".MainActivity">

    <ImageView
        android:id="@+id/image"
        android:layout_width="match_parent"
        android:layout_height="match_parent"
        android:scaleType="fitCenter" />

    <TextView
        android:id="@+id/text"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_gravity="top"
        android:textSize="24sp"
        android:textColor="@android:color/holo_red_light" />

</FrameLayout>

然后右键java里的com.example.工程名 文件夹,New->Java Class。本文新建的类名是CIfarClassed,类内代码:

package com.example.工程名;

public class CifarClassed {
    public static String[] IMAGENET_CLASSES = new String[]{
            "ddd",
            "automobile",
            "bird",
            "cat",
            "deer",
            "dog",
            "frog",
            "horse",
            "ship",
            "truck",
    };
}

最后打开java->com.example.工程名->MainActivity,删除原代码,用以下代码替代:

package com.example.dnna;

import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.util.Log;
import android.widget.ImageView;
import android.widget.TextView;

import org.pytorch.IValue;

import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
import org.pytorch.MemoryFormat;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;

import androidx.appcompat.app.AppCompatActivity;

import com.example.dnna.CifarClassed;

public class MainActivity extends AppCompatActivity {

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        Bitmap bitmap = null;
        Module module_ori = null;
        Module module_pruned = null;
        try {
            // creating bitmap from packaged into app android asset 'image.jpg',
            // app/src/main/assets/image.jpg
            bitmap = BitmapFactory.decodeStream(getAssets().open("x.png"));
            // loading serialized torchscript module from packaged into app android asset model.pt,
            // app/src/model/assets/model.pt
            module_ori = Module.load(assetFilePath(this, "model_ori.pt"));
            module_pruned = Module.load(assetFilePath(this, "model——pruned.pt"));
        } catch (IOException e) {
            Log.e("PytorchHelloWorld", "Error reading assets", e);
            finish();
        }

        // showing image on UI
        ImageView imageView = findViewById(R.id.image);
        imageView.setImageBitmap(bitmap);

        // preparing input tensor
        final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
                TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);

        // running the model
        long startTime_ori = System.currentTimeMillis();
        final Tensor outputTensor_ori = module_ori.forward(IValue.from(inputTensor)).toTensor();
        long endTime_ori = System.currentTimeMillis();
        long InferenceTimeOri=endTime_ori - startTime_ori;

        long startTime_pruned = System.currentTimeMillis();
        final Tensor outputTensor_pruned = module_pruned.forward(IValue.from(inputTensor)).toTensor();
        long endTime_pruned = System.currentTimeMillis();
        long InferenceTimePruned=endTime_pruned - startTime_pruned;

        // getting tensor content as java array of floats
        final float[] scores = outputTensor_ori.getDataAsFloatArray();

        // searching for the index with maximum score
        float maxScore = -Float.MAX_VALUE;
        int maxScoreIdx = -1;
        for (int i = 0; i < scores.length; i++) {
            if (scores[i] > maxScore) {
                maxScore = scores[i];
                maxScoreIdx = i;
            }
        }
        System.out.println(maxScoreIdx);
        String className = CifarClassed.IMAGENET_CLASSES[maxScoreIdx];

        // showing className on UI
        TextView textView = findViewById(R.id.text);
        String tex="推理结果:"+className+"\n原始模型推理时间:"+InferenceTimeOri+"ms"+"\n剪枝模型推理时间:"+InferenceTimePruned+"ms";
        textView.setText(tex);
    }

    /**
     * Copies specified asset to the file in /files app directory and returns this file absolute path.
     *
     * @return absolute file path
     */
    public static String assetFilePath(Context context, String assetName) throws IOException {
        File file = new File(context.getFilesDir(), assetName);
        if (file.exists() && file.length() > 0) {
            return file.getAbsolutePath();
        }

        try (InputStream is = context.getAssets().open(assetName)) {
            try (OutputStream os = new FileOutputStream(file)) {
                byte[] buffer = new byte[4 * 1024];
                int read;
                while ((read = is.read(buffer)) != -1) {
                    os.write(buffer, 0, read);
                }
                os.flush();
            }
            return file.getAbsolutePath();
        }
    }
}

运行效果如下图:

pytorch模型运行到android手机上(仅使用pytorch+AndroidStudio)

四、结语

本文的主要流程是:

  • 使用pytorch转化模型
  • 新建Android Studio工程与虚拟机
  • 修改Android Studio工程代码

本人目前希望提升自己的博客撰写水平,如读者在实现过程中遇到困难,或在阅读本文时感到困惑,欢迎留言或添加我的QQ:1106295085。我将在周日下午回复,并积极修改本文。文章来源地址https://www.toymoban.com/news/detail-400097.html

到了这里,关于pytorch模型运行到android手机上(仅使用pytorch+AndroidStudio)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • python在手机上怎么运行,手机怎么用python程序

    大家好,本文将围绕python在手机上怎么运行展开说明,手机怎么用python程序是一个很多人都想弄明白的事情,想搞清楚手机版的python怎么用需要先了解以下几个事情。 大家好,小编来为大家解答以下问题,手机上的python怎么运行程序,手机的python怎么运行文件,今天让我们一

    2024年02月03日
    浏览(50)
  • 手机上的python怎么运行,python在手机上怎么操作

    大家好,小编来为大家解答以下问题,python在手机上怎么操作,手机上的python怎么运行,现在让我们一起来看看吧! 手机浏览器运行python是因为手机浏览器和python两者之间是可以互相兼容的,手机浏览器可以对python的内容数据进行解压和储存显示,所以才会出现手机浏览器运

    2024年02月08日
    浏览(40)
  • python在手机上怎么运行,手机版的python怎么用

    这篇文章主要介绍了python在手机上怎么运行,具有一定借鉴价值,需要的朋友可以参考下。希望大家阅读完这篇文章后大有收获,下面让小编带着大家一起了解一下。 如何用手机编程Python? 1.QPython3:这是一个在安卓手机上运行python3的脚本引擎,整合了python3解释器、控制台

    2024年01月16日
    浏览(54)
  • 【Android】在AndroidStudio开发工具运行Java程序

    在Android Studio开发工具中,Android系统开始就是用java语言开发的,还可以java代码来写程序,控制台,桌面应用,还可以写可调用的模块,这里讲一下创建Java程序步骤,方便入门java语言开发。 新建一个Android项目时,要选择第一个,就是空的(不带模板)的项目,这里打开会有

    2024年02月11日
    浏览(55)
  • Android studio 软件介绍及运行到手机上

    下载好软件后,双击打开 注意第一次打开软件是如下图显示,如果是已经打开过项目,会默认打开最后关闭的项目 如果是已经是打开状态,点击左上角File - New - New Project 新建项目 创建项目 1, 选择模板 2. 基础设置,点击完成 进入页面,如下图 1,先说一下布局和类位置 类

    2024年02月10日
    浏览(79)
  • Android 手机运行 JoyCon Droid 并且使用 Amiibo

    PS: 整个过程耗时耗力,经常会断开连接,有些不想搞那么麻烦的人就不要搞了,以免遭受刺激啊,哈哈。 如果想使用并刷Amiibo,必须同时满足以下几个条件: 1. 蓝牙版本需要 5.0 。 2. 可以解开 BootLoader 锁。 3. 安卓版本 = 9.0。 经过了多次刷机,最后终于取得了成功,中间的

    2024年02月04日
    浏览(27)
  • 关于为在手机上开发/运行Python程序的研究报告以及为手机打包Python应用的研究。

    前一段时间莫名地想用Python开发手机应用。经过日日夜夜在互联网上的挖掘于是有了这样一篇导航性的文章兼入坑/踩坑记录。必须承认Python在手机领域的进展还停留在研发阶段,作者也是真心希望更多的大佬参与到这个领域的先驱部队中,开发出一款完备的引擎之类的。 如

    2024年02月14日
    浏览(66)
  • 使用android studio编译app到自己的手机上运行,却读取不了手机里面的图片

    问题描述: 使用android studio编译app到自己的手机上运行,却读取不了手机里面的图片 问题分析: 这个是由于这个app没有申请手机端的 媒体文件访问权限,所以读取不了 解决:(我的是Android 10,新版本可能会有不同) 查看AndroidManifest.xml这个文件,发现原来只有permission.CAM

    2024年01月17日
    浏览(52)
  • vue项目运行后使用ip地址在手机上打开

    window+r,输入cmd按回车后在输入ipconfig ipv4地址就是你了 (1)vue.config.js文件中修改 host localhost 为 host 0.0.0.0 (2)你的vue.config.js文件可以由于配置较多,配置在了config文件夹的index.js文件中 (3)修改package.json文件中添加–host 0.0.0.0,然后重新重启项目。 “scripts”: { “dev”

    2024年02月12日
    浏览(44)
  • Android开发:使用AndroidStudio开发记单词APP(带数据库)

    实现功能 :设计与开发记单词系统的四个界面,分别是用户登录、用户注册、单词操作以及忘记密码。 指标要求 :通过用户登录、用户注册、单词操作、忘记密码掌握界面设计的基础,其中包括界面布局、常用控件、事件处理等相关内容,通过所学内容设计与开发的界面要

    2024年02月12日
    浏览(41)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包