基于X-CUBE-AI的模型推理
本文所使用的版本如下:
- X-CUBE-AI:8.1.0
- STM32CUBEMX:6.7.0
基于CUBEMX导出模型
首先需要再软件包选中X-CUBE-AI:
导入模型进行转换,这里选择STM32Cube.AI Runtime
在底部有RAM与ROM的开销占用:
基于STM32实现模型推理
STM32提供了相关了文档,可以到pack包安装的地方查看这篇文章,我的安装路径如下,每个人的电脑都不一样
file:///D:/IDE/STM32CUBEMX/Repository/Packs/STMicroelectronics/X-CUBE-AI/8.1.0/Documentation/how_to_run_a_model_locally.html
接下来,我们按照文档编写图例代码,本文所使用的模型输入为2048长度的一维浮点数据。
1.引入必要的头文件
#include "stdio.h"
#include <stdlib.h>
#include <time.h>
#include <string.h>
#include "network.h"
#include "network_data.h"
2.创建模型的输入输出以及句柄
AI_ALIGNED(32)
static ai_u8 activations[AI_NETWORK_DATA_ACTIVATIONS_SIZE];
AI_ALIGNED(32)
static ai_float in_data[AI_NETWORK_IN_1_SIZE]; //这里记得修改为自己的类型,以及长度选择SIZE,不要是byte
AI_ALIGNED(32)
static ai_float out_data[AI_NETWORK_OUT_1_SIZE]; //这里也是改为size
ai_buffer *ai_input;
ai_buffer *ai_output;
ai_handle network = AI_HANDLE_NULL;
ai_error err;
ai_network_report report;
3.创建模型初始化代码
int ai_init()
{
const ai_handle acts[] = {activations};
err = ai_network_create_and_init(&network, acts, NULL);
if (err.type != AI_ERROR_NONE)
{
printf("ai init_and_create error\n");
return -1;
}
else
{
printf("ai init success\n");
}
if (ai_network_get_report(network, &report) != true)
{
printf("ai get report error\n");
return -1;
}
printf("Model name : %s\n", report.model_name);
printf("Model signature : %s\n", report.model_signature);
return 0;
}
3.赋值与推理
int ai_run(ai_float *in_data, ai_float *out_data, float *data, int length)
{
ai_i32 n_batch;
for (int i = 0; i < length; i++)
{
in_data[i] = data[i];
}
ai_input = ai_network_inputs_get(network, NULL);
ai_output = ai_network_outputs_get(network, NULL);
ai_input[0].data = AI_HANDLE_PTR(in_data);
ai_output[0].data = AI_HANDLE_PTR(out_data);
n_batch = ai_network_run(network, &ai_input[0], &ai_output[0]);
if (n_batch != 1)
{
ai_network_get_error(network);
printf("run failed\r\n");
return -1;
};
return 0; // success;
}
接下来,我们就可以根据out_data来查看推理结果文章来源:https://www.toymoban.com/news/detail-861473.html
for (int i = 0; i < AI_NETWORK_OUT_1_SIZE; i++)
{
printf("%.2f, ", out_data[i]);
}
和我们上位机的结果保持一致
文章来源地址https://www.toymoban.com/news/detail-861473.html
全部代码
/* USER CODE BEGIN Header */
/**
******************************************************************************
* @file : main.c
* @brief : Main program body
******************************************************************************
* @attention
*
* Copyright (c) 2024 STMicroelectronics.
* All rights reserved.
*
* This software is licensed under terms that can be found in the LICENSE file
* in the root directory of this software component.
* If no LICENSE file comes with this software, it is provided AS-IS.
*
******************************************************************************
*/
/* USER CODE END Header */
/* Includes ------------------------------------------------------------------*/
#include "main.h"
/* Private includes ----------------------------------------------------------*/
/* USER CODE BEGIN Includes */
#include "stdio.h"
#include <stdlib.h>
#include <time.h>
#include <string.h>
#include "network.h"
#include "network_data.h"
/* USER CODE END Includes */
/* Private typedef -----------------------------------------------------------*/
/* USER CODE BEGIN PTD */
/* USER CODE END PTD */
/* Private define ------------------------------------------------------------*/
/* USER CODE BEGIN PD */
/* USER CODE END PD */
/* Private macro -------------------------------------------------------------*/
/* USER CODE BEGIN PM */
/* USER CODE END PM */
/* Private variables ---------------------------------------------------------*/
CRC_HandleTypeDef hcrc;
I2C_HandleTypeDef hi2c1;
UART_HandleTypeDef huart1;
/* USER CODE BEGIN PV */
/* USER CODE END PV */
/* Private function prototypes -----------------------------------------------*/
void SystemClock_Config(void);
static void MX_GPIO_Init(void);
static void MX_CRC_Init(void);
static void MX_I2C1_Init(void);
static void MX_USART1_UART_Init(void);
/* USER CODE BEGIN PFP */
/* USER CODE END PFP */
/* Private user code ---------------------------------------------------------*/
/* USER CODE BEGIN 0 */
int fputc(int ch, FILE *f)
{
HAL_UART_Transmit(&huart1, (uint8_t *)&ch, 1, 0xFFFF);
return ch;
}
AI_ALIGNED(32)
static ai_u8 activations[AI_NETWORK_DATA_ACTIVATIONS_SIZE];
AI_ALIGNED(32)
static ai_float in_data[AI_NETWORK_IN_1_SIZE];
AI_ALIGNED(32)
static ai_float out_data[AI_NETWORK_OUT_1_SIZE];
ai_buffer *ai_input;
ai_buffer *ai_output;
ai_handle network = AI_HANDLE_NULL;
ai_error err;
ai_network_report report;
//替换为自己的数据
float data[] ={};
/**
* @brief ai init
*
* @return int
*/
int ai_init()
{
const ai_handle acts[] = {activations};
err = ai_network_create_and_init(&network, acts, NULL);
if (err.type != AI_ERROR_NONE)
{
printf("ai init_and_create error\n");
return -1;
}
else
{
printf("ai init success\n");
}
if (ai_network_get_report(network, &report) != true)
{
printf("ai get report error\n");
return -1;
}
printf("Model name : %s\n", report.model_name);
printf("Model signature : %s\n", report.model_signature);
return 0;
}
int ai_run(ai_float *in_data, ai_float *out_data, float *data, int length)
{
ai_i32 n_batch;
for (int i = 0; i < length; i++)
{
in_data[i] = data[i];
}
ai_input = ai_network_inputs_get(network, NULL);
ai_output = ai_network_outputs_get(network, NULL);
ai_input[0].data = AI_HANDLE_PTR(in_data);
ai_output[0].data = AI_HANDLE_PTR(out_data);
n_batch = ai_network_run(network, &ai_input[0], &ai_output[0]);
if (n_batch != 1)
{
ai_network_get_error(network);
printf("run failed\r\n");
return -1;
};
return 0; // success;
}
/* USER CODE END 0 */
/**
* @brief
*
*/
int main(void)
{
/* USER CODE BEGIN 1 */
/* USER CODE END 1 */
/* MCU Configuration--------------------------------------------------------*/
/* Reset of all peripherals, Initializes the Flash interface and the Systick. */
HAL_Init();
/* USER CODE BEGIN Init */
/* USER CODE END Init */
/* Configure the system clock */
SystemClock_Config();
/* USER CODE BEGIN SysInit */
/* USER CODE END SysInit */
/* Initialize all configured peripherals */
MX_GPIO_Init();
MX_CRC_Init();
MX_I2C1_Init();
MX_USART1_UART_Init();
/* USER CODE BEGIN 2 */
if (ai_init() != 0)
{
return -1;
}
if (ai_run(in_data, out_data, data, AI_NETWORK_IN_1_SIZE) != 0)
{
return -1;
}
for (int i = 0; i < AI_NETWORK_OUT_1_SIZE; i++)
{
printf("%.2f, ", out_data[i]);
}
/* USER CODE END 2 */
/* Infinite loop */
/* USER CODE BEGIN WHILE */
while (1)
{
/* USER CODE END WHILE */
/* USER CODE BEGIN 3 */
HAL_GPIO_TogglePin(LedHeart_GPIO_Port, LedHeart_Pin);
HAL_Delay(1000);
}
/* USER CODE END 3 */
}
/**
* @brief System Clock Configuration
* @retval None
*/
void SystemClock_Config(void)
{
RCC_OscInitTypeDef RCC_OscInitStruct = {0};
RCC_ClkInitTypeDef RCC_ClkInitStruct = {0};
/** Configure the main internal regulator output voltage
*/
__HAL_RCC_PWR_CLK_ENABLE();
__HAL_PWR_VOLTAGESCALING_CONFIG(PWR_REGULATOR_VOLTAGE_SCALE1);
/** Initializes the RCC Oscillators according to the specified parameters
* in the RCC_OscInitTypeDef structure.
*/
RCC_OscInitStruct.OscillatorType = RCC_OSCILLATORTYPE_HSE;
RCC_OscInitStruct.HSEState = RCC_HSE_ON;
RCC_OscInitStruct.PLL.PLLState = RCC_PLL_ON;
RCC_OscInitStruct.PLL.PLLSource = RCC_PLLSOURCE_HSE;
RCC_OscInitStruct.PLL.PLLM = 4;
RCC_OscInitStruct.PLL.PLLN = 168;
RCC_OscInitStruct.PLL.PLLP = RCC_PLLP_DIV2;
RCC_OscInitStruct.PLL.PLLQ = 4;
if (HAL_RCC_OscConfig(&RCC_OscInitStruct) != HAL_OK)
{
Error_Handler();
}
/** Initializes the CPU, AHB and APB buses clocks
*/
RCC_ClkInitStruct.ClockType = RCC_CLOCKTYPE_HCLK | RCC_CLOCKTYPE_SYSCLK | RCC_CLOCKTYPE_PCLK1 | RCC_CLOCKTYPE_PCLK2;
RCC_ClkInitStruct.SYSCLKSource = RCC_SYSCLKSOURCE_PLLCLK;
RCC_ClkInitStruct.AHBCLKDivider = RCC_SYSCLK_DIV1;
RCC_ClkInitStruct.APB1CLKDivider = RCC_HCLK_DIV4;
RCC_ClkInitStruct.APB2CLKDivider = RCC_HCLK_DIV2;
if (HAL_RCC_ClockConfig(&RCC_ClkInitStruct, FLASH_LATENCY_5) != HAL_OK)
{
Error_Handler();
}
/** Enables the Clock Security System
*/
HAL_RCC_EnableCSS();
}
/**
* @brief CRC Initialization Function
* @param None
* @retval None
*/
static void MX_CRC_Init(void)
{
/* USER CODE BEGIN CRC_Init 0 */
/* USER CODE END CRC_Init 0 */
/* USER CODE BEGIN CRC_Init 1 */
/* USER CODE END CRC_Init 1 */
hcrc.Instance = CRC;
if (HAL_CRC_Init(&hcrc) != HAL_OK)
{
Error_Handler();
}
/* USER CODE BEGIN CRC_Init 2 */
/* USER CODE END CRC_Init 2 */
}
/**
* @brief I2C1 Initialization Function
* @param None
* @retval None
*/
static void MX_I2C1_Init(void)
{
/* USER CODE BEGIN I2C1_Init 0 */
/* USER CODE END I2C1_Init 0 */
/* USER CODE BEGIN I2C1_Init 1 */
/* USER CODE END I2C1_Init 1 */
hi2c1.Instance = I2C1;
hi2c1.Init.ClockSpeed = 100000;
hi2c1.Init.DutyCycle = I2C_DUTYCYCLE_2;
hi2c1.Init.OwnAddress1 = 0;
hi2c1.Init.AddressingMode = I2C_ADDRESSINGMODE_7BIT;
hi2c1.Init.DualAddressMode = I2C_DUALADDRESS_DISABLE;
hi2c1.Init.OwnAddress2 = 0;
hi2c1.Init.GeneralCallMode = I2C_GENERALCALL_DISABLE;
hi2c1.Init.NoStretchMode = I2C_NOSTRETCH_DISABLE;
if (HAL_I2C_Init(&hi2c1) != HAL_OK)
{
Error_Handler();
}
/* USER CODE BEGIN I2C1_Init 2 */
/* USER CODE END I2C1_Init 2 */
}
/**
* @brief USART1 Initialization Function
* @param None
* @retval None
*/
static void MX_USART1_UART_Init(void)
{
/* USER CODE BEGIN USART1_Init 0 */
/* USER CODE END USART1_Init 0 */
/* USER CODE BEGIN USART1_Init 1 */
/* USER CODE END USART1_Init 1 */
huart1.Instance = USART1;
huart1.Init.BaudRate = 115200;
huart1.Init.WordLength = UART_WORDLENGTH_8B;
huart1.Init.StopBits = UART_STOPBITS_1;
huart1.Init.Parity = UART_PARITY_NONE;
huart1.Init.Mode = UART_MODE_TX_RX;
huart1.Init.HwFlowCtl = UART_HWCONTROL_NONE;
huart1.Init.OverSampling = UART_OVERSAMPLING_16;
if (HAL_UART_Init(&huart1) != HAL_OK)
{
Error_Handler();
}
/* USER CODE BEGIN USART1_Init 2 */
/* USER CODE END USART1_Init 2 */
}
/**
* @brief GPIO Initialization Function
* @param None
* @retval None
*/
static void MX_GPIO_Init(void)
{
GPIO_InitTypeDef GPIO_InitStruct = {0};
/* GPIO Ports Clock Enable */
__HAL_RCC_GPIOH_CLK_ENABLE();
__HAL_RCC_GPIOA_CLK_ENABLE();
__HAL_RCC_GPIOD_CLK_ENABLE();
__HAL_RCC_GPIOB_CLK_ENABLE();
/*Configure GPIO pin Output Level */
HAL_GPIO_WritePin(LedHeart_GPIO_Port, LedHeart_Pin, GPIO_PIN_RESET);
/*Configure GPIO pin : LedHeart_Pin */
GPIO_InitStruct.Pin = LedHeart_Pin;
GPIO_InitStruct.Mode = GPIO_MODE_OUTPUT_PP;
GPIO_InitStruct.Pull = GPIO_NOPULL;
GPIO_InitStruct.Speed = GPIO_SPEED_FREQ_LOW;
HAL_GPIO_Init(LedHeart_GPIO_Port, &GPIO_InitStruct);
}
/* USER CODE BEGIN 4 */
/* USER CODE END 4 */
/**
* @brief This function is executed in case of error occurrence.
* @retval None
*/
void Error_Handler(void)
{
/* USER CODE BEGIN Error_Handler_Debug */
/* User can add his own implementation to report the HAL error return state */
__disable_irq();
while (1)
{
}
/* USER CODE END Error_Handler_Debug */
}
#ifdef USE_FULL_ASSERT
/**
* @brief Reports the name of the source file and the source line number
* where the assert_param error has occurred.
* @param file: pointer to the source file name
* @param line: assert_param error line source number
* @retval None
*/
void assert_failed(uint8_t *file, uint32_t line)
{
/* USER CODE BEGIN 6 */
/* User can add his own implementation to report the file name and line number,
ex: printf("Wrong parameters value: file %s on line %d\r\n", file, line) */
/* USER CODE END 6 */
}
#endif /* USE_FULL_ASSERT */
到了这里,关于[STM32]:基于X-CUBE-AI的模型推理的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!