pytorch生成对抗网络GAN的基础教学简单实例(附代码数据集)

这篇具有很好参考价值的文章主要介绍了pytorch生成对抗网络GAN的基础教学简单实例(附代码数据集)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

1.简介

这篇文章主要是介绍了使用pytorch框架构建生成对抗网络GAN来生成虚假图像的原理与简单实例代码。数据集使用的是开源人脸图像数据集img_align_celeba,共1.34G。生成器与判别器模型均采用简单的卷积结构,代码参考了pytorch官网。

建议对pytorch和神经网络原理还不熟悉的同学,可以先看下之前的文章了解下基础:pytorch基础教学简单实例(附代码)_Lizhi_Tech的博客-CSDN博客_pytorch实例


2.GAN原理

简而言之,生成对抗网络可以归纳为以下几个步骤:

  1. 随机噪声输入进生成器,生成虚假图片。

  1. 将带标签的虚假图片和真实图片输入进判别器进行更新,最大化 log(D(x)) + log(1 - D(G(z)))。

  1. 根据判别器的输出结果更新生成器,最大化 log(D(G(z)))。


3.代码

from __future__ import print_function
import random
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt

# 设置随机算子
manualSeed = 999
random.seed(manualSeed)
torch.manual_seed(manualSeed)

# 数据集位置
dataroot = "data/celeba"

# dataloader的核数
workers = 2

# Batch大小
batch_size = 128

# 图像缩放大小
image_size = 64

# 图像通道数
nc = 3

# 隐向量维度
nz = 100

# 生成器特征维度
ngf = 64

# 判别器特征维度
ndf = 64

# 训练轮数
num_epochs = 5

# 学习率
lr = 0.0002

# Adam优化器的beta系数
beta1 = 0.5

# gpu个数
ngpu = 1

# 加载数据集
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# 创建dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# 使用cpu还是gpu
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# 初始化权重
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# 生成器
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

# 实例化生成器并初始化权重
netG = Generator(ngpu).to(device)
netG.apply(weights_init)

# 判别器
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

# 实例化判别器并初始化权重
netD = Discriminator(ngpu).to(device)
netD.apply(weights_init)


# 损失函数
criterion = nn.BCELoss()

# 随机输入噪声
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# 真实标签与虚假标签
real_label = 1.
fake_label = 0.

# 创建优化器
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# 开始训练
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) 更新D: 最大化 log(D(x)) + log(1 - D(G(z)))
        ###########################
        # 使用真实标签的batch训练
        netD.zero_grad()
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        # 使用虚假标签的batch训练
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        # 更新D
        optimizerD.step()

        ############################
        # (2) 更新G: 最大化 log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)
        output = netD(fake).view(-1)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        # 更新G
        optimizerG.step()

        # 输出训练状态
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # 保存每轮loss
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # 记录生成的结果
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

# loss曲线
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()


# 生成效果图
real_batch = next(iter(dataloader))

# 真实图像
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# 生成的虚假图像
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

4.结果

生成对抗网络代码pytorch,深度学习,机器视觉,pytorch,深度学习,计算机视觉,GAN,Powered by 金山文档

真实图像与生成图像

生成对抗网络代码pytorch,深度学习,机器视觉,pytorch,深度学习,计算机视觉,GAN,Powered by 金山文档

loss曲线


业务合作/学习交流+v:lizhiTechnology 文章来源地址https://www.toymoban.com/news/detail-550323.html

到了这里,关于pytorch生成对抗网络GAN的基础教学简单实例(附代码数据集)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • PyTorch深度学习实战(31)——生成对抗网络(Generative Adversarial Network, GAN)

    生成对抗网络 ( Generative Adversarial Networks , GAN ) 是一种由两个相互竞争的神经网络组成的深度学习模型,它由一个生成网络和一个判别网络组成,通过彼此之间的博弈来提高生成网络的性能。生成对抗网络使用神经网络生成与原始图像集非常相似的新图像,它在图像生成中应用

    2024年01月22日
    浏览(48)
  • 深度学习基础——GAN生成对抗网络

            生成对抗网络GAN(Generative adversarial networks)是Goodfellow等在2014年提出的一种生成式模型。GAN在结构上受博弈论中的二元零和博弈(即二元的利益之和为零,一方的所得正是另一方的所失)的启发,系统由一个生成器和一个判别器构成。         生成器和判别器均可以

    2024年02月22日
    浏览(62)
  • 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)01 手写字体识别

    前面的博客讲了如何基于PyTorch使用神经网络识别手写数字 使用PyTorch构建神经网络 下面在此基础上构建一个生成对抗网络,生成对抗网络可以模拟出新的手写数字数据集。 生成对抗网络(GAN)是一种用于生成新的照片,文本或音频的模型。它由两部分组成:生成器和判别器

    2024年02月02日
    浏览(31)
  • 人工智能(pytorch)搭建模型11-pytorch搭建DCGAN模型,一种生成对抗网络GAN的变体实际应用

    大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型11-pytorch搭建DCGAN模型,一种生成对抗网络GAN的变体实际应用,本文将具体介绍DCGAN模型的原理,并使用PyTorch搭建一个简单的DCGAN模型。我们将提供模型代码,并使用一些数据样例进行训练和测试。最后,我们将

    2024年02月08日
    浏览(68)
  • 人工智能基础部分20-生成对抗网络(GAN)的实现应用

    大家好,我是微学AI,今天给大家介绍一下人工智能基础部分20-生成对抗网络(GAN)的原理与简单应用。生成对抗网络是一种由深度学习模型构成的神经网络系统,由一个生成器和一个判别器相互博弈来提升模型的能力。本文将从以下几个方面进行阐述:生成对抗网络的概念、

    2024年02月09日
    浏览(109)
  • 【计算机视觉|生成对抗】生成对抗网络(GAN)

    本系列博文为深度学习/计算机视觉论文笔记,转载请注明出处 标题: Generative Adversarial Nets 链接:Generative Adversarial Nets (nips.cc) 我们提出了一个通过**对抗(adversarial)**过程估计生成模型的新框架,在其中我们同时训练两个模型: 一个生成模型G,捕获数据分布 一个判别模型

    2024年02月12日
    浏览(58)
  • 生成对抗网络----GAN

    ` GAN (Generative Adversarial Network) : 通过两个神经网络,即生成器(Generator)和判别器(Discriminator),相互竞争来学习数据分布。 { 生成器 ( G e n e r a t o r ) : 负责从随机噪声中学习生成与真实数据相似的数据。 判别器 ( D i s c r i m i n a t o r ) : 尝试区分生成的数据和真实数据。

    2024年02月20日
    浏览(46)
  • GAN(生成对抗网络)

    简介:GAN生成对抗网络本质上是一种思想,其依靠神经网络能够拟合任意函数的能力,设计了一种架构来实现数据的生成。 原理:GAN的原理就是最小化生成器Generator的损失,但是在最小化损失的过程中加入了一个约束,这个约束就是使Generator生成的数据满足我们指定数据的分

    2024年02月11日
    浏览(47)
  • GAN生成对抗网络介绍

    GAN 全称是Generative Adversarial Networks,即 生成对抗网络 。 “生成”表示它是一个生成模型,而“对抗”代表它的训练是处于一种对抗博弈状态中的。 一个可以自己创造数据的网络! 判别模型与生成模型 判别模型(Discriminative Model) 主要目标是对给定输入数据直接进行建模,

    2024年01月17日
    浏览(52)
  • 了解生成对抗网络 (GAN)

            Yann LeCun将其描述为“过去10年来机器学习中最有趣的想法”。当然,来自深度学习领域如此杰出的研究人员的赞美总是对我们谈论的主题的一个很好的广告!事实上,生成对抗网络(简称GAN)自2014年由Ian J. Goodfellow和共同作者在《

    2024年02月12日
    浏览(39)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包