21个项目玩转PyTorch实战
通过经典项目入门 PyTorch,通过前沿项目提升 PyTorch,基于PyTorch玩转深度学习,本书适合人工智能、机器学习、深度学习方面的人员阅读,也适合其他 IT 方面从业者,另外,还可以作为相关专业的教材。
京东自营购买链接:https://item.jd.com/13522327.html
PyTorch 是基于 Torch 库的开源机器学习库,它主要由 Meta(原 Facebook)的人工智能研究实验室开发,在自然语言处理和计算机视觉领域都具有广泛的应用。本书介绍了简单且经典的入门项目,方便快速上手,如 MNIST数字识别,读者在完成项目的过程中可以了解数据集、模型和训练等基础概念。本书还介绍了一些实用且经典的模型,如 R-CNN 模型,通过这个模型的学习,读者可以对目标检测任务有一个基本的认识,对于基本的网络结构原理有一定的了解。另外,本书对于当前比较热门的生成对抗网络和强化学习也有一定的介绍,方便读者拓宽视野,掌握前沿方向。
正文
在使用Python进行机器学习时,你有多个选项可供选择使用哪个库或框架。但是,如果你正在向深度学习迈进,那么你应该使用TensorFlow
或PyTorch
这两个最著名的深度学习框架之一。
在本文中,我们将快速介绍PyTorch
框架,从最初的概念一直到第一张图像分类模型的训练和测试。
我们不会深入学习复杂的概念和数学,因为本文旨在成为一个更加实用的PyTorch
工具入门文章,而不是一个深度学习概念的入门文章。
因此,我们假设你具有一些中级Python知识-包括类和面向对象编程-并且熟悉深度学习的主要概念。
PyTorch
PyTorch
是一个功能强大、易于使用的Python深度学习库,主要用于计算机视觉和自然语言处理等应用。
虽然TensorFlow
是由谷歌开发的,但PyTorch
是由Facebook
的AI研究小组开发的,该小组最近将该框架的管理转移到了新创建的PyTorch
基金会下,该基金会受Linux基金会的监督。
PyTorch
的灵活性允许轻松集成新的数据类型和算法,而且该框架也是高效且可扩展的,因为它被设计为最小化所需的计算数量,并与各种硬件架构兼容。
张量
在深度学习中,张量是一种基本数据结构,非常类似于数组和矩阵,可以在大量数据集上高效执行数学运算。张量可以表示为矩阵,也可以表示为向量、标量或高维数组。
为了更容易地可视化,你可以将张量视为包含标量或其他数组的简单数组。在PyTorch
上,张量是一个非常类似于ndarray
的结构,但它们能够在GPU
上运行,从而大大加快了计算过程。
从NumPy
创建张量很简单:
import torch
import numpy as np
ndarray = np.array([0, 1, 2])
t = torch.from_numpy(ndarray)
print(t)
tensor([0, 1, 2])
PyTorch
上的张量具有三个属性:
形状:张量的大小
数据类型:张量中存储的数据类型
设备:张量存储的设备
如果我们从创建的张量中打印属性,我们将得到以下结果:
print(t.shape)
print(t.dtype)
print(t.device)
torch.Size([3])
torch.int64
cpu
这意味着我们有一个包含整数的一维张量,大小为3,存储在CPU中。
我们也可以从Python列表实例化一个张量:
t = torch.tensor([0, 1, 2])
print(t)
tensor([0, 1, 2])
张量也可以是多维的:
ndarray = np.array([[0, 1, 2], [3, 4, 5]])
t = torch.from_numpy(ndarray)
print(t)
tensor([[0, 1, 2],
[3, 4, 5]])
还可以从另一个张量创建张量。在这种情况下,新张量继承初始张量的特性。下面的示例根据先前创建的张量创建具有随机数字的张量:
new_t = torch.rand_like(t, dtype=torch.float)
print(new_t)
tensor([[0.1366, 0.5994, 0.3963],
[0.1126, 0.8860, 0.8233]])
请注意,该函数创建一个形状为(2,2)的新张量。但是,由于函数返回值从0到1,因此我们必须将数据类型覆盖为float。
我们还可以仅从我们预期的形状创建一个随机张量:
my_shape = (3, 3)
rand_t = torch.rand(my_shape)
print(rand_t)
tensor([[0.8099, 0.8816, 0.3071],
[0.1003, 0.3190, 0.3503],
[0.9088, 0.0844, 0.0547]])
张量操作
与NumPy一样,我们可以使用张量执行多种可能的操作-例如切片、转置和矩阵乘法等。
切片张量与Python中的任何其他数组结构完全相同。考虑下面的张量:
zeros_tensor = torch.zeros((2, 3))
print(zeros_tensor)
tensor([[0., 0., 0.],
[0., 0., 0.]])
我们可以轻松地索引第一行或第一列:
print(zeros_tensor[1])
print(zeros_tensor[:, 0])
tensor([0., 0., 0.])
tensor([0., 0.])
我们还可以将此张量转置:
transposed = zeros_tensor.T
print(transposed)
tensor([[0., 0.],
[0., 0.],
[0., 0.]])
最后,我们可以将张量相乘:
ones_tensor = torch.ones(3, 3)
product = torch.matmul(zeros_tensor, ones_tensor)
print(product)
tensor([[0., 0., 0.],
[0., 0., 0.]])
请注意,我们使用了and函数创建一个仅包含零和一的张量,其形状与我们传递的形状相同。
这些操作只是PyTorch
可以执行的一小部分。但是,本文的目的不是涵盖它们中的每一个,而是提供对它们如何工作的一般概念。如果你想了解更多信息,PyTorch
有完整的文档。
加载数据
PyTorch自带一个内置模块,为许多深度学习应用提供了现成的数据集,例如计算机视觉、语音识别和自然语言处理。这意味着可以构建自己的神经网络,而无需收集和处理数据。
作为示例,我们将下载MNIST数据集。MNIST是一个手写数字图像数据集,包含6万个样本和一组包含1万张图像的测试集。
我们将使用来自torchvision的模块来下载数据集:
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.MNIST(root=".", train=True, download=True, transform=ToTensor())
test_data = datasets.MNIST(root=".", train=False, download=True, transform=ToTensor())
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/9912422 [00:00<?, ?it/s]
Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/28881 [00:00<?, ?it/s]
Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/1648877 [00:00<?, ?it/s]
Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/4542 [00:00<?, ?it/s]
Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
在下载函数内部,我们有以下参数:
root:数据将保存的目录。你可以传递一个带有目录路径的字符串。点(如示例中所示)将文件保存在你所在的相同目录中。
train:用于告知PyTorch你是否正在下载训练集还是测试集。
download:如果指定的路径上已经没有数据,是否下载数据。
transform:转换数据。在我们的代码中,我们选择张量。
如果我们打印训练集的第一个元素,我们将看到以下内容:
training_data[0]
(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0706, 0.0706, 0.0706,
0.4941, 0.5333, 0.6863, 0.1020, 0.6510, 1.0000, 0.9686, 0.4980,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.1176, 0.1412, 0.3686, 0.6039, 0.6667, 0.9922, 0.9922, 0.9922,
0.9922, 0.9922, 0.8824, 0.6745, 0.9922, 0.9490, 0.7647, 0.2510,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1922,
0.9333, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922,
0.9922, 0.9843, 0.3647, 0.3216, 0.3216, 0.2196, 0.1529, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706,
0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765, 0.7137,
0.9686, 0.9451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.3137, 0.6118, 0.4196, 0.9922, 0.9922, 0.8039, 0.0431, 0.0000,
0.1686, 0.6039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0549, 0.0039, 0.6039, 0.9922, 0.3529, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.5451, 0.9922, 0.7451, 0.0078, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0431, 0.7451, 0.9922, 0.2745, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.1373, 0.9451, 0.8824, 0.6275,
0.4235, 0.0039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3176, 0.9412, 0.9922,
0.9922, 0.4667, 0.0980, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1765, 0.7294,
0.9922, 0.9922, 0.5882, 0.1059, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0627,
0.3647, 0.9882, 0.9922, 0.7333, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.9765, 0.9922, 0.9765, 0.2510, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1804, 0.5098,
0.7176, 0.9922, 0.9922, 0.8118, 0.0078, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.1529, 0.5804, 0.8980, 0.9922,
0.9922, 0.9922, 0.9804, 0.7137, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0941, 0.4471, 0.8667, 0.9922, 0.9922, 0.9922,
0.9922, 0.7882, 0.3059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0902, 0.2588, 0.8353, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765,
0.3176, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.6706,
0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.7647, 0.3137, 0.0353,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.2157, 0.6745, 0.8863, 0.9922,
0.9922, 0.9922, 0.9922, 0.9569, 0.5216, 0.0431, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.5333, 0.9922, 0.9922, 0.9922,
0.8314, 0.5294, 0.5176, 0.0627, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000]]]), 5)
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1765, 0.7294,
0.9922, 0.9922, 0.5882, 0.1059, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0627,
0.3647, 0.9882, 0.9922, 0.7333, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.9765, 0.9922, 0.9765, 0.2510, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1804, 0.5098,
0.7176, 0.9922, 0.9922, 0.8118, 0.0078, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.1529, 0.5804, 0.8980, 0.9922,
0.9922, 0.9922, 0.9804, 0.7137, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0941, 0.4471, 0.8667, 0.9922, 0.9922, 0.9922,
0.9922, 0.7882, 0.3059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0902, 0.2588, 0.8353, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765,
0.3176, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
上述张量只是整个元素的一小部分,因为它太大了无法显示。
这一串数字对我们可能没有什么意义,由于它们代表图像,我们可以使用matplotlib将它们可视化为实际的图像:
figure = plt.figure(figsize=(8, 8))
cols, rows = 5, 5
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
我们还可以使用属性来查看数据中的类别:classes
training_data.classes
['0 - zero',
'1 - one',
'2 - two',
'3 - three',
'4 - four',
'5 - five',
'6 - six',
'7 - seven',
'8 - eight',
'9 - nine']
当模型训练好后,它可以接收新的输入,然后将其分类为这些类别之一。
现在我们已经下载了数据,我们将使用DataLoader。这使我们能够以小批量的方式迭代数据集,而不是一次观察一个,同时在训练模型时对数据进行洗牌。以下是代码:
from torch.utils.data import DataLoader
loaded_train = DataLoader(training_data, batch_size=64, shuffle=True)
loaded_test = DataLoader(test_data, batch_size=64, shuffle=True)
神经网络
在深度学习中,神经网络是一种用于建模具有复杂模式数据的算法。神经网络试图通过多个由处理节点连接的层来模拟人脑的功能,这些处理节点的行为类似于人类神经元。这些由节点连接的层创建了一个复杂的网络,能够处理和理解大量复杂的数据。
在PyTorch中,与神经网络相关的所有内容都是使用模块构建的。网络本身是一个继承自的类,而在类内部,我们将使用来构建层。以下是从PyTorch文档中摘取的一个简单实现:
from torch import nn
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
虽然本文的范围不包括深入讨论层是什么、它们如何工作以及如何实现它们,但让我们简要地了解一下上述代码的作用。
是负责将数据从多维转换为一维的模块。
容器在网络内部创建一系列层。
在容器内部,我们有层。每种类型的层以不同的方式转换数据,而在神经网络中实现层的方式有很多种。
前向函数是在执行模型时调用的函数;然而,我们不应该直接调用它。
以下行实例化了我们的模型:
model = NeuralNetwork()
print(model)
NeuralNetwork(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
)
训练神经网络
现在我们已经定义了神经网络,可以开始使用它了。在开始训练之前,我们应该先设置损失函数。损失函数衡量模型与正确结果的差距,我们将在网络训练期间尝试最小化它。交叉熵是用于分类任务的常见损失函数,也是我们将使用的损失函数。我们应该初始化函数:
loss_function = nn.CrossEntropyLoss()
训练之前的最后一步是设置优化算法。这样的算法将负责在训练过程中调整模型,以最小化我们选择的损失函数测量的误差。这种任务的常见选择是随机梯度下降算法。然而,PyTorch有几种其他可能性,你可以在这里了解。以下是代码:
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
参数是学习率,它表示在每次迭代训练中更新模型参数的速度。
最后,是时候训练和测试网络了。对于这些任务,我们将实现一个函数。训练函数包括循环遍历数据,使用优化器调整模型,并计算预测和损失。这是PyTorch的标准实现:
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 1000 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
请注意,对于每次迭代,我们获取用于馈送模型的数据,但还跟踪批次的编号,以便我们可以在每100次迭代时打印损失和当前批次。
然后是测试函数,它计算准确性和损失,这次使用测试集:
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
然后我们设置要训练模型的时期数。一个时期包括对数据集的迭代。例如,如果我们设置为5,则表示我们将使用神经网络训练和测试整个数据集5次。我们训练次数越多,结果就越好。
epochs=5
这是PyTorch的实现和这样一个循环的输出:
epochs = 5
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(loaded_train, model, loss_function, optimizer)
test(loaded_test, model, loss_function)
print("Done!")
Epoch 1
-------------------------------
loss: 2.296232 [ 0/60000]
Test Error:
Accuracy: 47.3%, Avg loss: 2.254638
Epoch 2
-------------------------------
loss: 2.260034 [ 0/60000]
Test Error:
Accuracy: 63.2%, Avg loss: 2.183432
Epoch 3
-------------------------------
loss: 2.173747 [ 0/60000]
Test Error:
Accuracy: 66.9%, Avg loss: 2.062604
Epoch 4
-------------------------------
loss: 2.078938 [ 0/60000]
Test Error:
Accuracy: 72.4%, Avg loss: 1.859960
Epoch 5
-------------------------------
loss: 1.871736 [ 0/60000]
Test Error:
Accuracy: 75.8%, Avg loss: 1.562622
请注意,在每个时期中,我们在训练循环中每100批次打印损失函数,它继续降低。此外,在每个时期之后,我们可以看到准确性随着平均损失的降低而提高。
如果我们设置更多的时期——比如10、50或甚至100,很可能会看到更好的结果,但输出将更长,更难以可视化和理解。
最后,我们的模型训练完毕后,保存和加载它非常容易:
torch.save(model, “model.pth”)
model = torch.load(“model.pth”)
结论
在本文中,我们介绍了使用PyTorch进行深度学习的基础知识,包括:
-
张量及其使用方法
-
如何加载和准备数据
-
神经网络以及如何在PyTorch中定义它们
-
如何训练你的第一个图像分类模型文章来源:https://www.toymoban.com/news/detail-456310.html
-
在PyTorch上保存和加载模型。文章来源地址https://www.toymoban.com/news/detail-456310.html
到了这里,关于2023初学者如何玩转玩转PyTorch?《21个项目玩转PyTorch实战》的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!