2023.7.18
MNIST百科:
MNIST数据集简介与使用_bwqiang的博客-CSDN博客
数据集官网:MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges
MNIST数据集获取并转换成图片格式:
数据集将按以图片和文件夹名为标签的形式保存:
代码:下载mnist数据集并转还为图片
import os
from PIL import Image
from torchvision import datasets, transforms
# 定义数据转换
transform = transforms.Compose([
transforms.ToTensor(), # 转换为张量
transforms.Normalize((0.5,), (0.5,)) # 标准化
])
# 下载并加载训练集和测试集
train_dataset = datasets.MNIST(root=os.getcwd(), train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root=os.getcwd(), train=False, transform=transform, download=True)
# 路径
train_path = './images/train'
test_path = './images/test'
# 将训练集中的图像保存为图片
for i in range(10):
file_name = train_path + os.sep + str(i)
if not os.path.exists(file_name):
os.mkdir(file_name)
for i in range(10):
file_name = test_path + os.sep + str(i)
if not os.path.exists(file_name):
os.mkdir(file_name)
for i, (image, label) in enumerate(train_dataset):
train_label = label
image_path = f'images/train/{train_label}/{i}.png'
image = image.squeeze().numpy() # 去除通道维度,并转换为 numpy 数组
image = (image * 0.5) + 0.5 # 反标准化,将范围调整为 [0, 1]
image = (image * 255).astype('uint8') # 将范围调整为 [0, 255],并转换为整数类型
Image.fromarray(image).save(image_path)
# 将测试集中的图像保存为图片
for i, (image, label) in enumerate(test_dataset):
text_label = label
image_path = f'images/test/{text_label}/{i}.png'
image = image.squeeze().numpy() # 去除通道维度,并转换为 numpy 数组
image = (image * 0.5) + 0.5 # 反标准化,将范围调整为 [0, 1]
image = (image * 255).astype('uint8') # 将范围调整为 [0, 255],并转换为整数类型
Image.fromarray(image).save(image_path)
训练代码:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
# 调动显卡进行计算
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class MyDataset(torch.utils.data.Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.names_list = []
for dirs in os.listdir(self.root_dir):
dir_path = self.root_dir + '/' + dirs
for imgs in os.listdir(dir_path):
img_path = dir_path + '/' + imgs
self.names_list.append((img_path, dirs))
def __len__(self):
return len(self.names_list)
def __getitem__(self, index):
image_path, label = self.names_list[index]
if not os.path.isfile(image_path):
print(image_path + '不存在该路径')
return None
image = Image.open(image_path)
label = np.array(label).astype(int)
label = torch.from_numpy(label)
if self.transform:
image = self.transform(image)
return image, label
# 定义卷积神经网络模型
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc = nn.Linear(16 * 14 * 14, 10)
def forward(self, x):
x = self.conv1(x) # 卷积
x = self.relu(x) # 激活函数
x = self.maxpool(x) # 最大值池化
x = x.view(x.size(0), -1)
x = self.fc(x) # 全连接层
return x
# 加载手写数字数据集
train_dataset = MyDataset('./dataset/images/train', transform=transforms.ToTensor())
val_dataset = MyDataset('./dataset/images/val', transform=transforms.ToTensor())
# 定义超参数
batch_size = 8192 # 批处理大小
learning_rate = 0.001 # 学习率
num_epochs = 30 # 迭代次数
# 创建数据加载器
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
# 实例化模型、损失函数和优化器
model = CNN().to(device)
criterion = nn.CrossEntropyLoss() # 损失函数
optimizer = optim.Adam(model.parameters(), lr=learning_rate) # 优化器
# 记录验证的次数
total_train_step = 0
total_val_step = 0
# 模型训练和验证
print("-------------TRAINING-------------")
total_step = len(train_loader)
for epoch in range(num_epochs):
print("Epoch=", epoch)
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
output = model(images)
loss = criterion(output, labels.long())
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_train_step = total_train_step + 1
print("train_times:{},Loss:{}".format(total_train_step, loss.item()))
# 测试验证
total_val_loss = 0
total_accuracy = 0
with torch.no_grad():
for i, (images, labels) in enumerate(val_loader):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels.long())
total_val_loss = total_val_loss + loss.item() # 计算损失值的和
accuracy = 0
for j in labels: # 计算精确度的和
if outputs.argmax(1)[j] == labels[j]:
accuracy = accuracy + 1
total_accuracy = total_accuracy + accuracy
print('Accuracy =', float(total_accuracy / len(val_dataset))) # 输出正确率
torch.save(model, "cnn_{}.pth".format(epoch)) # 模型保存
# # 模型评估
# with torch.no_grad():
# correct = 0
# total = 0
# for images, labels in test_loader:
# outputs = model(images)
# _, predicted = torch.max(outputs.data, 1)
# total += labels.size(0)
# correct += (predicted == labels).sum().item()
测试代码:
import torch
from torchvision import transforms
import torch.nn as nn
import os
from PIL import Image
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # 判断是否有GPU
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc = nn.Linear(16 * 14 * 14, 10)
def forward(self, x):
x = self.conv1(x) # 卷积
x = self.relu(x) # 激活函数
x = self.maxpool(x) # 最大值池化
x = x.view(x.size(0), -1)
x = self.fc(x) # 全连接层
return x
model = torch.load('cnn.pth') # 加载模型
path = "./dataset/images/test/" # 测试集
imgs = os.listdir(path)
test_num = len(imgs)
print(f"test_dataset_quantity={test_num}")
for img_name in imgs:
img = Image.open(path + img_name)
test_transform = transforms.Compose([transforms.ToTensor()])
img = test_transform(img)
img = img.to(device)
img = img.unsqueeze(0)
outputs = model(img) # 将图片输入到模型中
_, predicted = outputs.max(1)
pred_type = predicted.item()
print(img_name, 'pred_type:', pred_type)
分类正确率不错:文章来源:https://www.toymoban.com/news/detail-581537.html
文章来源地址https://www.toymoban.com/news/detail-581537.html
到了这里,关于Pytorch:搭建卷积神经网络完成MNIST分类任务:的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!