PyTorch训练简单的生成对抗网络GAN

这篇具有很好参考价值的文章主要介绍了PyTorch训练简单的生成对抗网络GAN。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

原理

同时训练两个网络:辨别器Discriminator 和 生成器Generator
Generator是 造假者,用来生成假数据。
Discriminator 是警察,尽可能的分辨出来哪些是造假的,哪些是真实的数据。

目的:使得判别模型尽量犯错,无法判断数据是来自真实数据还是生成出来的数据。

GAN的梯度下降训练过程:

PyTorch训练简单的生成对抗网络GAN,Deep Learning,DL Papers,pytorch,生成对抗网络,人工智能
上图来源:https://arxiv.org/abs/1406.2661

Train 辨别器: m a x max max l o g ( D ( x ) ) + l o g ( 1 − D ( G ( z ) ) ) log(D(x)) + log(1 - D(G(z))) log(D(x))+log(1D(G(z)))

Train 生成器: m i n min min l o g ( 1 − D ( G ( z ) ) ) log(1-D(G(z))) log(1D(G(z)))

我们可以使用BCEloss来计算上述两个损失函数

BCEloss的表达式: m i n − [ y l n x + ( 1 − y ) l n ( 1 − x ) ] min -[ylnx + (1-y)ln(1-x)] min[ylnx+(1y)ln(1x)]
具体过程参加代码中注释

代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter  # to print to tensorboard

class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)
    
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim): # z_dim 噪声的维度
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim), # 28x28 -> 784
            nn.Tanh(),
        )
    
    def forward(self, x):
        return self.gen(x)
    
# Hyperparameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 3e-4 # 3e-4是Adam最好的学习率
z_dim = 64 # 噪声维度
img_dim = 784 # 28x28x1
batch_size = 32
num_epochs = 50

disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)

fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose( # MNIST标准化系数:(0.1307,), (0.3081,)
    [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081,))] # 不同数据集就有不同的标准化系数
)

dataset = datasets.MNIST(root='dataset/', transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
# BCE 损失
criterion = nn.BCELoss()

# 打开tensorboard:在该目录下,使用 tensorboard --logdir=runs
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0

for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device) # view相当于reshape
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(real)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise) # G(z)
        disc_real = disc(real).view(-1) # flatten
        # BCEloss的表达式:min -[ylnx + (1-y)ln(1-x)]

        # max log(D(real)) 相当于 min -log(D(real))
        # ones_like:1填充得到y=1, 即可忽略  min -[ylnx + (1-y)ln(1-x)]中的后一项
        # 得到 min -lnx,这里的x就是我们的real图片
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))

        disc_fake = disc(fake).view(-1)
        # max log(1 - D(G(z))) 相当于 min -log(1 - D(G(z)))
        # zeros_like用0填充,得到y=0,即可忽略  min -[ylnx + (1-y)ln(1-x)]中的前一项
        # 得到 min -ln(1-x),这里的x就是我们的fake噪声
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2

        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        ### Train Generator: min log(1-D(G(z))) <--> max log(D(G(z))) <--> min - log(D(G(z)))
        # 依然可使用BCEloss来做
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] \ "
                f"Loss D: {lossD:.4f}, Loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )

                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )

                step += 1

结果

训练50轮的的损失

Epoch [0/50] \ Loss D: 0.7366, Loss G: 0.7051
Epoch [1/50] \ Loss D: 0.2483, Loss G: 1.6877
Epoch [2/50] \ Loss D: 0.1049, Loss G: 2.4980
Epoch [3/50] \ Loss D: 0.1159, Loss G: 3.4923
Epoch [4/50] \ Loss D: 0.0400, Loss G: 3.8776
Epoch [5/50] \ Loss D: 0.0450, Loss G: 4.1703
...
Epoch [43/50] \ Loss D: 0.0022, Loss G: 7.7446
Epoch [44/50] \ Loss D: 0.0007, Loss G: 9.1281
Epoch [45/50] \ Loss D: 0.0138, Loss G: 6.2177
Epoch [46/50] \ Loss D: 0.0008, Loss G: 9.1188
Epoch [47/50] \ Loss D: 0.0025, Loss G: 8.9419
Epoch [48/50] \ Loss D: 0.0010, Loss G: 8.3315
Epoch [49/50] \ Loss D: 0.0007, Loss G: 7.8302

使用

tensorboard --logdir=runs

打开tensorboard:

PyTorch训练简单的生成对抗网络GAN,Deep Learning,DL Papers,pytorch,生成对抗网络,人工智能
可以看到效果并不好,这是由于我们只是采用了简单的线性网络来做辨别器和生成器。后面的博文我们会使用更复杂的网络来训练GAN。

参考

[1] Building our first simple GAN
[2] https://arxiv.org/abs/1406.2661文章来源地址https://www.toymoban.com/news/detail-653583.html

到了这里,关于PyTorch训练简单的生成对抗网络GAN的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包赞助服务器费用

相关文章

  • PyTorch深度学习实战(31)——生成对抗网络(Generative Adversarial Network, GAN)

    PyTorch深度学习实战(31)——生成对抗网络(Generative Adversarial Network, GAN)

    生成对抗网络 ( Generative Adversarial Networks , GAN ) 是一种由两个相互竞争的神经网络组成的深度学习模型,它由一个生成网络和一个判别网络组成,通过彼此之间的博弈来提高生成网络的性能。生成对抗网络使用神经网络生成与原始图像集非常相似的新图像,它在图像生成中应用

    2024年01月22日
    浏览(12)
  • 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)01 手写字体识别

    使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)01 手写字体识别

    前面的博客讲了如何基于PyTorch使用神经网络识别手写数字 使用PyTorch构建神经网络 下面在此基础上构建一个生成对抗网络,生成对抗网络可以模拟出新的手写数字数据集。 生成对抗网络(GAN)是一种用于生成新的照片,文本或音频的模型。它由两部分组成:生成器和判别器

    2024年02月02日
    浏览(8)
  • 人工智能(pytorch)搭建模型11-pytorch搭建DCGAN模型,一种生成对抗网络GAN的变体实际应用

    人工智能(pytorch)搭建模型11-pytorch搭建DCGAN模型,一种生成对抗网络GAN的变体实际应用

    大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型11-pytorch搭建DCGAN模型,一种生成对抗网络GAN的变体实际应用,本文将具体介绍DCGAN模型的原理,并使用PyTorch搭建一个简单的DCGAN模型。我们将提供模型代码,并使用一些数据样例进行训练和测试。最后,我们将

    2024年02月08日
    浏览(44)
  • PyTorch训练深度卷积生成对抗网络DCGAN

    PyTorch训练深度卷积生成对抗网络DCGAN

    将CNN和GAN结合起来,把监督学习和无监督学习结合起来。具体解释可以参见 深度卷积对抗生成网络(DCGAN) DCGAN的生成器结构: 图片来源:https://arxiv.org/abs/1511.06434 model.py 训练使用的数据集:CelebA dataset (Images Only) 总共1.3GB的图片,使用方法,将其解压到当前目录 图片如下图所

    2024年02月12日
    浏览(14)
  • 深度学习进阶篇[9]:对抗生成网络GANs综述、代表变体模型、训练策略、GAN在计算机视觉应用和常见数据集介绍,以及前沿问题解决

    深度学习进阶篇[9]:对抗生成网络GANs综述、代表变体模型、训练策略、GAN在计算机视觉应用和常见数据集介绍,以及前沿问题解决

    【深度学习入门到进阶】必看系列,含激活函数、优化策略、损失函数、模型调优、归一化算法、卷积模型、序列模型、预训练模型、对抗神经网络等 专栏详细介绍:【深度学习入门到进阶】必看系列,含激活函数、优化策略、损失函数、模型调优、归一化算法、卷积模型、

    2024年02月08日
    浏览(12)
  • 【计算机视觉|生成对抗】生成对抗网络(GAN)

    【计算机视觉|生成对抗】生成对抗网络(GAN)

    本系列博文为深度学习/计算机视觉论文笔记,转载请注明出处 标题: Generative Adversarial Nets 链接:Generative Adversarial Nets (nips.cc) 我们提出了一个通过**对抗(adversarial)**过程估计生成模型的新框架,在其中我们同时训练两个模型: 一个生成模型G,捕获数据分布 一个判别模型

    2024年02月12日
    浏览(12)
  • GAN生成对抗网络介绍

    GAN生成对抗网络介绍

    GAN 全称是Generative Adversarial Networks,即 生成对抗网络 。 “生成”表示它是一个生成模型,而“对抗”代表它的训练是处于一种对抗博弈状态中的。 一个可以自己创造数据的网络! 判别模型与生成模型 判别模型(Discriminative Model) 主要目标是对给定输入数据直接进行建模,

    2024年01月17日
    浏览(6)
  • 生成对抗网络----GAN

    生成对抗网络----GAN

    ` GAN (Generative Adversarial Network) : 通过两个神经网络,即生成器(Generator)和判别器(Discriminator),相互竞争来学习数据分布。 { 生成器 ( G e n e r a t o r ) : 负责从随机噪声中学习生成与真实数据相似的数据。 判别器 ( D i s c r i m i n a t o r ) : 尝试区分生成的数据和真实数据。

    2024年02月20日
    浏览(8)
  • 生成对抗网络 (GAN)

    生成对抗网络 (GAN)

    生成对抗网络(Generative Adversarial Networks,GAN)是由Ian Goodfellow等人在2014年提出的一种深度学习模型。GAN由两部分组成:一个生成器(Generator)和一个判别器(Discriminator),它们通过对抗过程来训练,从而能够生成非常逼真的数据。 生成器(Generator) 生成器的任务是创建尽可

    2024年03月10日
    浏览(7)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包