LeNet-5 是经典卷积神经网络之一,于 1998 年由 Yann LeCun 等人提出。LeNet-5 网络使用了卷积层、池化层和全连接层,实现可以应用于手写体识别的卷积神经网络。TensorFlow 内置了 MNIST 手写体数据集,可以很方便地读取数据集,并应用于后续的模型训练过程中。本文主要记录了如何使用 TensorFlow 2.0 实现 MNIST 手写体识别模型。
目录
1 数据集准备
2 模型建立
3 模型训练与评估
1 数据集准备
TensorFlow 内置了 MNIST 手写体数据集,安装 TensorFlow 之后,使用如下代码就可以加载 MNIST 数据集:
import tensorflow as tf
mnist = tf.keras.datasets.mnist
(train_x, train_y), (test_x, test_y) = mnist.load_data()
使用 Matplotlib 查看前 25 张图片,并打印对应的标签。
from matplotlib import pyplot as plt
# 查看训练集
plt.figure(figsize=(3,3))
for i in range(25):
plt.subplot(5,5,i+1)
plt.imshow(train_x[i], cmap=plt.cm.binary)
plt.xticks([])
plt.yticks([])
plt.show()
接着使用 tf.one_hot() 函数,对图像的标签进行独热码编码。
# 预处理
train_y = tf.one_hot(train_y, depth=10)
test_y = tf.one_hot(test_y, depth=10)
2 模型建立
MNIST 手写体数据集中,每张图像的大小是 28 × 28 × 1,按照 LeNet-5 模型的思路,构建卷积神经网络模型。选择 5 × 5 的卷积核,卷积层之后是 2 × 2 的平均池化,激活函数选择 sigmoid(除了最后一层)。
# the first layer can receive an 'input_shape' argument
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=6,kernel_size=5,padding='valid',activation='sigmoid',input_shape=(28,28,1)),
tf.keras.layers.AveragePooling2D(pool_size=(2,2),strides=2,padding='valid'),
tf.keras.layers.Conv2D(filters=16,kernel_size=5,padding='valid',activation='sigmoid'),
tf.keras.layers.AveragePooling2D(pool_size=(2,2),strides=2,padding='valid'),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(120,activation='sigmoid'),
tf.keras.layers.Dense(84,activation='sigmoid'),
tf.keras.layers.Dense(10,activation='softmax')
])
使用 model.summary() 查看模型信息。
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 24, 24, 6) 156
average_pooling2d (AverageP (None, 12, 12, 6) 0
ooling2D)
conv2d_1 (Conv2D) (None, 8, 8, 16) 2416
average_pooling2d_1 (Averag (None, 4, 4, 16) 0
ePooling2D)
flatten (Flatten) (None, 256) 0
dense (Dense) (None, 120) 30840
dense_1 (Dense) (None, 84) 10164
dense_2 (Dense) (None, 10) 850
=================================================================
Total params: 44,426
Trainable params: 44,426
Non-trainable params: 0
_________________________________________________________________
3 模型训练与评估
使用 compile() 函数配置模型,优化算法为 Adam 算法,学习率为 0.001,损失函数为交叉熵损失函数。
# 模型配置
model.compile(
optimizer=tf.keras.optimizer.Adam(learning_rate=1e-3),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=['accuracy']
)
# 模型训练
model.fit(
x=train_x,
y=train_y,
validation_split=0.0,
epochs=10
)
Epoch 1/10
1875/1875 [==============================] - 72s 38ms/step - loss: 0.5806 - accuracy: 0.8206
Epoch 2/10
1875/1875 [==============================] - 70s 37ms/step - loss: 0.1254 - accuracy: 0.9620
Epoch 3/10
1875/1875 [==============================] - 75s 40ms/step - loss: 0.0870 - accuracy: 0.9735
Epoch 4/10
1875/1875 [==============================] - 82s 43ms/step - loss: 0.0699 - accuracy: 0.9785
Epoch 5/10
1875/1875 [==============================] - 69s 37ms/step - loss: 0.0604 - accuracy: 0.9809
Epoch 6/10
1875/1875 [==============================] - 68s 36ms/step - loss: 0.0530 - accuracy: 0.9833
Epoch 7/10
1875/1875 [==============================] - 72s 38ms/step - loss: 0.0477 - accuracy: 0.9854
Epoch 8/10
1875/1875 [==============================] - 70s 38ms/step - loss: 0.0436 - accuracy: 0.9863
Epoch 9/10
1875/1875 [==============================] - 70s 37ms/step - loss: 0.0399 - accuracy: 0.9873
Epoch 10/10
1875/1875 [==============================] - 68s 36ms/step - loss: 0.0357 - accuracy: 0.9883
<keras.callbacks.History at 0x20a56b65660>
使用 model.evaluate() 函数评估模型,model.predict() 函数用于预测输出。
model.evaluate(test_x,test_y)
313/313 [==============================] - 1s 2ms/step - loss: 0.0914 - accuracy: 0.9701
[0.09142322838306427, 0.9700999855995178]文章来源:https://www.toymoban.com/news/detail-682241.html
# 预测输出
pred_y = model.predict(test_x)
print(pred_y[:25].argmax(axis=1).reshape(5,5))
文章来源地址https://www.toymoban.com/news/detail-682241.html
到了这里,关于【卷积神经网络】MNIST 手写体识别的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!