PyTorch深度学习实战(31)——生成对抗网络(Generative Adversarial Network, GAN)

这篇具有很好参考价值的文章主要介绍了PyTorch深度学习实战(31)——生成对抗网络(Generative Adversarial Network, GAN)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

0. 前言

生成对抗网络 (Generative Adversarial Networks, GAN) 是一种由两个相互竞争的神经网络组成的深度学习模型,它由一个生成网络和一个判别网络组成,通过彼此之间的博弈来提高生成网络的性能。生成对抗网络使用神经网络生成与原始图像集非常相似的新图像,它在图像生成中应用广泛,且 GAN 的相关研究正在迅速发展,以生成与真实图像难以区分的逼真图像。在本节中,我们将学习 GAN 网络的原理并使用 PyTorch 实现 GAN

1. GAN

生成对抗网络 (Generative Adversarial Networks, GAN) 包含两个网络:生成网络( Generator,也称生成器)和判别网络( discriminator,也称判别器)。在 GAN 网络训练过程中,需要有一个合理的图像样本数据集,生成网络从图像样本中学习图像表示,然后生成与图像样本相似的图像。判别网络接收(由生成网络)生成的图像和原始图像样本作为输入,并将图像分类为原始(真实)图像或生成(伪造)图像。
生成网络的目标是生成逼真的伪造图像骗过判别网络,判别网络的目标是将生成的图像分类为伪造图像,将原始图像样本分类为真实图像。本质上,GAN 中的对抗表示两个网络的相反性质,生成网络生成图像来欺骗判别网络,判别网络通过判别图像是生成图像还是原始图像来对输入图像进行分类:

PyTorch深度学习实战(31)——生成对抗网络(Generative Adversarial Network, GAN),深度学习,pytorch,生成对抗网络

在上图中,生成网络根据输入随机噪声生成图像,判别网络接收生成网络生成的图像,并将它们与真实图像样本进行比较,以判断生成的图像是真实的还是伪造的。生成网络尝试生成尽可能逼真的图像,而判别网络尝试判定生成网络生成图像的真实性,从而学习生成尽可能逼真的图像。
GAN 的关键思想是生成网络和判别网络之间的竞争和动态平衡,通过不断的训练和迭代,生成网络和判别网络会逐渐提高性能,生成网络能够生成更加逼真的样本,而判别网络则能够更准确地区分真实和伪造的样本。
通常,生成网络和判别网络交替训练,将生成网络和判别网络视为博弈双方,并通过两者之间的对抗来推动模型性能的提升,直到生成网络生成的样本能够以假乱真,判别网络无法分辨真实样本和生成样本之间的差异:

  • 生成网络的训练过程:冻结判别网络权重,生成网络以噪声 z 作为输入,通过最小化生成网络与真实数据之间的差异来学习如何生成更好的样本,以便判别网络将图像分类为真实图像
  • 判别网络的训练过程:冻结生成网络权重,判别网络通过最小化真实样本和假样本之间的分类误差来更新判别网络,区分真实样本和生成样本,将生成网络生成的图像分类为伪造图像

重复训练生成网络与判别网络,直到达到平衡,当判别网络能够很好地检测到生成的图像时,生成网络对应的损失比判别网络对应的损失要高得多。通过不断训练生成网络和判别网络,直到生成网络可以生成逼真图像,而判别网络无法区分真实图像和生成图像。

2. GAN 模型分析

为了生成手写数字的图像,我们采取以下策略:

  • 导入 MNIST 数据
  • 初始化随机噪声
  • 定义生成网络模型
  • 定义判别网络模型
  • 使用生成网络生成伪造图像,生成网络在最初只能生成噪声图像,噪声图像是通过将一组噪声值通过权重随机的神经网络得到的图像
  • 交替训练两个模型
    • 将生成的图像与原始图像串联起来,判别网络预测每个图像是伪造图像还是真实图像,对判别网络进行训练,判别网络的损失是图像的预测值和实际值(标签)的二进制交叉熵,生成的伪造图像的实际值(标签)为 0,原始数据集中真实图像的实际值(标签)为 1
    • 训练生成网络利用输入噪声生成伪造图像,使其看起来更接近真实图像,从而使生成图像有可能欺骗判别网络
    • 输入噪声通过生成网络传递输出伪造图像,将生成网络生成的图像输入到判别网络中,此时,判别网络权重被冻结,因为生成网络的目标是欺骗判别网络,因此,假设生成的伪造图像实际值(标签)为 1,生成网络的损失是判别网络对输入图像的预测值和实际值 (1) 的二进制交叉熵

了解了 GAN 的基本原理后,在下一小节,我们实现 GAN 生成 MNIST 手写数字图像。

3. 利用 GAN 模型生成手写数字

(1) 导入相关库并定义设备:

import torch
from torch import nn
from torch import optim
from matplotlib import pyplot as plt
import numpy as np
from torchvision.utils import make_grid
device = "cuda" if torch.cuda.is_available() else "cpu"

from torchvision.datasets import MNIST
from torchvision import transforms

(2) 导入 MNIST 数据,定义具有内置数据转换功能的数据加载器,以便缩放输入数据:

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5,), std=(0.5,))
])

data_loader = torch.utils.data.DataLoader(MNIST('MNIST/', train=True, download=True, transform=transform),batch_size=128, shuffle=True, drop_last=True)

(3) 定义判别网络模型类:

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential( 
            nn.Linear(784, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.model(x)

在以上代码中,使用 LeakyReLU 激活函数替换 ReLU。打印判别网络的简要信息:

from torchsummary import summary
discriminator = Discriminator().to(device)
print(summary(discriminator, (1,784)))

模型简要信息输出结果如下所示:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1              [-1, 1, 1024]         803,840
         LeakyReLU-2              [-1, 1, 1024]               0
           Dropout-3              [-1, 1, 1024]               0
            Linear-4               [-1, 1, 512]         524,800
         LeakyReLU-5               [-1, 1, 512]               0
           Dropout-6               [-1, 1, 512]               0
            Linear-7               [-1, 1, 256]         131,328
         LeakyReLU-8               [-1, 1, 256]               0
           Dropout-9               [-1, 1, 256]               0
           Linear-10                 [-1, 1, 1]             257
          Sigmoid-11                 [-1, 1, 1]               0
================================================================
Total params: 1,460,225
Trainable params: 1,460,225
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.04
Params size (MB): 5.57
Estimated Total Size (MB): 5.61
----------------------------------------------------------------

(4) 定义生成网络模型类 Generator

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

生成网络根据 100 维随机噪声输入生成图像。打印生成网络模型的简要信息:

generator = Generator().to(device)
print(summary(generator, (1,100)))

模型简要信息输出结果如下所示:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1               [-1, 1, 256]          25,856
         LeakyReLU-2               [-1, 1, 256]               0
            Linear-3               [-1, 1, 512]         131,584
         LeakyReLU-4               [-1, 1, 512]               0
            Linear-5              [-1, 1, 1024]         525,312
         LeakyReLU-6              [-1, 1, 1024]               0
            Linear-7               [-1, 1, 784]         803,600
              Tanh-8               [-1, 1, 784]               0
================================================================
Total params: 1,486,352
Trainable params: 1,486,352
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.04
Params size (MB): 5.67
Estimated Total Size (MB): 5.71
----------------------------------------------------------------

(5) 定义函数生成随机噪声并将其注册到设备中:

def noise(size):
    n = torch.randn(size, 100)
    return n.to(device)

(6) 定义函数来训练判别网络。

判别网络训练函数 (discriminator_train_step) 将真实数据 (real_data) 和伪造数据 (fake_data) 作为输入:

def discriminator_train_step(real_data, fake_data, loss, d_optimizer):

重置优化器梯度:

    d_optimizer.zero_grad()

在对损失值执行反向传播之前,预测真实数据 (real_data) 并计算损失 (error_real):

    prediction_real = discriminator(real_data)
    error_real = loss(prediction_real, torch.ones(len(real_data), 1).to(device))
    error_real.backward()

在真实数据上计算判别网络损失时,我们期望判别网络预测输出为 1。因此,在判别网络的训练过程中,使用 torch.ones 作为标签,期望判别网络在真实数据上的输出为 1,从而计算判别网络在真实数据上的损失。

在对损失值执行反向传播之前,预测伪造数据 (fake_data) 并计算损失 (error_fake):

    prediction_fake = discriminator(fake_data)
    error_fake = loss(prediction_fake, torch.zeros(len(fake_data), 1).to(device))
    error_fake.backward()

在伪造数据上计算判别网络损失时,我们期望判别网络预测输出为 0。因此,在判别网络的训练过程中,使用 torch.zeros 作为标签,期望判别网络在伪造数据上的输出为 0,从而计算判别网络在伪造数据上的损失。

更新权重并返回整体损失(将模型在 real_dataerror_realfake_dataerror_fake 的损失值相加):

    d_optimizer.step()
    return error_real + error_fake

(7) 训练生成网络模型。

定义生成网络训练函数 generator_train_step 并传入伪造数据 fake_data 作为参数:

def generator_train_step(real_data, fake_data, loss, g_optimizer):

重置优化器梯度:

    g_optimizer.zero_grad()

预测判别网络对伪造数据 (fake_data) 的输出:

    prediction = discriminator(fake_data)

在计算生成网络的损失时,使用 torch.ones 作为标签,期望判别网络在伪造数据上的输出为 1,以在训练生成网络时欺骗判别网络输出值 1,以此来鼓励生成网络生成更加逼真的数据,并让判别网络无法区分其真伪:

    error = loss(prediction, torch.ones(len(real_data), 1).to(device))

执行反向传播,更新权重,并返回损失:

    error.backward()
    g_optimizer.step()
    return error

(8) 定义模型对象、生成网络和判别网络的优化器,以及损失函数:

discriminator = Discriminator().to(device)
generator = Generator().to(device)
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
loss = nn.BCELoss()

(9) 训练模型。

循环训练模型 200epochs (num_epochs):

num_epochs = 200

d_loss_epoch = []
g_loss_epoch = []
for epoch in range(num_epochs):
    N = len(data_loader)
    d_loss_items = []
    g_loss_items = []
    for i, (images, _) in enumerate(data_loader):

加载真实数据 (real_data) 和伪造数据,其中,伪造数据是通过将大小与真实数据样本数相同的噪声数据 (batch_size = len(real_data)) 传入生成网络网络获得的。需要注意的是,必须调用 fake_data.detach(),否则训练无法正常进行。通过 detach() 函数分离出来一个新的张量,这样在 discriminator_train_step() 中调用 error.backward() 时,与生成网络相关的张量(生成 fake_data )不会受到影响。使用 discriminator_train_step 函数训练判别网络:

        real_data = images.view(len(images), -1).to(device)
        fake_data = generator(noise(len(real_data))).to(device)
        fake_data = fake_data.detach()

训练判别网络后,继续训练生成网络。从噪声数据生成一组新的伪造图像 (fake_data) 并使用 generator_train_step 函数训练生成网络:

        fake_data = generator(noise(len(real_data))).to(device)
        g_loss = generator_train_step(real_data, fake_data, loss, g_optimizer)

记录损失变化:

        d_loss_items.append(d_loss.item())
        g_loss_items.append(g_loss.item())
    d_loss_epoch.append(np.average(d_loss_items))
    g_loss_epoch.append(np.average(g_loss_items))

绘制判别网络和生成网络的损失随训练的变化情况:

epochs = np.arange(num_epochs)+1
plt.plot(epochs, d_loss_epoch, 'bo', label='Discriminator Training loss')
plt.plot(epochs, g_loss_epoch, 'r-', label='Generator Training loss')
plt.title('Training and Test loss over increasing epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
plt.show()

PyTorch深度学习实战(31)——生成对抗网络(Generative Adversarial Network, GAN),深度学习,pytorch,生成对抗网络

(10) 可视化模型训练后生成的伪造数据:

z = torch.randn(64, 100).to(device)
sample_images = generator(z).data.cpu().view(64, 1, 28, 28)
grid = make_grid(sample_images, nrow=8, normalize=True)
plt.imshow(grid.cpu().detach().permute(1,2,0), cmap='gray')
plt.show()

PyTorch深度学习实战(31)——生成对抗网络(Generative Adversarial Network, GAN),深度学习,pytorch,生成对抗网络

在上图中,可以看到利用 GAN 生成逼真的图像,但仍有一定的改进空间,在之后的学习中,我们将介绍更多 GAN 的改进模型生成更逼真的图像。

小结

生成对抗网络是一种强大的深度学习模型,由生成器网络和判别器网络组成,通过彼此之间的竞争来提高性能,已经在图像生成、图像修复、图像转换和自然语言处理等领域取得了巨大的成功。其核心思想是通过生成器和判别器之间的博弈过程来实现真实样本的生成。生成器负责生成逼真的样本,而判别器则负责判断样本是真实还是伪造。通过不断的训练和迭代,生成器和判别器会相互竞争并逐渐提高性能。

系列链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——使用U-Net架构进行图像分割
PyTorch深度学习实战(24)——从零开始实现Mask R-CNN实例分割
PyTorch深度学习实战(25)——自编码器(Autoencoder)
PyTorch深度学习实战(26)——卷积自编码器(Convolutional Autoencoder)
PyTorch深度学习实战(27)——变分自编码器(Variational Autoencoder, VAE)
PyTorch深度学习实战(28)——对抗攻击(Adversarial Attack)
PyTorch深度学习实战(29)——神经风格迁移
PyTorch深度学习实战(30)——Deepfakes文章来源地址https://www.toymoban.com/news/detail-814448.html

到了这里,关于PyTorch深度学习实战(31)——生成对抗网络(Generative Adversarial Network, GAN)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • AIGC实战——条件生成对抗网络(Conditional Generative Adversarial Net, CGAN)

    我们已经学习了如何构建生成对抗网络 (Generative Adversarial Net, GAN) 以从给定的训练集中生成逼真图像。但是,我们无法控制想要生成的图像类型,例如控制模型生成男性或女性的面部图像;我们可以从潜空间中随机采样一个点,但是不能预知给定潜变量能够生成什么样的图像

    2024年02月04日
    浏览(33)
  • PyTorch训练深度卷积生成对抗网络DCGAN

    将CNN和GAN结合起来,把监督学习和无监督学习结合起来。具体解释可以参见 深度卷积对抗生成网络(DCGAN) DCGAN的生成器结构: 图片来源:https://arxiv.org/abs/1511.06434 model.py 训练使用的数据集:CelebA dataset (Images Only) 总共1.3GB的图片,使用方法,将其解压到当前目录 图片如下图所

    2024年02月12日
    浏览(29)
  • Generative Adversarial Network(生成对抗网络)

    目录 Generative Adversarial Network(生成对抗网络) Basic Idea of GAN GAN as structured learning Can Generator learn by itself Can Discriminator generate Theory behind GAN Conditional GAN Generation (生成器)  Generation是一个neural network,它的输入是一个vector,它的输出是一个更高维的vector,以图片生成为例,输

    2024年02月09日
    浏览(46)
  • GAN(Generative Adversarial Nets (生成对抗网络))

    一、GAN 1、应用 GAN的应用十分广泛,如图像生成、图像转换、风格迁移、图像修复等等。 2、简介 生成式对抗网络是近年来复杂分布上无监督学习最具前景的方法之一。模型通过框架中(至少)两个模块:生成模型(Generative Model,G)和判别模型(Discriminative Model,D)的互相

    2024年02月04日
    浏览(28)
  • 生成对抗网络 – Generative Adversarial Networks | GAN

    目录 生成对抗网络 GAN 的基本原理 非大白话版本 第一阶段:固定「判别器D」,训练「生成器G」

    2024年04月15日
    浏览(32)
  • 深度学习生成对抗网络(GAN)

    生成对抗网络(Generative Adversarial Networks)是一种无监督深度学习模型,用来通过计算机生成数据,由Ian J. Goodfellow等人于2014年提出。模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。生成对抗网络被认为是当

    2024年02月07日
    浏览(42)
  • 李宏毅 Generative Adversarial Network(GAN)生成对抗网络

    附课程提到的各式各样的GAN:https://github.com/hindupuravinash/the-gan-zoo 想要让机器做到的是生成东西。-训练出来一个generator。 假设要做图像生成,要做的是随便给一个输入(random sample一个vector,比如从gaussian distribution sample一个vector),generator产生一个image。丢不同的vector,就应

    2024年01月21日
    浏览(45)
  • 【深度学习】生成对抗网络理解和实现

            本篇说明GAN框架是个啥。并且以最基础的数据集为例,用代码说明Gan网络的原理;总的老说,所谓神经网络,宏观上看,就是万能函数,在这种函数下,任何可用数学表述的属性,都可以映射成另一种可表示属性。         生成对抗网络 (GAN) 是一种算法架

    2024年02月13日
    浏览(42)
  • 深度学习基础——GAN生成对抗网络

            生成对抗网络GAN(Generative adversarial networks)是Goodfellow等在2014年提出的一种生成式模型。GAN在结构上受博弈论中的二元零和博弈(即二元的利益之和为零,一方的所得正是另一方的所失)的启发,系统由一个生成器和一个判别器构成。         生成器和判别器均可以

    2024年02月22日
    浏览(48)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包