2023初学者如何玩转玩转PyTorch?《21个项目玩转PyTorch实战》

这篇具有很好参考价值的文章主要介绍了2023初学者如何玩转玩转PyTorch?《21个项目玩转PyTorch实战》。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

21个项目玩转PyTorch实战

通过经典项目入门 PyTorch,通过前沿项目提升 PyTorch,基于PyTorch玩转深度学习,本书适合人工智能、机器学习、深度学习方面的人员阅读,也适合其他 IT 方面从业者,另外,还可以作为相关专业的教材。

京东自营购买链接:https://item.jd.com/13522327.html

PyTorch 是基于 Torch 库的开源机器学习库,它主要由 Meta(原 Facebook)的人工智能研究实验室开发,在自然语言处理和计算机视觉领域都具有广泛的应用。本书介绍了简单且经典的入门项目,方便快速上手,如 MNIST数字识别,读者在完成项目的过程中可以了解数据集、模型和训练等基础概念。本书还介绍了一些实用且经典的模型,如 R-CNN 模型,通过这个模型的学习,读者可以对目标检测任务有一个基本的认识,对于基本的网络结构原理有一定的了解。另外,本书对于当前比较热门的生成对抗网络和强化学习也有一定的介绍,方便读者拓宽视野,掌握前沿方向。

正文

在使用Python进行机器学习时,你有多个选项可供选择使用哪个库或框架。但是,如果你正在向深度学习迈进,那么你应该使用TensorFlowPyTorch这两个最著名的深度学习框架之一。

在本文中,我们将快速介绍PyTorch框架,从最初的概念一直到第一张图像分类模型的训练和测试。

我们不会深入学习复杂的概念和数学,因为本文旨在成为一个更加实用的PyTorch工具入门文章,而不是一个深度学习概念的入门文章。

因此,我们假设你具有一些中级Python知识-包括类和面向对象编程-并且熟悉深度学习的主要概念。

PyTorch

PyTorch是一个功能强大、易于使用的Python深度学习库,主要用于计算机视觉和自然语言处理等应用。

虽然TensorFlow是由谷歌开发的,但PyTorch是由Facebook的AI研究小组开发的,该小组最近将该框架的管理转移到了新创建的PyTorch基金会下,该基金会受Linux基金会的监督。

PyTorch的灵活性允许轻松集成新的数据类型和算法,而且该框架也是高效且可扩展的,因为它被设计为最小化所需的计算数量,并与各种硬件架构兼容。

张量

在深度学习中,张量是一种基本数据结构,非常类似于数组和矩阵,可以在大量数据集上高效执行数学运算。张量可以表示为矩阵,也可以表示为向量、标量或高维数组。

为了更容易地可视化,你可以将张量视为包含标量或其他数组的简单数组。在PyTorch上,张量是一个非常类似于ndarray的结构,但它们能够在GPU上运行,从而大大加快了计算过程。

NumPy创建张量很简单:

import torch
import numpy as np

ndarray = np.array([0, 1, 2])
t = torch.from_numpy(ndarray)
print(t)
tensor([0, 1, 2])

PyTorch上的张量具有三个属性:

形状:张量的大小
数据类型:张量中存储的数据类型
设备:张量存储的设备

如果我们从创建的张量中打印属性,我们将得到以下结果:

print(t.shape)
print(t.dtype)
print(t.device)
torch.Size([3])
torch.int64
cpu

这意味着我们有一个包含整数的一维张量,大小为3,存储在CPU中。

我们也可以从Python列表实例化一个张量:

t = torch.tensor([0, 1, 2])
print(t)
tensor([0, 1, 2])

张量也可以是多维的:

ndarray = np.array([[0, 1, 2], [3, 4, 5]])
t = torch.from_numpy(ndarray)
print(t)
tensor([[0, 1, 2],
        [3, 4, 5]])

还可以从另一个张量创建张量。在这种情况下,新张量继承初始张量的特性。下面的示例根据先前创建的张量创建具有随机数字的张量:

new_t = torch.rand_like(t, dtype=torch.float)
print(new_t)
tensor([[0.1366, 0.5994, 0.3963],
        [0.1126, 0.8860, 0.8233]])

请注意,该函数创建一个形状为(2,2)的新张量。但是,由于函数返回值从0到1,因此我们必须将数据类型覆盖为float。

我们还可以仅从我们预期的形状创建一个随机张量:

my_shape = (3, 3)
rand_t = torch.rand(my_shape)
print(rand_t)
tensor([[0.8099, 0.8816, 0.3071],
        [0.1003, 0.3190, 0.3503],
        [0.9088, 0.0844, 0.0547]])

张量操作

与NumPy一样,我们可以使用张量执行多种可能的操作-例如切片、转置和矩阵乘法等。

切片张量与Python中的任何其他数组结构完全相同。考虑下面的张量:

zeros_tensor = torch.zeros((2, 3))
print(zeros_tensor)
tensor([[0., 0., 0.],
        [0., 0., 0.]])

我们可以轻松地索引第一行或第一列:

print(zeros_tensor[1])
print(zeros_tensor[:, 0])
tensor([0., 0., 0.])
tensor([0., 0.])

我们还可以将此张量转置:

transposed = zeros_tensor.T
print(transposed)
tensor([[0., 0.],
        [0., 0.],
        [0., 0.]])

最后,我们可以将张量相乘:

ones_tensor = torch.ones(3, 3)
product = torch.matmul(zeros_tensor, ones_tensor)
print(product)
tensor([[0., 0., 0.],
        [0., 0., 0.]])

请注意,我们使用了and函数创建一个仅包含零和一的张量,其形状与我们传递的形状相同。

这些操作只是PyTorch可以执行的一小部分。但是,本文的目的不是涵盖它们中的每一个,而是提供对它们如何工作的一般概念。如果你想了解更多信息,PyTorch有完整的文档。

加载数据

PyTorch自带一个内置模块,为许多深度学习应用提供了现成的数据集,例如计算机视觉、语音识别和自然语言处理。这意味着可以构建自己的神经网络,而无需收集和处理数据。

作为示例,我们将下载MNIST数据集。MNIST是一个手写数字图像数据集,包含6万个样本和一组包含1万张图像的测试集。

我们将使用来自torchvision的模块来下载数据集:

from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

training_data = datasets.MNIST(root=".", train=True, download=True, transform=ToTensor())

test_data = datasets.MNIST(root=".", train=False, download=True, transform=ToTensor())
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw

在下载函数内部,我们有以下参数:

root:数据将保存的目录。你可以传递一个带有目录路径的字符串。点(如示例中所示)将文件保存在你所在的相同目录中。

train:用于告知PyTorch你是否正在下载训练集还是测试集。

download:如果指定的路径上已经没有数据,是否下载数据。

transform:转换数据。在我们的代码中,我们选择张量。

如果我们打印训练集的第一个元素,我们将看到以下内容:

training_data[0]
(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0706, 0.0706, 0.0706,
           0.4941, 0.5333, 0.6863, 0.1020, 0.6510, 1.0000, 0.9686, 0.4980,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.1176, 0.1412, 0.3686, 0.6039, 0.6667, 0.9922, 0.9922, 0.9922,
           0.9922, 0.9922, 0.8824, 0.6745, 0.9922, 0.9490, 0.7647, 0.2510,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1922,
           0.9333, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922,
           0.9922, 0.9843, 0.3647, 0.3216, 0.3216, 0.2196, 0.1529, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706,
           0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765, 0.7137,
           0.9686, 0.9451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.3137, 0.6118, 0.4196, 0.9922, 0.9922, 0.8039, 0.0431, 0.0000,
           0.1686, 0.6039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0549, 0.0039, 0.6039, 0.9922, 0.3529, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.5451, 0.9922, 0.7451, 0.0078, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0431, 0.7451, 0.9922, 0.2745, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.1373, 0.9451, 0.8824, 0.6275,
           0.4235, 0.0039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3176, 0.9412, 0.9922,
           0.9922, 0.4667, 0.0980, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1765, 0.7294,
           0.9922, 0.9922, 0.5882, 0.1059, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0627,
           0.3647, 0.9882, 0.9922, 0.7333, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.9765, 0.9922, 0.9765, 0.2510, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1804, 0.5098,
           0.7176, 0.9922, 0.9922, 0.8118, 0.0078, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.1529, 0.5804, 0.8980, 0.9922,
           0.9922, 0.9922, 0.9804, 0.7137, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0941, 0.4471, 0.8667, 0.9922, 0.9922, 0.9922,
           0.9922, 0.7882, 0.3059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0902, 0.2588, 0.8353, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765,
           0.3176, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.6706,
           0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.7647, 0.3137, 0.0353,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.2157, 0.6745, 0.8863, 0.9922,
           0.9922, 0.9922, 0.9922, 0.9569, 0.5216, 0.0431, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.5333, 0.9922, 0.9922, 0.9922,
           0.8314, 0.5294, 0.5176, 0.0627, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000]]]), 5)
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1765, 0.7294,
0.9922, 0.9922, 0.5882, 0.1059, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0627,
0.3647, 0.9882, 0.9922, 0.7333, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.9765, 0.9922, 0.9765, 0.2510, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1804, 0.5098,
0.7176, 0.9922, 0.9922, 0.8118, 0.0078, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.1529, 0.5804, 0.8980, 0.9922,
0.9922, 0.9922, 0.9804, 0.7137, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0941, 0.4471, 0.8667, 0.9922, 0.9922, 0.9922,
0.9922, 0.7882, 0.3059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0902, 0.2588, 0.8353, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765,
0.3176, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],

上述张量只是整个元素的一小部分,因为它太大了无法显示。

这一串数字对我们可能没有什么意义,由于它们代表图像,我们可以使用matplotlib将它们可视化为实际的图像:

figure = plt.figure(figsize=(8, 8))
cols, rows = 5, 5

for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

2023初学者如何玩转玩转PyTorch?《21个项目玩转PyTorch实战》

我们还可以使用属性来查看数据中的类别:classes

training_data.classes
['0 - zero',
 '1 - one',
 '2 - two',
 '3 - three',
 '4 - four',
 '5 - five',
 '6 - six',
 '7 - seven',
 '8 - eight',
 '9 - nine']

当模型训练好后,它可以接收新的输入,然后将其分类为这些类别之一。

现在我们已经下载了数据,我们将使用DataLoader。这使我们能够以小批量的方式迭代数据集,而不是一次观察一个,同时在训练模型时对数据进行洗牌。以下是代码:

from torch.utils.data import DataLoader

loaded_train = DataLoader(training_data, batch_size=64, shuffle=True)
loaded_test = DataLoader(test_data, batch_size=64, shuffle=True)

神经网络

在深度学习中,神经网络是一种用于建模具有复杂模式数据的算法。神经网络试图通过多个由处理节点连接的层来模拟人脑的功能,这些处理节点的行为类似于人类神经元。这些由节点连接的层创建了一个复杂的网络,能够处理和理解大量复杂的数据。

在PyTorch中,与神经网络相关的所有内容都是使用模块构建的。网络本身是一个继承自的类,而在类内部,我们将使用来构建层。以下是从PyTorch文档中摘取的一个简单实现:

from torch import nn

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

 def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

虽然本文的范围不包括深入讨论层是什么、它们如何工作以及如何实现它们,但让我们简要地了解一下上述代码的作用。

是负责将数据从多维转换为一维的模块。

容器在网络内部创建一系列层。

在容器内部,我们有层。每种类型的层以不同的方式转换数据,而在神经网络中实现层的方式有很多种。

前向函数是在执行模型时调用的函数;然而,我们不应该直接调用它。

以下行实例化了我们的模型:

model = NeuralNetwork()
print(model)
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)


训练神经网络

现在我们已经定义了神经网络,可以开始使用它了。在开始训练之前,我们应该先设置损失函数。损失函数衡量模型与正确结果的差距,我们将在网络训练期间尝试最小化它。交叉熵是用于分类任务的常见损失函数,也是我们将使用的损失函数。我们应该初始化函数:

loss_function = nn.CrossEntropyLoss()


训练之前的最后一步是设置优化算法。这样的算法将负责在训练过程中调整模型,以最小化我们选择的损失函数测量的误差。这种任务的常见选择是随机梯度下降算法。然而,PyTorch有几种其他可能性,你可以在这里了解。以下是代码:

optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

参数是学习率,它表示在每次迭代训练中更新模型参数的速度。

最后,是时候训练和测试网络了。对于这些任务,我们将实现一个函数。训练函数包括循环遍历数据,使用优化器调整模型,并计算预测和损失。这是PyTorch的标准实现:

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        pred = model(X)
        loss = loss_fn(pred, y)
		optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch % 1000 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

请注意,对于每次迭代,我们获取用于馈送模型的数据,但还跟踪批次的编号,以便我们可以在每100次迭代时打印损失和当前批次。

然后是测试函数,它计算准确性和损失,这次使用测试集:

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

然后我们设置要训练模型的时期数。一个时期包括对数据集的迭代。例如,如果我们设置为5,则表示我们将使用神经网络训练和测试整个数据集5次。我们训练次数越多,结果就越好。

epochs=5

这是PyTorch的实现和这样一个循环的输出:

epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(loaded_train, model, loss_function, optimizer)
    test(loaded_test, model, loss_function)
print("Done!")
Epoch 1
-------------------------------
loss: 2.296232  [    0/60000]
Test Error:
 Accuracy: 47.3%, Avg loss: 2.254638

Epoch 2
-------------------------------
loss: 2.260034  [    0/60000]
Test Error:
 Accuracy: 63.2%, Avg loss: 2.183432

Epoch 3
-------------------------------
loss: 2.173747  [    0/60000]
Test Error:
 Accuracy: 66.9%, Avg loss: 2.062604

Epoch 4
-------------------------------
loss: 2.078938  [    0/60000]
Test Error:
 Accuracy: 72.4%, Avg loss: 1.859960

Epoch 5
-------------------------------
loss: 1.871736  [    0/60000]
Test Error:
 Accuracy: 75.8%, Avg loss: 1.562622

请注意,在每个时期中,我们在训练循环中每100批次打印损失函数,它继续降低。此外,在每个时期之后,我们可以看到准确性随着平均损失的降低而提高。

如果我们设置更多的时期——比如10、50或甚至100,很可能会看到更好的结果,但输出将更长,更难以可视化和理解。

最后,我们的模型训练完毕后,保存和加载它非常容易:

torch.save(model, “model.pth”)
model = torch.load(“model.pth”)

结论

在本文中,我们介绍了使用PyTorch进行深度学习的基础知识,包括:

  • 张量及其使用方法

  • 如何加载和准备数据

  • 神经网络以及如何在PyTorch中定义它们

  • 如何训练你的第一个图像分类模型

  • 在PyTorch上保存和加载模型。文章来源地址https://www.toymoban.com/news/detail-456310.html

到了这里,关于2023初学者如何玩转玩转PyTorch?《21个项目玩转PyTorch实战》的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 超详细的的PyTorch安装教程,成功率高,适合初学者,亲测可用。

    啰嗦几句: 网上的教程很多,安装的方法多种多样,操作复杂,成功率还不高。 小编在淘宝专门帮助不会安装的小伙伴远程配置环境,这方法都是测试过了,适用大部分人的 ,完全按照文章来操作,基本都是可以安装成功的。 如果你不想再折腾了 ,可能联系 技术客服344

    2024年02月02日
    浏览(46)
  • 初学者该如何入手云计算

    妥妥的适合零基础入门云计算专业的学习路径,请收好。 我们将云计算的学习划分为4个阶段,基础阶段、初级阶段、应用阶段、进阶阶段。 (1)基础阶段 在基础阶段需要掌握通用的知识,有了扎实的基础后面才能走的更远,比如计算机组成原理、计算机网络、操作系统、

    2024年02月02日
    浏览(102)
  • systemd:初学者如何理解其中的争议

    导读 对于什么是 systemd,以及为什么它经常成为 Linux 世界争议的焦点,你可能仍然感到困惑。我将尝试用简单的语言来回答。 在 Linux 世界中,很少有争议能像传统的 System V 初始化 系统(通常称为 SysVinit)和较新的 systemd 之间的斗争那样引起如此大的争议。 在这篇文章中

    2024年02月12日
    浏览(46)
  • “初学者必看:如何从零开始学习人工智能?

    当我初次接触人工智能(AI)时,正值 AlphaGo 战胜围棋世界冠军李世石成为全球焦点,那一刻,人工智能这项技术首次闯入我的视线。我对此产生了浓厚兴趣,决心探究其背后的原理以及这些技术能为我们带来何种益处。于是我开始搜集资料,观看视频,深入了解相关知识。

    2024年01月24日
    浏览(62)
  • chatgpt赋能python:Python初学者必知:如何正确安装pandas模块

    如果你是一名初学者,或者只是想学习数据分析的人,你可能已经听说过 pandas 这个模块。Pandas 是一个 Python 的数据处理库,它提供了各种数据结构,可以使用户轻松地处理大量数据。但是,在安装 Pandas 的时候,可能会遇到一些问题。下面,我们将给大家介绍一些方法,来确

    2024年02月07日
    浏览(63)
  • 使用AI制作 3d 模型初学者指南,如何在 Blender 3d 中使用stable diffusion

    安装 Stability for Blender 只需这些简单的步骤即可快速简便: 在这里,前往Addon Releases页面,然后单击“stability-blender-addon”链接下载 ZIP 文件(而不是源代码链接) 或者,您可以从我们的 Blender Market 页面免费下载最新版本。

    2024年02月03日
    浏览(101)
  • 爬虫,初学者指南

    1.想目标地址发起请求,携带heards和不携带heards的区别 request模块用于测速发送数据的连通性,通过回复可以看出418,Connection:close表示未获取到服务器的返回值,需要添加heards信息,此服务器拒绝非浏览器发送的请求。 上图可以看出添加了头信息headers之后成功获取了返回值

    2024年02月07日
    浏览(61)
  • 守护进程(初学者必备)

    目录 一.进程组和会话 二.守护进程的概念 三.守护线程的特点 四.守护进程创建的基本步骤 1.进程组的相关概念: 进程除了有进程的PID之外还有一个进程组,进程组是由一个进程或者多个进程组成。通常他们与同一作业相关联可以收到同一终端的信号 每个进程组有唯一的进程

    2024年02月08日
    浏览(60)
  • Groovy初学者指南

    本文已收录至Github,推荐阅读 👉 Java随想录 微信公众号:Java随想录 目录 摘要 Groovy与Java的联系和区别 Groovy的语法 动态类型 元编程 处理集合的便捷方法 闭包 运算符重载 控制流 条件语句 循环语句 字符串处理 字符串插值 多行字符串 集合与迭代 列表(List) 映射(Map) 迭代器

    2024年02月05日
    浏览(62)
  • linux初学者小命令

    进程 :进程是一个具有一定独立功能的程序在一个数据集上的一次动态执行的过程,是操作系统进行资源分配和调度的一个独立单位,是应用程序运行的载体。 bash执行命令的过程,以’ls’命令为例: 第一步. 读取输入信息 :shell通过STDIN(标准输入)的getline()函数得到用户的输入

    2024年02月13日
    浏览(48)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包