音频数据处理+模型训练保存+Android模型移植
一个epoch , 表示: 所有的数据送入网络中, 完成了一次前向计算 + 反向传播的过程
把数据准备好,开始跑实验
1.分割数据集 scirpt.walk_file(path,out_path)
BirdsSong-2s-20spec
2.生成csv(script.py)
3.将wav音频文件中的音频浮点序列特征提出出来保存成pkl格式(注意数据是2s的,采样率是16000,SIGNAL_LENGYH=2);(get_pkl.py)
注意frames_train.reshape((len(frames_train),32000))。
4.训练完成,生成pt文件(train2023.py)
测试结果:
5.本地测试加载模型,获取识别结果,注意模型保存时设置的维度,以及模型保存是cpu还是gpu
import librosa
import numpy as np
import torch.utils.data.distributed
import torch.nn.functional as F
'cpu加载pytorch模型'
net = torch.load('best_network.pth')
net = net.to('cpu')
net.eval()
waveform, sample_rate = librosa.load('./data/BirdsSong-2s-20spec/train/0009/111651_1.wav',sr=16000,mono=True, offset=0, duration=2)
# 升维
waveform = np.expand_dims(waveform,axis=0)
waveform = np.expand_dims(waveform,axis=0)
# 模型输入要求是张量
x = torch.Tensor(waveform)
output = net(x)
# 归一化处理,数据类型转换,最后保存成list格式[索引,概率]
prob = F.softmax(output,dim=1)
x = torch.max(prob)
max_value = x.item()
max_index =torch.argmax(prob,dim=1)
max_index =max_index.numpy().item()
print("max_value",max_value)
print("max_index",max_index)
result_list =[max_index,max_value]
print(result_list)
6.在Android端部署pt模型
-
添加PyTorch Mobile依赖:
-
在Android项目的
build.gradle
文件中,添加PyTorch Mobile的依赖项。dependencies { // torch依赖 implementation 'org.pytorch:pytorch_android:1.6.0' implementation 'org.pytorch:pytorch_android_torchvision:1.6.0' }
-
在
gradle.repositories
部分添加以下代码,以确保能够访问PyTorch Mobile的存储库。android.useAndroidX=true android.enableJetifier=true
-
-
将TorchScript模型集成到Android应用:
-
将转换后的
your_scripted_model.pt
文件复制到Android项目的assets目录下。 -
在Android应用的代码中加载和执行模型:
-
首先,通过编写
PyTorchLoader
类加载TorchScript模型。文章来源:https://www.toymoban.com/news/detail-811006.htmlimport org.pytorch.Module; import org.pytorch.IValue; import org.pytorch.Tensor; import org.pytorch.torchvision.TorchVision; public class PyTorchLoader { private Module mModule; public PyTorchLoader(AssetManager assetManager, String modelPath) throws IOException { mModule = Module.load(String.valueOf(assetManager.open(modelPath))); } public PyTorchLoader(Module module) throws IOException { mModule = module; } public float[] predict(float[] input) { Tensor inputTensor = Tensor.fromBlob(input, new long[]{1, 1, 32000}); IValue inputs = IValue.from(inputTensor); // 执行推理,获取预测结果张量 Tensor outputTensor = mModule.forward(inputs).toTensor(); // 处理预测结果 float[] outputData = outputTensor.getDataAsFloatArray(); return outputData; } }
-
在您的Android活动中,实例化
PyTorchLoader
类并使用加载的模型进行预测。文章来源地址https://www.toymoban.com/news/detail-811006.htmlimport android.content.res.AssetManager; import android.os.Bundle; import androidx.appcompat.app.AppCompatActivity; import java.io.IOException; public class MainActivity extends AppCompatActivity { private PyTorchLoader mPyTorchLoader; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); try { module = Module.load(assetFilePath(this, "tcnn20_for_android.pt")); mPyTorchLoader = new PyTorchLoader(module); } catch (IOException e) { e.printStackTrace(); } // 运行预测 callPythonCodeFloat(); // 根据模型要求进行调整,获得一维输入数据(1, 1, 32000) Log.d(TAG,"input = "+ Arrays.toString(input)); float[] output = mPyTorchLoader.predict(input); Log.d(TAG,"output[] = "+Arrays.toString(output)); // 处理输出结果 ... } }
-
-
到了这里,关于音频数据处理+模型训练保存+Android模型移植的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!