音频数据处理+模型训练保存+Android模型移植

这篇具有很好参考价值的文章主要介绍了音频数据处理+模型训练保存+Android模型移植。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

音频数据处理+模型训练保存+Android模型移植

一个epoch , 表示: 所有的数据送入网络中, 完成了一次前向计算 + 反向传播的过程

把数据准备好,开始跑实验

1.分割数据集 scirpt.walk_file(path,out_path)

BirdsSong-2s-20spec
音频数据处理+模型训练保存+Android模型移植,音视频,android

2.生成csv(script.py)
音频数据处理+模型训练保存+Android模型移植,音视频,android

3.将wav音频文件中的音频浮点序列特征提出出来保存成pkl格式(注意数据是2s的,采样率是16000,SIGNAL_LENGYH=2);(get_pkl.py)

注意frames_train.reshape((len(frames_train),32000))。
音频数据处理+模型训练保存+Android模型移植,音视频,android

4.训练完成,生成pt文件(train2023.py)
音频数据处理+模型训练保存+Android模型移植,音视频,android

测试结果:
音频数据处理+模型训练保存+Android模型移植,音视频,android

5.本地测试加载模型,获取识别结果,注意模型保存时设置的维度,以及模型保存是cpu还是gpu

import librosa
import numpy as np
import torch.utils.data.distributed
import torch.nn.functional as F


'cpu加载pytorch模型'
net = torch.load('best_network.pth')
net = net.to('cpu')
net.eval()
waveform, sample_rate = librosa.load('./data/BirdsSong-2s-20spec/train/0009/111651_1.wav',sr=16000,mono=True, offset=0, duration=2)

# 升维
waveform = np.expand_dims(waveform,axis=0)
waveform = np.expand_dims(waveform,axis=0)

# 模型输入要求是张量
x = torch.Tensor(waveform)
output = net(x)

# 归一化处理,数据类型转换,最后保存成list格式[索引,概率]
prob = F.softmax(output,dim=1)
x = torch.max(prob)
max_value = x.item()
max_index =torch.argmax(prob,dim=1)
max_index =max_index.numpy().item()
print("max_value",max_value)
print("max_index",max_index)

result_list =[max_index,max_value]

print(result_list)

6.在Android端部署pt模型

  1. 添加PyTorch Mobile依赖:

    • 在Android项目的build.gradle文件中,添加PyTorch Mobile的依赖项。

      dependencies {
          // torch依赖
          implementation 'org.pytorch:pytorch_android:1.6.0'
          implementation 'org.pytorch:pytorch_android_torchvision:1.6.0'
      }
      
    • gradle.repositories部分添加以下代码,以确保能够访问PyTorch Mobile的存储库。

      android.useAndroidX=true
      android.enableJetifier=true
      
  2. 将TorchScript模型集成到Android应用:

    • 将转换后的your_scripted_model.pt文件复制到Android项目的assets目录下。

    • 在Android应用的代码中加载和执行模型:

      • 首先,通过编写PyTorchLoader类加载TorchScript模型。

        import org.pytorch.Module;
        import org.pytorch.IValue;
        import org.pytorch.Tensor;
        import org.pytorch.torchvision.TorchVision;
        
        public class PyTorchLoader {
            private Module mModule;
        
            public PyTorchLoader(AssetManager assetManager, String modelPath) throws IOException {
                mModule = Module.load(String.valueOf(assetManager.open(modelPath)));
            }
            public PyTorchLoader(Module module) throws IOException {
                mModule = module;
            }
        
            public float[] predict(float[] input) {
                Tensor inputTensor = Tensor.fromBlob(input, new long[]{1, 1, 32000});
                IValue inputs = IValue.from(inputTensor);
                // 执行推理,获取预测结果张量
                Tensor outputTensor = mModule.forward(inputs).toTensor();
                // 处理预测结果
                float[] outputData = outputTensor.getDataAsFloatArray();
                return outputData;
            }
        }
        
      • 在您的Android活动中,实例化PyTorchLoader类并使用加载的模型进行预测。文章来源地址https://www.toymoban.com/news/detail-811006.html

        import android.content.res.AssetManager;
        import android.os.Bundle;
        import androidx.appcompat.app.AppCompatActivity;
        
        import java.io.IOException;
        
        public class MainActivity extends AppCompatActivity {
        
            private PyTorchLoader mPyTorchLoader;
        
            @Override
            protected void onCreate(Bundle savedInstanceState) {
                super.onCreate(savedInstanceState);
                setContentView(R.layout.activity_main);
        
                try {
                    module = Module.load(assetFilePath(this, "tcnn20_for_android.pt"));
                    mPyTorchLoader = new PyTorchLoader(module);
                    
                } catch (IOException e) {
                    e.printStackTrace();
                }
        
                // 运行预测
                callPythonCodeFloat(); // 根据模型要求进行调整,获得一维输入数据(1, 1, 32000)
        
                Log.d(TAG,"input = "+ Arrays.toString(input));
                float[] output = mPyTorchLoader.predict(input);
                Log.d(TAG,"output[] = "+Arrays.toString(output));
                
                // 处理输出结果
                ...
            }
        }
        

到了这里,关于音频数据处理+模型训练保存+Android模型移植的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包