2023初学者如何玩转玩转PyTorch?《21个项目玩转PyTorch实战》

这篇具有很好参考价值的文章主要介绍了2023初学者如何玩转玩转PyTorch?《21个项目玩转PyTorch实战》。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

21个项目玩转PyTorch实战

通过经典项目入门 PyTorch,通过前沿项目提升 PyTorch,基于PyTorch玩转深度学习,本书适合人工智能、机器学习、深度学习方面的人员阅读,也适合其他 IT 方面从业者,另外,还可以作为相关专业的教材。

京东自营购买链接:https://item.jd.com/13522327.html

PyTorch 是基于 Torch 库的开源机器学习库,它主要由 Meta(原 Facebook)的人工智能研究实验室开发,在自然语言处理和计算机视觉领域都具有广泛的应用。本书介绍了简单且经典的入门项目,方便快速上手,如 MNIST数字识别,读者在完成项目的过程中可以了解数据集、模型和训练等基础概念。本书还介绍了一些实用且经典的模型,如 R-CNN 模型,通过这个模型的学习,读者可以对目标检测任务有一个基本的认识,对于基本的网络结构原理有一定的了解。另外,本书对于当前比较热门的生成对抗网络和强化学习也有一定的介绍,方便读者拓宽视野,掌握前沿方向。

正文

在使用Python进行机器学习时,你有多个选项可供选择使用哪个库或框架。但是,如果你正在向深度学习迈进,那么你应该使用TensorFlowPyTorch这两个最著名的深度学习框架之一。

在本文中,我们将快速介绍PyTorch框架,从最初的概念一直到第一张图像分类模型的训练和测试。

我们不会深入学习复杂的概念和数学,因为本文旨在成为一个更加实用的PyTorch工具入门文章,而不是一个深度学习概念的入门文章。

因此,我们假设你具有一些中级Python知识-包括类和面向对象编程-并且熟悉深度学习的主要概念。

PyTorch

PyTorch是一个功能强大、易于使用的Python深度学习库,主要用于计算机视觉和自然语言处理等应用。

虽然TensorFlow是由谷歌开发的,但PyTorch是由Facebook的AI研究小组开发的,该小组最近将该框架的管理转移到了新创建的PyTorch基金会下,该基金会受Linux基金会的监督。

PyTorch的灵活性允许轻松集成新的数据类型和算法,而且该框架也是高效且可扩展的,因为它被设计为最小化所需的计算数量,并与各种硬件架构兼容。

张量

在深度学习中,张量是一种基本数据结构,非常类似于数组和矩阵,可以在大量数据集上高效执行数学运算。张量可以表示为矩阵,也可以表示为向量、标量或高维数组。

为了更容易地可视化,你可以将张量视为包含标量或其他数组的简单数组。在PyTorch上,张量是一个非常类似于ndarray的结构,但它们能够在GPU上运行,从而大大加快了计算过程。

NumPy创建张量很简单:

import torch
import numpy as np

ndarray = np.array([0, 1, 2])
t = torch.from_numpy(ndarray)
print(t)
tensor([0, 1, 2])

PyTorch上的张量具有三个属性:

形状:张量的大小
数据类型:张量中存储的数据类型
设备:张量存储的设备

如果我们从创建的张量中打印属性,我们将得到以下结果:

print(t.shape)
print(t.dtype)
print(t.device)
torch.Size([3])
torch.int64
cpu

这意味着我们有一个包含整数的一维张量,大小为3,存储在CPU中。

我们也可以从Python列表实例化一个张量:

t = torch.tensor([0, 1, 2])
print(t)
tensor([0, 1, 2])

张量也可以是多维的:

ndarray = np.array([[0, 1, 2], [3, 4, 5]])
t = torch.from_numpy(ndarray)
print(t)
tensor([[0, 1, 2],
        [3, 4, 5]])

还可以从另一个张量创建张量。在这种情况下,新张量继承初始张量的特性。下面的示例根据先前创建的张量创建具有随机数字的张量:

new_t = torch.rand_like(t, dtype=torch.float)
print(new_t)
tensor([[0.1366, 0.5994, 0.3963],
        [0.1126, 0.8860, 0.8233]])

请注意,该函数创建一个形状为(2,2)的新张量。但是,由于函数返回值从0到1,因此我们必须将数据类型覆盖为float。

我们还可以仅从我们预期的形状创建一个随机张量:

my_shape = (3, 3)
rand_t = torch.rand(my_shape)
print(rand_t)
tensor([[0.8099, 0.8816, 0.3071],
        [0.1003, 0.3190, 0.3503],
        [0.9088, 0.0844, 0.0547]])

张量操作

与NumPy一样,我们可以使用张量执行多种可能的操作-例如切片、转置和矩阵乘法等。

切片张量与Python中的任何其他数组结构完全相同。考虑下面的张量:

zeros_tensor = torch.zeros((2, 3))
print(zeros_tensor)
tensor([[0., 0., 0.],
        [0., 0., 0.]])

我们可以轻松地索引第一行或第一列:

print(zeros_tensor[1])
print(zeros_tensor[:, 0])
tensor([0., 0., 0.])
tensor([0., 0.])

我们还可以将此张量转置:

transposed = zeros_tensor.T
print(transposed)
tensor([[0., 0.],
        [0., 0.],
        [0., 0.]])

最后,我们可以将张量相乘:

ones_tensor = torch.ones(3, 3)
product = torch.matmul(zeros_tensor, ones_tensor)
print(product)
tensor([[0., 0., 0.],
        [0., 0., 0.]])

请注意,我们使用了and函数创建一个仅包含零和一的张量,其形状与我们传递的形状相同。

这些操作只是PyTorch可以执行的一小部分。但是,本文的目的不是涵盖它们中的每一个,而是提供对它们如何工作的一般概念。如果你想了解更多信息,PyTorch有完整的文档。

加载数据

PyTorch自带一个内置模块,为许多深度学习应用提供了现成的数据集,例如计算机视觉、语音识别和自然语言处理。这意味着可以构建自己的神经网络,而无需收集和处理数据。

作为示例,我们将下载MNIST数据集。MNIST是一个手写数字图像数据集,包含6万个样本和一组包含1万张图像的测试集。

我们将使用来自torchvision的模块来下载数据集:

from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

training_data = datasets.MNIST(root=".", train=True, download=True, transform=ToTensor())

test_data = datasets.MNIST(root=".", train=False, download=True, transform=ToTensor())
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw

在下载函数内部,我们有以下参数:

root:数据将保存的目录。你可以传递一个带有目录路径的字符串。点(如示例中所示)将文件保存在你所在的相同目录中。

train:用于告知PyTorch你是否正在下载训练集还是测试集。

download:如果指定的路径上已经没有数据,是否下载数据。

transform:转换数据。在我们的代码中,我们选择张量。

如果我们打印训练集的第一个元素,我们将看到以下内容:

training_data[0]
(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0706, 0.0706, 0.0706,
           0.4941, 0.5333, 0.6863, 0.1020, 0.6510, 1.0000, 0.9686, 0.4980,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.1176, 0.1412, 0.3686, 0.6039, 0.6667, 0.9922, 0.9922, 0.9922,
           0.9922, 0.9922, 0.8824, 0.6745, 0.9922, 0.9490, 0.7647, 0.2510,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1922,
           0.9333, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922,
           0.9922, 0.9843, 0.3647, 0.3216, 0.3216, 0.2196, 0.1529, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706,
           0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765, 0.7137,
           0.9686, 0.9451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.3137, 0.6118, 0.4196, 0.9922, 0.9922, 0.8039, 0.0431, 0.0000,
           0.1686, 0.6039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0549, 0.0039, 0.6039, 0.9922, 0.3529, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.5451, 0.9922, 0.7451, 0.0078, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0431, 0.7451, 0.9922, 0.2745, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.1373, 0.9451, 0.8824, 0.6275,
           0.4235, 0.0039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3176, 0.9412, 0.9922,
           0.9922, 0.4667, 0.0980, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1765, 0.7294,
           0.9922, 0.9922, 0.5882, 0.1059, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0627,
           0.3647, 0.9882, 0.9922, 0.7333, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.9765, 0.9922, 0.9765, 0.2510, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1804, 0.5098,
           0.7176, 0.9922, 0.9922, 0.8118, 0.0078, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.1529, 0.5804, 0.8980, 0.9922,
           0.9922, 0.9922, 0.9804, 0.7137, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0941, 0.4471, 0.8667, 0.9922, 0.9922, 0.9922,
           0.9922, 0.7882, 0.3059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0902, 0.2588, 0.8353, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765,
           0.3176, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.6706,
           0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.7647, 0.3137, 0.0353,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.2157, 0.6745, 0.8863, 0.9922,
           0.9922, 0.9922, 0.9922, 0.9569, 0.5216, 0.0431, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.5333, 0.9922, 0.9922, 0.9922,
           0.8314, 0.5294, 0.5176, 0.0627, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000]]]), 5)
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1765, 0.7294,
0.9922, 0.9922, 0.5882, 0.1059, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0627,
0.3647, 0.9882, 0.9922, 0.7333, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.9765, 0.9922, 0.9765, 0.2510, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1804, 0.5098,
0.7176, 0.9922, 0.9922, 0.8118, 0.0078, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.1529, 0.5804, 0.8980, 0.9922,
0.9922, 0.9922, 0.9804, 0.7137, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0941, 0.4471, 0.8667, 0.9922, 0.9922, 0.9922,
0.9922, 0.7882, 0.3059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0902, 0.2588, 0.8353, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765,
0.3176, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],

上述张量只是整个元素的一小部分,因为它太大了无法显示。

这一串数字对我们可能没有什么意义,由于它们代表图像,我们可以使用matplotlib将它们可视化为实际的图像:

figure = plt.figure(figsize=(8, 8))
cols, rows = 5, 5

for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

2023初学者如何玩转玩转PyTorch?《21个项目玩转PyTorch实战》

我们还可以使用属性来查看数据中的类别:classes

training_data.classes
['0 - zero',
 '1 - one',
 '2 - two',
 '3 - three',
 '4 - four',
 '5 - five',
 '6 - six',
 '7 - seven',
 '8 - eight',
 '9 - nine']

当模型训练好后,它可以接收新的输入,然后将其分类为这些类别之一。

现在我们已经下载了数据,我们将使用DataLoader。这使我们能够以小批量的方式迭代数据集,而不是一次观察一个,同时在训练模型时对数据进行洗牌。以下是代码:

from torch.utils.data import DataLoader

loaded_train = DataLoader(training_data, batch_size=64, shuffle=True)
loaded_test = DataLoader(test_data, batch_size=64, shuffle=True)

神经网络

在深度学习中,神经网络是一种用于建模具有复杂模式数据的算法。神经网络试图通过多个由处理节点连接的层来模拟人脑的功能,这些处理节点的行为类似于人类神经元。这些由节点连接的层创建了一个复杂的网络,能够处理和理解大量复杂的数据。

在PyTorch中,与神经网络相关的所有内容都是使用模块构建的。网络本身是一个继承自的类,而在类内部,我们将使用来构建层。以下是从PyTorch文档中摘取的一个简单实现:

from torch import nn

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

 def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

虽然本文的范围不包括深入讨论层是什么、它们如何工作以及如何实现它们,但让我们简要地了解一下上述代码的作用。

是负责将数据从多维转换为一维的模块。

容器在网络内部创建一系列层。

在容器内部,我们有层。每种类型的层以不同的方式转换数据,而在神经网络中实现层的方式有很多种。

前向函数是在执行模型时调用的函数;然而,我们不应该直接调用它。

以下行实例化了我们的模型:

model = NeuralNetwork()
print(model)
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)


训练神经网络

现在我们已经定义了神经网络,可以开始使用它了。在开始训练之前,我们应该先设置损失函数。损失函数衡量模型与正确结果的差距,我们将在网络训练期间尝试最小化它。交叉熵是用于分类任务的常见损失函数,也是我们将使用的损失函数。我们应该初始化函数:

loss_function = nn.CrossEntropyLoss()


训练之前的最后一步是设置优化算法。这样的算法将负责在训练过程中调整模型,以最小化我们选择的损失函数测量的误差。这种任务的常见选择是随机梯度下降算法。然而,PyTorch有几种其他可能性,你可以在这里了解。以下是代码:

optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

参数是学习率,它表示在每次迭代训练中更新模型参数的速度。

最后,是时候训练和测试网络了。对于这些任务,我们将实现一个函数。训练函数包括循环遍历数据,使用优化器调整模型,并计算预测和损失。这是PyTorch的标准实现:

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        pred = model(X)
        loss = loss_fn(pred, y)
		optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch % 1000 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

请注意,对于每次迭代,我们获取用于馈送模型的数据,但还跟踪批次的编号,以便我们可以在每100次迭代时打印损失和当前批次。

然后是测试函数,它计算准确性和损失,这次使用测试集:

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

然后我们设置要训练模型的时期数。一个时期包括对数据集的迭代。例如,如果我们设置为5,则表示我们将使用神经网络训练和测试整个数据集5次。我们训练次数越多,结果就越好。

epochs=5

这是PyTorch的实现和这样一个循环的输出:

epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(loaded_train, model, loss_function, optimizer)
    test(loaded_test, model, loss_function)
print("Done!")
Epoch 1
-------------------------------
loss: 2.296232  [    0/60000]
Test Error:
 Accuracy: 47.3%, Avg loss: 2.254638

Epoch 2
-------------------------------
loss: 2.260034  [    0/60000]
Test Error:
 Accuracy: 63.2%, Avg loss: 2.183432

Epoch 3
-------------------------------
loss: 2.173747  [    0/60000]
Test Error:
 Accuracy: 66.9%, Avg loss: 2.062604

Epoch 4
-------------------------------
loss: 2.078938  [    0/60000]
Test Error:
 Accuracy: 72.4%, Avg loss: 1.859960

Epoch 5
-------------------------------
loss: 1.871736  [    0/60000]
Test Error:
 Accuracy: 75.8%, Avg loss: 1.562622

请注意,在每个时期中,我们在训练循环中每100批次打印损失函数,它继续降低。此外,在每个时期之后,我们可以看到准确性随着平均损失的降低而提高。

如果我们设置更多的时期——比如10、50或甚至100,很可能会看到更好的结果,但输出将更长,更难以可视化和理解。

最后,我们的模型训练完毕后,保存和加载它非常容易:

torch.save(model, “model.pth”)
model = torch.load(“model.pth”)

结论

在本文中,我们介绍了使用PyTorch进行深度学习的基础知识,包括:

  • 张量及其使用方法

  • 如何加载和准备数据

  • 神经网络以及如何在PyTorch中定义它们

  • 如何训练你的第一个图像分类模型

  • 在PyTorch上保存和加载模型。文章来源地址https://www.toymoban.com/news/detail-456310.html

到了这里,关于2023初学者如何玩转玩转PyTorch?《21个项目玩转PyTorch实战》的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 超详细的的PyTorch安装教程,成功率高,适合初学者,亲测可用。

    啰嗦几句: 网上的教程很多,安装的方法多种多样,操作复杂,成功率还不高。 小编在淘宝专门帮助不会安装的小伙伴远程配置环境,这方法都是测试过了,适用大部分人的 ,完全按照文章来操作,基本都是可以安装成功的。 如果你不想再折腾了 ,可能联系 技术客服344

    2024年02月02日
    浏览(46)
  • 初学者该如何入手云计算

    妥妥的适合零基础入门云计算专业的学习路径,请收好。 我们将云计算的学习划分为4个阶段,基础阶段、初级阶段、应用阶段、进阶阶段。 (1)基础阶段 在基础阶段需要掌握通用的知识,有了扎实的基础后面才能走的更远,比如计算机组成原理、计算机网络、操作系统、

    2024年02月02日
    浏览(101)
  • systemd:初学者如何理解其中的争议

    导读 对于什么是 systemd,以及为什么它经常成为 Linux 世界争议的焦点,你可能仍然感到困惑。我将尝试用简单的语言来回答。 在 Linux 世界中,很少有争议能像传统的 System V 初始化 系统(通常称为 SysVinit)和较新的 systemd 之间的斗争那样引起如此大的争议。 在这篇文章中

    2024年02月12日
    浏览(46)
  • “初学者必看:如何从零开始学习人工智能?

    当我初次接触人工智能(AI)时,正值 AlphaGo 战胜围棋世界冠军李世石成为全球焦点,那一刻,人工智能这项技术首次闯入我的视线。我对此产生了浓厚兴趣,决心探究其背后的原理以及这些技术能为我们带来何种益处。于是我开始搜集资料,观看视频,深入了解相关知识。

    2024年01月24日
    浏览(57)
  • chatgpt赋能python:Python初学者必知:如何正确安装pandas模块

    如果你是一名初学者,或者只是想学习数据分析的人,你可能已经听说过 pandas 这个模块。Pandas 是一个 Python 的数据处理库,它提供了各种数据结构,可以使用户轻松地处理大量数据。但是,在安装 Pandas 的时候,可能会遇到一些问题。下面,我们将给大家介绍一些方法,来确

    2024年02月07日
    浏览(60)
  • 使用AI制作 3d 模型初学者指南,如何在 Blender 3d 中使用stable diffusion

    安装 Stability for Blender 只需这些简单的步骤即可快速简便: 在这里,前往Addon Releases页面,然后单击“stability-blender-addon”链接下载 ZIP 文件(而不是源代码链接) 或者,您可以从我们的 Blender Market 页面免费下载最新版本。

    2024年02月03日
    浏览(100)
  • 爬虫,初学者指南

    1.想目标地址发起请求,携带heards和不携带heards的区别 request模块用于测速发送数据的连通性,通过回复可以看出418,Connection:close表示未获取到服务器的返回值,需要添加heards信息,此服务器拒绝非浏览器发送的请求。 上图可以看出添加了头信息headers之后成功获取了返回值

    2024年02月07日
    浏览(59)
  • 守护进程(初学者必备)

    目录 一.进程组和会话 二.守护进程的概念 三.守护线程的特点 四.守护进程创建的基本步骤 1.进程组的相关概念: 进程除了有进程的PID之外还有一个进程组,进程组是由一个进程或者多个进程组成。通常他们与同一作业相关联可以收到同一终端的信号 每个进程组有唯一的进程

    2024年02月08日
    浏览(60)
  • Groovy初学者指南

    本文已收录至Github,推荐阅读 👉 Java随想录 微信公众号:Java随想录 目录 摘要 Groovy与Java的联系和区别 Groovy的语法 动态类型 元编程 处理集合的便捷方法 闭包 运算符重载 控制流 条件语句 循环语句 字符串处理 字符串插值 多行字符串 集合与迭代 列表(List) 映射(Map) 迭代器

    2024年02月05日
    浏览(62)
  • ChatGPT初学者最佳实践

    2022年11月底,ChatGPT引爆了新一轮AI的革命,也让人们意识到AI真的能够大幅度提高人们的工作效率,甚至有人担心自己的工作会因为AI不保。这种居安思危的意识是正确的,但是正如锛凿斧锯的出现,并没有让木匠这个行业消失,而是让这个行业以更高效的方式工作。所以作为

    2024年02月05日
    浏览(51)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包