PyTorch 中的批量规范化

这篇具有很好参考价值的文章主要介绍了PyTorch 中的批量规范化。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

批量规范化(Batch Normalization)是深度学习中一种常用的技术,用于加速训练过程并提高模型的稳定性和泛化能力。以下是PyTorch中批量规范化的一些关键知识点:

1.nn.BatchNorm1d 和 nn.BatchNorm2d:

2.PyTorch提供了nn.BatchNorm1d用于在全连接层后应用批量规范化,以及nn.BatchNorm2d用于在卷积层后应用批量规范化。
3.通过这两个模块,可以轻松在模型中的不同层应用批量规范化。

4.训练和测试模式:

5.批量规范化层在训练和测试阶段的计算方式不同。
6.在训练时,它使用当前批次的均值和方差进行归一化。
7.在测试时,通常使用保存的移动平均值和方差进行归一化。

8.model.train() 和 model.eval() 方法:

9.在使用批量规范化的模型中,需要在训练和测试阶段切换模型的状态,使用model.train()将模型切换到训练模式,而使用model.eval()切换到测试模式。

10.批量规范化的计算过程:

11.对于每个输入特征,批量规范化执行以下步骤:
12.计算当前批次的均值和方差。
13.使用批次的均值和方差对输入进行标准化。
14.通过缩放和平移操作,将标准化后的值映射回新的分布,从而引入可学习的参数。

15.批量规范化的参数:

16.批量规范化引入了可学习的参数,包括缩放参数(gamma)和平移参数(beta)。
17.这些参数允许模型学习适应数据的最佳标准化。

18.移动平均:

19.在测试时,批量规范化层通常使用移动平均来估计整个训练数据集的均值和方差。
20.这有助于提高模型在测试集上的泛化性能。

21.批量规范化的位置:

22.批量规范化可以在模型的不同层中使用,包括全连接层和卷积层。
23.它通常在激活函数之前应用,即在全连接或卷积操作之后,激活函数之前。

24.对学习率和权重初始化的影响:

25.批量规范化有助于减小对学习率和权重初始化的敏感性,使得在更大范围内选择合适的学习率和初始化值更为容易。

这些知识点涵盖了在PyTorch中使用批量规范化时的一些关键概念和实践。使用批量规范化可以显著改善训练深度神经网络的效果,特别是在深层网络中。文章来源地址https://www.toymoban.com/news/detail-791946.html

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os
# 设置环境变量以避免 OpenMP 问题
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

class NeuralNetWithBatchNorm(nn.Module):
    def __init__(self, use_batch_norm, input_size=784, hidden_dim=256, output_size=10):
        super(NeuralNetWithBatchNorm, self).__init__()

        self.input_size = input_size
        self.hidden_dim = hidden_dim
        self.output_size = output_size
        self.use_batch_norm = use_batch_norm
        # 第一个全连接层。如果使用批量归一化,则在其后添加批量归一化层
        self.fc1 = nn.Linear(input_size, hidden_dim, bias=not use_batch_norm)
        if use_batch_norm:
            self.batch_norm1 = nn.BatchNorm1d(hidden_dim)
        # 第二层全连接层。与 fc1类似,如果 use_batch_norm是True ,则添加批处理归一化层 ()
        self.fc2 = nn.Linear(hidden_dim, hidden_dim, bias=not use_batch_norm)
        if use_batch_norm:
            self.batch_norm2 = nn.BatchNorm1d(hidden_dim)
        # 最终的全连接层。
        self.fc3 = nn.Linear(hidden_dim, output_size)

    def forward(self, x):
        # 将输入调整为第一维的大小
        x = x.view(-1, 28 * 28)
        # 按顺序应用的全连接层。
        x = self.fc1(x)
        # 如果 use_batch_norm是True ,则在全连接层之后应用批量归一化。
        if self.use_batch_norm:
            x = self.batch_norm1(x)
        # 在全连接层之间应用整流线性单元 (ReLU) 激活功能。
        x = F.relu(x)
        x = self.fc2(x)
        if self.use_batch_norm:
            x = self.batch_norm2(x)
        x = self.fc3(x)
        return x

# model:要训练的神经网络模型。
# train_loader:训练数据集的 PyTorch 数据加载器。
# n_epochs:训练周期数(默认设置为 10)
def train(model, train_loader, n_epochs=10):
    # 将 epoch 的数量分配给局部变量,并创建一个空列表来存储每个 epoch 期间的训练损失。
    n_epochs = n_epochs
    losses = []
    # nn.CrossEntropyLoss():这是交叉熵损失,通常用于分类问题。
    # optim.SGD():随机梯度下降优化器的使用学习率为 0.01。它优化了神经网络的参数()。
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    # 这会将模型设置为训练模式,如果模型包含在训练和评估期间表现不同的辍学或批量归一化等层,则这是必需的。
    model.train()
    # 该函数遍历每个纪元。
    for epoch in range(1, n_epochs + 1):
        train_loss = 0.0
        batch_count = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            batch_count += 1
        losses.append(train_loss / batch_count)
        print('Epoch:{}\t Training Loss:{:6f}'.format(epoch, train_loss / batch_count))
    # train_loss:当前时期训练损失的累加器。
    # batch_count:处理的批次数的计数器。
    # optimizer.zero_grad():清除在上一次迭代中计算的梯度。
    # output = model(data):前向通过模型。
    # loss = criterion(output, target):根据模型的输出和目标标签计算损失。
    # loss.backward():向后传递以计算参数的梯度。
    # optimizer.step():使用优化器更新模型参数。
    # train_loss += loss.item():累积当前批次的训练损失。
    # batch_count += 1:递增批次计数。
    # losses.append(train_loss / batch_count):记录当前 epoch 的平均训练损失。
    # print(...):显示纪元编号和相应的训练损失。
    return losses

# num_workers:用于数据加载的子进程数。此处设置为 0,表示数据加载将在主进程中进行。
# batch_size:每批样品数。
# transform:的实例,用于将 PIL 图像或 numpy 数组转换为火炬张量。transforms.ToTensor()
num_workers = 0
batch_size = 64
transform = transforms.ToTensor()
# datasets.MNIST:加载 MNIST 数据集。
# root='data':指定存储数据集的目录。
# train=True:表示这是数据集的训练拆分。
# download=True:下载数据集(如果尚不存在)。
# transform=transform:将指定的转换(在本例中为)应用于数据。
train_data = datasets.MNIST(root='data', train=True, download=True, transform=transform)
# torch.utils.data.DataLoader:创建用于遍历数据集的数据加载器。
# train_data:要加载的数据集。
# batch_size:每批样品数。
# num_workers:用于数据加载的子进程数。
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=num_workers)

# NeuralNetWithBatchNorm:假定存在一个名为 的神经网络类。
# use_batch_norm=True:表示具有批量归一化的网络应使用批量归一化层进行实例化。
# use_batch_norm=False:表示没有批量归一化的网络应该在没有批量归一化层的情况下实例化。
net_batchnorm = NeuralNetWithBatchNorm(use_batch_norm=True)
net_no_norm = NeuralNetWithBatchNorm(use_batch_norm=False)

# 打印实例
print(net_batchnorm)
print(net_no_norm)
print()

# Train the models
losses_batchnorm = train(net_batchnorm, train_loader)
losses_no_norm = train(net_no_norm, train_loader)

# Plot the training losses
fig, ax = plt.subplots(figsize=(12, 8))
plt.plot(losses_batchnorm, label='Using batchnorm', alpha=0.5)
plt.plot(losses_no_norm, label='No norm', alpha=0.5)
plt.title("Training Losses")
plt.legend()
plt.show()

# 测试功能
def test(model, train_loader, train=True):
    # 初始化:
    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))
    test_loss = 0.0
    # 将模型设置为训练或评估模式
    # 只是为了看到行为上的差异
    if train:
        model.train()
    else:
        model.eval()
    # 损失标准
    criterion = nn.CrossEntropyLoss()
    for batch_idx, (data, target) in enumerate(train_loader):
        batch_size = data.size(0)
        output = model(data)
        loss = criterion(output, target)
        test_loss += loss.item() * batch_size
        _, pred = torch.max(output, 1)
        correct = np.squeeze(pred.eq(target.data.view_as(pred)))
        for i in range(batch_size):
            label = target.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1

    # 打印测试损失
    print('Test Loss: {:.6f}\n'.format(test_loss / len(train_loader.dataset)))

    # 每个类别的打印精度
    for i in range(10):
        if class_total[i] > 0:
            print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % (
                str(i), 100 * class_correct[i] / class_total[i],
                np.sum(class_correct[i]), np.sum(class_total[i])))
        else:
            print('Test Accuracy of %5s: N/A (no training examples)' % (i))

    # 打印整体精度
    print('\nTest Accuracy (Overall): %2d%% (%2d/%2d)' % (
        100. * np.sum(class_correct) / np.sum(class_total),
        np.sum(class_correct), np.sum(class_total)))

# 测试模型
test(net_batchnorm, train_loader, train=True)
test(net_batchnorm, train_loader, train=False)
test(net_no_norm, train_loader, train=False)







到了这里,关于PyTorch 中的批量规范化的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【第一章 先导篇】1. 规范化的学习模型

    举例:什么是编码?

    2024年04月25日
    浏览(38)
  • idea的git的规范化提交插件

    1.在idea中安装git的插件git commit Template插件 打开IDEA-选择菜单栏的File-Settings,选择Plugins-MarkPlace输入Git Commit Template进行搜索,点击apply, 2.在日常commit的 时候按照如下操作进行:在commit的页面,点击下图的图标后,根据实际情况选择或者输入相关内容,该插件会根据其填入的内

    2024年02月12日
    浏览(62)
  • 数据库期末复习(10)数据库规范化理论

     函数依赖(概念):FD 范式分解(评估准则): 模式分解(工具): 如何衡量一个数据库好不好:准确 高效 如果一个数据库设计的不好的话的,会带来哪些问题 删除异常 数据冗余 为什么会导致出现上方的问题:数据依赖 数据依赖的分类:完全依赖,部分依赖,传递依赖和相应的定义 A

    2024年02月08日
    浏览(58)
  • 【数据库原理 • 四】数据库设计和规范化理论

    前言 数据库技术是计算机科学技术中发展最快,应用最广的技术之一,它是专门研究如何科学的组织和存储数据,如何高效地获取和处理数据的技术。它已成为各行各业存储数据、管理信息、共享资源和决策支持的最先进,最常用的技术。 当前互联网+与大数据,一切都建立

    2023年04月12日
    浏览(47)
  • Git Commit 之道:规范化 Commit Message 写作指南

    commit message格式都包括三部分:Header,Body和Footer Header是必需的,Body和Footer则可以省略 Type(必需) type用于说明 git commit 的类别,允许使用下面几个标识。 feat :新功能(Feature) \\\"feat\\\"用于表示引入新功能或特性的变动。这种变动通常是在代码库中新增的功能,而不仅仅是修

    2024年02月03日
    浏览(59)
  • vsCode配置Eslint+Prettier结合使用详细配置步骤,规范化开发

            eslint它规范的是代码偏向语法层面上的风格 。本篇文章以一个基本的vue项目,来说明eslint+prettier+husky配置项目代码规范,为了更好的描述本文,我恢复了vscode的默认设置(即未安装eslint,prettier等插件,setting中也没有相关配置) 1、新建vue3.0脚手架项目 2、项目安装

    2023年04月17日
    浏览(48)
  • 干翻Dubbo系列第十五篇:Rest协议基于SpringBoot的规范化开发

    文章目录 一:Rest协议 1:协议概念 2:协议作用 二:搭建开发环境 1:父项目里边引入的新的版本内容 2:Api中的操作 3:Provider模块 4:Consumer模块 三:编码 1:API模块 2:Provider模块 3:Consumer模块         Rest协议就是我们我们一开始基于SpringBoot或者是SpringMVC开发说的Re

    2024年02月10日
    浏览(53)
  • 项目git commit时卡主不良代码:husky让Git检查代码规范化工作

    看完 《前端规范之Git工作流规范(Husky + Commitlint + Lint-staged) 前端规范之Git工作流规范(Husky + Commitlint + Lint-staged) - Yellow_ice - 博客园》,再次修改本文 团队人一多,提交一多,还是要对备注加以区分,好快速找到变更点。这时候就需要对每次提交,需要输入message,对提交

    2024年02月03日
    浏览(133)
  • 系统架构设计师---计算机基础知识之数据库系统结构与规范化

    目录 一、基本概念  二、 数据库的结构  三、常用的数据模型         概念数据模型        基本数据模型        面向对象模型 四、数据的规范化      函数依赖       范式   1. 数据库 (DataBase, DB) : 是指长期储存在计算机内的、有组织的、可共享的数据集合。   

    2024年02月12日
    浏览(54)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包