手把手教你用MindSpore训练一个AI模型!

这篇具有很好参考价值的文章主要介绍了手把手教你用MindSpore训练一个AI模型!。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

手把手教你用MindSpore训练一个AI模型!

首先我们要先了解深度学习的概念和AI计算框架的角色(https://zhuanlan.zhihu.com/p/463019160),本篇文章将演示怎么利用MindSpore来训练一个AI模型。和上一章的场景一致,我们要训练的模型是用来对手写数字图片进行分类的LeNet5模型

请参考(http://yann.lecun.com/exdb/lenet/)。

手把手教你用MindSpore训练一个AI模型!

图1 MindSpore使用流程

安装MindSpore

MindSpore提供给用户使用的是Python接口(什么是Python,请参考:

https://zhuanlan.zhihu.com/p/462756985),所以我们首先需要安装MindSpore的whl包,安装之后就可以导入(import)MindSpore提供的方法接口了。安装whl包有两种方式:

方式一:进入MindSpore官网,根据自己的设备和Python版本选择安装命令。比如我的Python版本是3.7.5,我的设备是笔记本(CPU),那么我就复制下图红框中的命令进行安装:

手把手教你用MindSpore训练一个AI模型!

图2 MindSpore安装界面

安装过程如下:

手把手教你用MindSpore训练一个AI模型!

图3 MindSpore安装过程

注意:由于MindSpore还依赖于其他的Python三方库,所以在安装过程中,系统还会自动下载、安装其他的Python三方库,如numpy、pillow、scipy等等,安装结束后,如果能 import mindspore 成功,说明MindSpore安装成功了:

手把手教你用MindSpore训练一个AI模型!

图4 MindSpore安装成功

方式二:可以在版本列表中找到对应的whl包,点击就能下载:

手把手教你用MindSpore训练一个AI模型!

图5 MindSpore版本下载列表

下载完成后,把whl包放到自己的目录下,执行 pip install xxx.whl:

手把手教你用MindSpore训练一个AI模型!

图6 MindSpore第二种安装方式

定义模型

安装好MindSpore之后,我们就可以导入MindSpore提供的算子(卷积、全连接、池化等函数:https://zhuanlan.zhihu.com/p/463019160)来构建我们的模型了。可以这么比喻:我们构建一个AI模型就像建一个房子,而MindSpore提供给我们的算子就像是砖块、窗户、地板等基本组件。

手把手教你用MindSpore训练一个AI模型!

图7 定义LeNet5模型

如上图所示,我们用到的“砖块”都是mindspore.nn模块提供的。注意:这里用到了Python的类(class),由②和③两部分组成。我们这里定义的类是class LeNet5,它由初始化函数 __init__(self) 和构造函数construct(self, x)组成。初始化函数定义了我们构造模型所需要用到的算子,比如conv算子、relu算子、flatten算子等等,这些算子都是从mindspore.nn获取的;构造函数就是把我们在初始化函数中导入的算子按顺序排放,构成我们最终的模型。construct()函数的输入就是我们这个模型预测的对象,比如第一章讲的黑白图片像素矩阵;而“return y”中的就是预测的结果,对应于第一章讲到的10分类手写数字数据集,就是一个行10列的数组(这里的是指输入图片的数量,AI模型支持多张图片同时推理)。

导入训练数据集

什么是训练数据集?刚刚定义好的模型是不能对图片进行正确分类的,我们要通过“训练”过程来调整模型的参数矩阵的值。训练过程就需要用到训练样本,也就是打上了正确标签的图片。这就好比我们教小孩儿认识动物,需要拿几张图片给他们看,然后告诉他们这是什么、那是什么,教了几遍之后,小孩儿就能认识了。那么我们训练LeNet5模型就需要用到MNIST数据集,请参考(http://yann.lecun.com/exdb/mnist/)。这个数据集由两部分组成:训练集(6万张图片)和测试集(1万张图片),都是0~9的黑白手写数字图片。训练集是用来训练AI模型的,测试集是用来测试训练后的模型分类准确率的。

下载得到的数据集最初是压缩文件,还不能直接传给MindSpore的训练接口使用,我们要先用MindSpore提供的数据处理接口把他们读进来:

import mindspore.dataset as ds
mnist_ds = ds.MnistDataset(data_path)  # 导入下载的MNIST数据集

然后进行数据增强(比如把图片大小转化成相同的尺寸、像素值标准化、归一化等操作),提升训练效率:

import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype

# 定义数据增强函数
def create_dataset(data_path, batch_size=32):  # batch_size是每一步训练使用的图片数量,一般取32
    """
    create dataset for train or test

    Args:
        data_path (str): Data path
        batch_size (int): The number of data records in each group
    """
    # define dataset
    mnist_ds = ds.MnistDataset(data_path)  # 导入下载的MNIST数据集
    # define some parameters needed for data enhancement and rough justification
    resize_height, resize_width = 32, 32
    rescale = 1.0 / 255.0
    shift = 0.0
    rescale_nml = 1 / 0.3081
    shift_nml = -1 * 0.1307 / 0.3081

    # according to the parameters, generate the corresponding data enhancement method
    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
    rescale_op = CV.Rescale(rescale, shift)
    hwc2chw_op = CV.HWC2CHW()
    type_cast_op = C.TypeCast(mstype.int32)

    # using map to apply operations to a dataset
    mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label")
    mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image")
    mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image")
    mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image")
    mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image")

    # process the generated dataset
    buffer_size = 10000
    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
    return mnist_ds

 训练模型

训练数据集和模型定义完成之后呢,我们就可以开始训练模型了。但是在训练之前,我们还需要从MindSpore导入两个函数:

  • 损失函数,也就是衡量预测结果和真实标签之间的差距的函数。看过上一章的同学可能会记得,我们之前用的损失函数是真实值与预测值之差的2-范数:

手把手教你用MindSpore训练一个AI模型!

图8 2-范数损失

在这里,我们使用业界最常用的交叉熵损失函数SoftmaxCrossEntropyWithLogits,对于真实标签

手把手教你用MindSpore训练一个AI模型!

和预测值,它们之间的交叉熵损失计算公式为:

手把手教你用MindSpore训练一个AI模型!

其中J代表数组的下标,。从MindSpore导入损失函数:

from mindspore.nn import SoftmaxCrossEntropyWithLogits
# define the loss function
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 
  • 优化器,优化器就是用来求解损失函数关于模型参数的更新梯度的,它是整个训练过程中最重要的工具!我们这里用MindSpore提供的Momentum优化器:

import mindspore.nn as nn

lr = 0.01  # 定义学习率
momentum = 0.9  # 定义Momentum优化器的超参
# define the optimizer
net_opt = nn.Momentum(network.trainable_params(), lr, momentum)  # 导入mindspore提供

 准备好损失函数和优化器之后我们就可以开始训练模型了,也非常简单,我们先把前面定义好的模型、损失函数、优化器封装成一个Model:

from mindspore import Model
net = LeNet5()
model = Model(net, net_loss , net_opt , metrics={'acc', 'loss'})

然后使用model.train接口就可以训练我们定义的LeNet5模型了:

loss_cb = LossMonitor(per_print_times=ds_train.get_dataset_size())  # 用于监控训练过程中损失函数值的变化
ds_train = create_dataset(train_data_dir)  # 传入下载的训练集的路径
model.train(num_epochs, ds_train, callbacks=[loss_cb])  # num_epochs是训练的轮数,往往训练多轮才能使模型收敛

测试训练后的模型准确率

训练结束后,调用model.eval()计算训练后的模型在测试集上面的分类准确率:

ds_eval = create_dataset(test_data_dir)  # 传入下载的训练集的路径
metrics = model.eval(ds_eval)

小结

祝贺你耐心看完了MindSpore训练模型的完整过程,如果你想动手操作一遍,但是又没有现成的环境,那么你可以使用官网提供的“在线运行”来体验一番:

手把手教你用MindSpore训练一个AI模型!

图9 MindSpore官网提供的免费体验入口

这是体验过程的实操视频:

https://zhuanlan.zhihu.com/p/463229660

欢迎投稿

欢迎大家踊跃投稿,有想投稿技术干货、项目经验等分享的同学,可以添加MindSpore官方小助手:小猫子(mindspore0328)的微信,告诉猫哥哦!

昇思MindSpore官方交流QQ群 : 486831414群里有很多技术大咖助力答疑!

MindSpore官方资料

GitHub : https://github.com/mindspore-ai/mindspore

Gitee : https : //gitee.com/mindspore/mindspore

官方QQ群 : 486831 文章来源地址https://www.toymoban.com/news/detail-445378.html

到了这里,关于手把手教你用MindSpore训练一个AI模型!的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 手把手教你用Python编一个《我的世界》 2.材质及第一人称

    本次,我们将实现这样一个效果: 首先,导入ursina模块 创建app 定义Block类,继承自Button 然后,我们需要一个天空 定义Sky类 因为我们所有的方块包括天空都需要图片材质,所以我们在程序开头写以下代码: 然后咱们先创建一个超平坦地形,厚度就只有1层吧,因为方块多了很

    2024年02月04日
    浏览(51)
  • 手把手教你用SQLServer连接Visual Studio2019并编写一个学生信息管理页面

    目录 安装SQLServer 创建新项目 建数据库建表 窗体设计 代码实现  整体效果 ​ 用SQLServer连接Visual Studio,首先需要下载SQLServer app。 下载教程,我之前写过,可以点击如下链接先下载安装SQLServer: SQL Server(express)安装教程 安装好SQL之后,打开VisualStudio2019,新建一个window项目 ,步

    2024年02月12日
    浏览(40)
  • 手把手教你训练一个VAE生成模型一生成手写数字

    VAE(Variational Autoencoder)变分自编码器是一种使用变分推理的自编码器,其主要用于生成模型。 VAE 的编码器是模型的一部分,用于将输入数据压缩成潜在表示,即编码。 VAE 编码器包括两个子网络:一个是推断网络,另一个是生成网络。推断网络输入原始输入数据,并输出两

    2024年02月06日
    浏览(47)
  • 手把手教你用代码画架构图

    作者:京东物流 覃玉杰 本文将给大家介绍一种简洁明了软件架构可视化模型——C4模型,并手把手教大家如何使用 代码 绘制出精美的C4架构图。 阅读本文之后,读者画的架构图将会是这样的: 注:该图例仅作绘图示例使用,不确保其完整性、可行性。 C4是软件架构可视化

    2024年02月04日
    浏览(45)
  • 手把手教你用AirtestIDE无线连接手机

    一直以来,我们发现同学们都挺喜欢用无线的方式连接手机,正好安卓11出了个无线连接的新姿势,我们今天就一起来看看,如何用AirtestIDE无线连接你的Android设备~ 当 手机与电脑处在同一个wifi 下,即可尝试无线连接手机了,但是这种方式受限于网络连接的稳定性,可能会出

    2023年04月18日
    浏览(44)
  • 快收藏!手把手教你用AI绘画

    点个关注👆跟腾讯工程师学技术 最近看到一篇有趣的文章,一副名为《太空歌剧院》(如下图)的艺术品在某美术比赛上,获得了第一名的成绩, 有意思的是这件作品是通过AI来实现的画作, 顿时觉得非常神奇。结合近期科技媒体频频报道的AI作画爆火现象,深入了解了下

    2024年02月09日
    浏览(32)
  • 手把手教你用 Jenkins 自动部署 SpringBoot

    CI/CD 是一种通过在应用开发阶段引入自动化来频繁向客户交付应用的方法。 CI/CD 的核心概念可以总结为三点: 持续集成 持续交付 持续部署 CI/CD 主要针对在集成新代码时所引发的问题(俗称\\\"集成地狱\\\")。 为什么会有集成地狱这个“雅称”呢?大家想想我们一个项目部署的

    2024年02月02日
    浏览(44)
  • 手把手教你用Python编写邮箱脚本引擎

    版权声明:原创不易,本文禁止抄袭、转载需附上链接,侵权必究! 邮箱是传输信息方式之一,个人,企业等都在使用,朋友之间发消息,注册/登录信息验证,订阅邮箱,企业招聘,向客户发送消息等都是邮箱的使用场景;邮箱有两个较重要的协议:SMTP和POP3,均位于OSI7层

    2024年02月06日
    浏览(37)
  • 手把手教你用video实现视频播放功能

    哈喽。大家好啊 今天需要做一个视频播放列表,让我想到了video的属性 下面让我们先看看实现效果 这里是我的代码 width是当前播放页面的宽度 height是当前播放页面的高度 Controls属性用就是控制栏那些了 比如播放按钮 暂停按钮 autoplay是指的是自动播放 poster是指的是初始化进

    2024年02月12日
    浏览(42)
  • 手把手教你用jmeter做压力测试(详图)

    压力测试是每一个Web应用程序上线之前都需要做的一个测试,他可以帮助我们发现系统中的瓶颈问题,减少发布到生产环境后出问题的几率;预估系统的承载能力,使我们能根据其做出一些应对措施。所以压力测试是一个非常重要的步骤,下面我带大家来使用一款压力测试工

    2024年02月02日
    浏览(39)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包