TenorFlow多层感知机识别手写体

这篇具有很好参考价值的文章主要介绍了TenorFlow多层感知机识别手写体。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

GITHUB地址https://github.com/fz861062923/TensorFlow
注意下载数据连接的是外网,有一股神秘力量让你403

数据准备

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\h5py\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters


WARNING:tensorflow:From <ipython-input-1-2ee827ab903d>:4: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py:252: _internal_retry.<locals>.wrap.<locals>.wrapped_fn (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please use urllib or similar directly.
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
print('train images     :', mnist.train.images.shape,
      'labels:'           , mnist.train.labels.shape)
print('validation images:', mnist.validation.images.shape,
      ' labels:'          , mnist.validation.labels.shape)
print('test images      :', mnist.test.images.shape,
      'labels:'           , mnist.test.labels.shape)
train images     : (55000, 784) labels: (55000, 10)
validation images: (5000, 784)  labels: (5000, 10)
test images      : (10000, 784) labels: (10000, 10)

建立模型

def layer(output_dim,input_dim,inputs, activation=None):#激活函数默认为None
    W = tf.Variable(tf.random_normal([input_dim, output_dim]))#以正态分布的随机数建立并且初始化权重W
    b = tf.Variable(tf.random_normal([1, output_dim]))
    XWb = tf.matmul(inputs, W) + b
    if activation is None:
        outputs = XWb
    else:
        outputs = activation(XWb)
    return outputs
建立输入层 x
x = tf.placeholder("float", [None, 784])
建立隐藏层h1
h1=layer(output_dim=1000,input_dim=784,
         inputs=x ,activation=tf.nn.relu)  

建立隐藏层h2
h2=layer(output_dim=1000,input_dim=1000,
         inputs=h1 ,activation=tf.nn.relu)  
建立输出层
y_predict=layer(output_dim=10,input_dim=1000,
                inputs=h2,activation=None)

定义训练方式

建立训练数据label真实值 placeholder
y_label = tf.placeholder("float", [None, 10])#训练数据的个数很多所以设置为None
定义loss function
# 深度学习模型的训练中使用交叉熵训练的效果比较好
loss_function = tf.reduce_mean(
                   tf.nn.softmax_cross_entropy_with_logits_v2
                       (logits=y_predict , 
                        labels=y_label))
选择optimizer
optimizer = tf.train.AdamOptimizer(learning_rate=0.001) \
                    .minimize(loss_function)
#使用Loss_function来计算误差,并且按照误差更新模型权重与偏差,使误差最小化

定义评估模型的准确率

计算每一项数据是否正确预测
correct_prediction = tf.equal(tf.argmax(y_label  , 1),
                              tf.argmax(y_predict, 1))#将one-hot encoding转化为1所在的位数,方便比较
将计算预测正确结果,加总平均
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

开始训练

trainEpochs = 15#执行15个训练周期
batchSize = 100#每一批的数量为100
totalBatchs = int(mnist.train.num_examples/batchSize)#计算每一个训练周期应该执行的次数
epoch_list=[];accuracy_list=[];loss_list=[];
from time import time
startTime=time()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for epoch in range(trainEpochs):
    #执行15个训练周期
    #每个训练周期执行550批次训练
    for i in range(totalBatchs):
        batch_x, batch_y = mnist.train.next_batch(batchSize)#用该函数批次读取数据
        sess.run(optimizer,feed_dict={x: batch_x,
                                      y_label: batch_y})
        
    #使用验证数据计算准确率
    loss,acc = sess.run([loss_function,accuracy],
                        feed_dict={x: mnist.validation.images, #验证数据的features
                                   y_label: mnist.validation.labels})#验证数据的label

    epoch_list.append(epoch)
    loss_list.append(loss);accuracy_list.append(acc)    
    
    print("Train Epoch:", '%02d' % (epoch+1), \
          "Loss=","{:.9f}".format(loss)," Accuracy=",acc)
    
duration =time()-startTime
print("Train Finished takes:",duration)        
Train Epoch: 01 Loss= 133.117172241  Accuracy= 0.9194
Train Epoch: 02 Loss= 88.949943542  Accuracy= 0.9392
Train Epoch: 03 Loss= 80.701606750  Accuracy= 0.9446
Train Epoch: 04 Loss= 72.045913696  Accuracy= 0.9506
Train Epoch: 05 Loss= 71.911483765  Accuracy= 0.9502
Train Epoch: 06 Loss= 63.642936707  Accuracy= 0.9558
Train Epoch: 07 Loss= 67.192626953  Accuracy= 0.9494
Train Epoch: 08 Loss= 55.959281921  Accuracy= 0.9618
Train Epoch: 09 Loss= 58.867351532  Accuracy= 0.9592
Train Epoch: 10 Loss= 61.904548645  Accuracy= 0.9612
Train Epoch: 11 Loss= 58.283069611  Accuracy= 0.9608
Train Epoch: 12 Loss= 54.332244873  Accuracy= 0.9646
Train Epoch: 13 Loss= 58.152175903  Accuracy= 0.9624
Train Epoch: 14 Loss= 51.552104950  Accuracy= 0.9688
Train Epoch: 15 Loss= 52.803482056  Accuracy= 0.9678
Train Finished takes: 545.0556836128235
画出误差执行结果
%matplotlib inline
import matplotlib.pyplot as plt
fig = plt.gcf()#获取当前的figure图
fig.set_size_inches(4,2)#设置图的大小
plt.plot(epoch_list, loss_list, label = 'loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['loss'], loc='upper left')
<matplotlib.legend.Legend at 0x1edb8d4c240>

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

画出准确率执行结果
plt.plot(epoch_list, accuracy_list,label="accuracy" )
fig = plt.gcf()
fig.set_size_inches(4,2)
plt.ylim(0.8,1)
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend()
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

评估模型的准确率

print("Accuracy:", sess.run(accuracy,
                           feed_dict={x: mnist.test.images, 
                                      y_label: mnist.test.labels}))
Accuracy: 0.9643

进行预测

prediction_result=sess.run(tf.argmax(y_predict,1),
                           feed_dict={x: mnist.test.images })
prediction_result[:10]
array([7, 2, 1, 0, 4, 1, 4, 9, 6, 9], dtype=int64)
import matplotlib.pyplot as plt
import numpy as np
def plot_images_labels_prediction(images,labels,
                                  prediction,idx,num=10):
    fig = plt.gcf()
    fig.set_size_inches(12, 14)
    if num>25: num=25 
    for i in range(0, num):
        ax=plt.subplot(5,5, 1+i)
        
        ax.imshow(np.reshape(images[idx],(28, 28)), 
                  cmap='binary')
            
        title= "label=" +str(np.argmax(labels[idx]))
        if len(prediction)>0:
            title+=",predict="+str(prediction[idx]) 
            
        ax.set_title(title,fontsize=10) 
        ax.set_xticks([]);ax.set_yticks([])        
        idx+=1 
    plt.show()
plot_images_labels_prediction(mnist.test.images,
                              mnist.test.labels,
                              prediction_result,0)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传文章来源地址https://www.toymoban.com/news/detail-830765.html

y_predict_Onehot=sess.run(y_predict,
                          feed_dict={x: mnist.test.images })
y_predict_Onehot[8]
array([-6185.544  , -5329.589  ,  1897.1707 , -3942.7764 ,   347.9809 ,
        5513.258  ,  6735.7153 , -5088.5273 ,   649.2062 ,    69.50408],
      dtype=float32)

找出预测错误

for i in range(400):
    if prediction_result[i]!=np.argmax(mnist.test.labels[i]):
        print("i="+str(i)+"   label=",np.argmax(mnist.test.labels[i]),
              "predict=",prediction_result[i])
i=8   label= 5 predict= 6
i=18   label= 3 predict= 8
i=149   label= 2 predict= 4
i=151   label= 9 predict= 8
i=233   label= 8 predict= 7
i=241   label= 9 predict= 8
i=245   label= 3 predict= 5
i=247   label= 4 predict= 2
i=259   label= 6 predict= 0
i=320   label= 9 predict= 1
i=340   label= 5 predict= 3
i=381   label= 3 predict= 7
i=386   label= 6 predict= 5
sess.close()

到了这里,关于TenorFlow多层感知机识别手写体的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • (神经网络)MNIST手写体数字识别MATLAB完整代码

            在此次实验中,笔者针对 MNIST 数据集,利用卷积神经网络进行训练与测试,提 出了一系列的改进方法,并对这些改进的方法进行了逐一验证,比较了改进方法与浅层 神经网络的优劣。         首先,笔者对实验中所用的 MNIST 数据集进行了简单的介绍;接着,

    2024年02月03日
    浏览(46)
  • 【MATLAB图像处理实用案例详解(16)】——利用概念神经网络实现手写体数字识别

    手写体数字属于光学字符识别(Optical Character Recognition,OCR)的范畴,但分类的分别比光学字符识别少得多,主要只需识别共10个字符。 使用概率神经网络作为分类器,对64*64二值图像表示的手写数字进行分类,所得的分类器对训练样本能够取得100%的正确率,训练时间短,比

    2024年02月06日
    浏览(46)
  • 实战:基于卷积的MNIST手写体分类

    前面实现了基于多层感知机的MNIST手写体识别,本章将实现以卷积神经网络完成的MNIST手写体识别。 1.  数据的准备 在本例中,依旧使用MNIST数据集,对这个数据集的数据和标签介绍,前面的章节已详细说明过了,相对于前面章节直接对数据进行“折叠”处理,这里需要显式地

    2024年02月10日
    浏览(41)
  • 基于PyTorch的MNIST手写体分类实战

    第2章对MNIST数据做了介绍,描述了其构成方式及其数据的特征和标签的含义等。了解这些有助于编写合适的程序来对MNIST数据集进行分析和识别。本节将使用同样的数据集完成对其进行分类的任务。 3.1.1  数据图像的获取与标签的说明 MNIST数据集的详细介绍在第2章中已经完成

    2024年02月08日
    浏览(39)
  • 【Python机器学习】实验14 手写体卷积神经网络(PyTorch实现)

    LeNet-5是卷积神经网络模型的早期代表,它由LeCun在1998年提出。该模型采用顺序结构,主要包括7层(2个卷积层、2个池化层和3个全连接层),卷积层和池化层交替排列。以mnist手写数字分类为例构建一个LeNet-5模型。每个手写数字图片样本的宽与高均为28像素,样本标签值是0~

    2024年02月12日
    浏览(52)
  • 6.6 实现卷积神经网络LeNet训练并预测手写体数字

    isinstance(net,nn.Module)是Python的内置函数,用于判断一个对象是否属于制定类或其子类的实例。如果net是nn.Module类或子类的实例,那么表达式返回True,否则返回False. nn.Module是pytorch中用于构建神经网络模型的基类,其他神经网络都会继承它,因此使用 isinstance(net,nn.Module),可以确

    2024年02月14日
    浏览(49)
  • 多层感知机

    2024年02月11日
    浏览(41)
  • 多层感知机实战

    我们将继续使用Fashion-MNIST图像分类数据集 Fashion-MNIST中的每个图像由 28×28=784个灰度像素值组成。 所有图像共分为10个类别。 忽略像素之间的空间结构, 我们可以将每个图像视为具有784个输入特征 和10个类的简单分类数据集。 实现一个具有单隐藏层的多层感知机, 它包含

    2024年01月25日
    浏览(40)
  • 多层感知机(MLP)

    多层感知器(MLP,Multilayer Perceptron)是一种前馈人工神经网络模型,其将输入的多个数据集映射到单一的输出的数据集上。 它最主要的特点是有多个神经元层 ,因此也叫深度神经网络(DNN: Deep Neural Networks)。 感知机是单个神经元模型,是较大神经网络的前身。神经网络的强大

    2024年02月17日
    浏览(40)
  • 《动手学深度学习》——多层感知机

    参考资料: 《动手学深度学习》 隐藏层 + 激活函数能够模拟任何连续函数。 4.1.2.1 ReLu函数 ReLU ⁡ ( x ) = max ⁡ ( x , 0 ) operatorname{ReLU}(x) = max(x, 0) ReLU ( x ) = max ( x , 0 ) 当输入为负时,ReLU 的导数为 0 ;当输出为负时,ReLU 的导数为 1 。 ReLU的优势在于它的求导非常简单,要么让

    2024年02月12日
    浏览(56)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包