生成对抗网络cGAN(条件GAN)

这篇具有很好参考价值的文章主要介绍了生成对抗网络cGAN(条件GAN)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

1.介绍

论文:Conditional Generative Adversarial Nets

论文地址:https://arxiv.org/abs/1411.1784

针对原始GAN的缺点:生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强等问题。

改进方法:cGAN的中心思想是希望可以控制 GAN 生成的图片,而不是单纯的随机生成图片。 Conditional GAN 在生成器和判别器的输入中增加了额外的条件信息,生成器生成的图片只有足够真实且与条件相符,才能够通过判别器。其核心在于将属性信息融入生成器G和判别器D中,属性可以是任何标签信息, 例如图像的类别、人脸图像的面部表情等。

2.模型结构

cgan,生成对抗网络,深度学习,人工智能,python,计算机视觉

在判别器和生成器中都添加了额外信息y,y可以是类别标签或者是其他类型的数据,可以将y作为一个额外的输入层引入判别器和生成器。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets, utils
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from torchvision.datasets import ImageFolder
import tqdm

ROOT_TRAIN = r'D:\CNN\AlexNet\data1\train'

def one_hot(x, num_class=2): #转化为独热标签
    return torch.eye(num_class)[x, :]

train_transform = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_dataset = ImageFolder(ROOT_TRAIN, transform=train_transform, target_transform=one_hot)  # 加载训练集
dataloader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=64,
                                           shuffle=True,
                                           num_workers=0)

# print(train_dataset[0]) #返回数据和标签, 引入one_hot编码后,标签就为长度为num_class的tensor tensor([1., 0.]

# 定义生成器,输入是长度为100的噪声(正态分布随机数),和标签独热编码(condition)
# 输出为3*224*224的图片(tensor)
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.linear1 = nn.Linear(100, 128*56*56)
        self.bn1 = nn.BatchNorm1d(128*56*56)
        self.linear2 = nn.Linear(2, 128*56*56)
        self.bn2 = nn.BatchNorm1d(128*56*56)

        self.deconv1 = nn.ConvTranspose2d(256, 128,
                                          kernel_size=(3, 3),
                                          stride=1,
                                          padding=1)  #128*56*56
        self.bn3 = nn.BatchNorm2d(128)
        self.deconv2 = nn.ConvTranspose2d(128, 64,
                                          kernel_size=(4, 4),
                                          stride=2,
                                          padding=1)  # 64*112*112
        self.bn4 = nn.BatchNorm2d(64)
        self.deconv3 = nn.ConvTranspose2d(64, 3,
                                          kernel_size=(4, 4),
                                          stride=2,
                                          padding=1)  # 3*224*224

    def forward(self, x1, x2): #x1为噪声输入,x2标签独热编码输入(condition)
        x1 = F.relu(self.linear1(x1)) #100 -- 128*56*56
        x1 = self.bn1(x1)
        x2 = F.relu(self.linear2(x2)) #num_class -- 128*56*56
        x2 = self.bn2(x2)
        x1 = x1.view(-1, 128, 56, 56)
        x2 = x2.view(-1, 128, 56, 56)
        x = torch.cat([x1, x2], dim=1) #256*56*56
        x = F.relu(self.deconv1(x)) #256*56*56 -- 128*56*56
        x = self.bn3(x)
        x = F.relu(self.deconv2(x)) #128*56*56 -- 64*112*112
        x = self.bn4(x)
        x = torch.tanh(self.deconv3(x)) #64*112*112 -- 3*224*224 生成器的输出不使用bn层
        return x

# 定义判别器,输入为3*224*224的图片,输出为二分类概率值
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.linear = nn.Linear(2, 3*224*224)
        self.conv1 = nn.Conv2d(6, 64, kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2)
        self.bn = nn.BatchNorm2d(128)
        self.fc = nn.Linear(128*55*55, 1)

    def forward(self, x1, x2): #x1为真实图像输入,x2标签独热编码输入(condition)
        x2 = self.linear(x2)
        x2 = x2.view(-1, 3, 224, 224)
        x = torch.cat([x1, x2], dim=1) #batchsize, 6, 224, 224
        x = F.dropout2d(F.leaky_relu(self.conv1(x)), p=0.3)  #64*111*111 判别器的输入不使用bn层
        x = F.dropout2d(F.leaky_relu(self.conv2(x)), p=0.3)  #128*55*55
        x = self.bn(x)
        x = x.view(-1, 128*55*55) #展平
        x = torch.sigmoid(self.fc(x))
        return x


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

gen = Generator().to(device)
dis = Discriminator().to(device)

# 判别器优化器
d_optim = torch.optim.Adam(dis.parameters(), lr=1e-4) #通过减小判别器的学习率降低其能力
# 生成器优化器
g_optim = torch.optim.Adam(gen.parameters(), lr=1e-3)

loss_fn = torch.nn.BCELoss() # 二元交叉熵损失

# 绘图函数,将每一个epoch中生成器生成的图片绘制
def gen_img_plot(model, epoch, noise_input, label_input): # model为Generator,test_input代表生成器输入的随机数,label_input为标签输入
    # prediction = np.squeeze(model(test_input).detach().cpu().numpy()) #squeeze为去掉通道维度
    prediction = model(noise_input, label_input).permute(0, 2, 3, 1).cpu().numpy() #将通道维度放在最后
    plt.figure(figsize=(10, 10))
    for i in range(prediction.shape[0]): #prediction.shape[0]=noise_input的batchsize
        plt.subplot(2, 2, i + 1)
        plt.imshow((prediction[i]+1)/2) #从-1~1 --> 0~1
        plt.axis('off')
    plt.savefig('./CGAN_img/image_CGAN_{}.png'.format(epoch))
    # if epoch == 99:
    #     plt.show()

# 设置生成绘图图片的随机张量,这里可视化4张图片
noise_input = torch.randn(4, 100, device=device) #测试输入:16个长度为100的随机数
# print(noise_input)
label_input0 = torch.randint(0, 1, size=(4, )) #生成4个从0到1的随机整数
# print(label_input)
label_input_onehot = one_hot(label_input0).to(device) #将tensor转化为独热编码形式
# print(label_input_onehot)


# CGAN训练
D_loss = []
G_loss = []

for epoch in range(500):
    d_epoch_loss = 0 #判别器损失
    g_epoch_loss = 0 #生成器损失
    count = len(dataloader) #len(dataloader)返回批次数
    count1 = len(train_dataset) #len(train_dataset)返回样本数
    for step, (img, label) in enumerate(tqdm.tqdm(dataloader)): #此时返回的label已经是独热标签
        img = img.to(device)
        label = label.to(device)
        size = img.size(0) #该批次包含多少张图片
        random_noise = torch.randn(size, 100, device=device) #创建生成器的噪声输入

        d_optim.zero_grad() #判别器梯度清0
        real_output = dis(img, label) #将真实图像放到判别器上进行判断,得到对真实图像的预测结果
        d_real_loss = loss_fn(real_output, torch.ones_like(real_output)) #real_output应该被判定为1(真),得到判别器在真实图像上的损失
        d_real_loss.backward() #计算梯度

        gen_img = gen(random_noise, label) #得到生成图像
        fake_output = dis(gen_img.detach(), label) #将生成图像和对应的标签同时放到判别器上进行判断,得到对生成图像的预测结果,detach()为截断梯度
        d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) #fake_output应该被判定为0(假),得到判别器在生成图像上的损失
        d_fake_loss.backward()  # 计算梯度

        d_loss = d_real_loss + d_fake_loss #判别器的损失包含两部分
        d_optim.step() #判别器优化

        # 生成器
        g_optim.zero_grad() #生成器梯度清零
        fake_output = dis(gen_img, label) #将生成图像放到判别器上进行判断
        g_loss = loss_fn(fake_output, torch.ones_like(fake_output)) #此处希望生成的图像能被判定为1
        g_loss.backward()  # 计算梯度
        g_optim.step() #生成器优化

        with torch.no_grad(): # loss累加的过程不需要计算梯度
            d_epoch_loss += d_loss.item() #将每一个批次的损失累加
            g_epoch_loss += g_loss.item() #将每一个批次的损失累加

    with torch.no_grad():  # loss累加的过程不需要计算梯度
        g_epoch_loss /= count
        d_epoch_loss /= count
        D_loss.append(d_epoch_loss) #保存每一个epoch的平均loss
        G_loss.append(g_epoch_loss) #保存每一个epoch的平均loss
        print('Epoch:', epoch)
        gen_img_plot(gen, epoch, noise_input, label_input_onehot) #每个epoch会生成一张图

    plt.figure(figsize=(10, 10))
    plt.plot(range(1, len(D_loss) + 1), D_loss, label='D_loss')
    plt.plot(range(1, len(G_loss) + 1), G_loss, label='G_loss')
    plt.xlabel('epoch')  # 横轴名称
    plt.legend()
    plt.savefig('loss.png')  # 保存图片

cGAN生成的图像虽有很多缺陷,如图像边缘模糊,生成的图像分辨率太低等,但是它为后面的pix2pixGAN和CycleGAN开拓了道路!!!

最后放上我训练的结果(数据量不大,只有四百张狗的图片,效果不太明显!!!) 

 cgan,生成对抗网络,深度学习,人工智能,python,计算机视觉cgan,生成对抗网络,深度学习,人工智能,python,计算机视觉

cgan,生成对抗网络,深度学习,人工智能,python,计算机视觉

 文章来源地址https://www.toymoban.com/news/detail-802832.html

到了这里,关于生成对抗网络cGAN(条件GAN)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 深度学习生成对抗网络(GAN)

    生成对抗网络(Generative Adversarial Networks)是一种无监督深度学习模型,用来通过计算机生成数据,由Ian J. Goodfellow等人于2014年提出。模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。生成对抗网络被认为是当

    2024年02月07日
    浏览(61)
  • 深度学习基础——GAN生成对抗网络

            生成对抗网络GAN(Generative adversarial networks)是Goodfellow等在2014年提出的一种生成式模型。GAN在结构上受博弈论中的二元零和博弈(即二元的利益之和为零,一方的所得正是另一方的所失)的启发,系统由一个生成器和一个判别器构成。         生成器和判别器均可以

    2024年02月22日
    浏览(66)
  • 深度学习(4)---生成式对抗网络(GAN)

     1. 生成式对抗网络(Generative Adversarial Network,GAN)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。它启发自博弈论中的二人零和博弈(two-player game),两位博弈方分别由生成模型(generative model)和判别模型(discriminative model)充当。  2. 判别模

    2024年02月08日
    浏览(58)
  • 【Pytorch深度学习实战】(10)生成对抗网络(GAN)

     🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎 📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃 🎁欢迎各位→点赞👍 + 收藏⭐️ + 留言📝​ 📣系列专栏 - 机器学习【ML】 自然语言处理【NLP】  深度学习【DL】 ​  🖍foreword ✔说

    2023年04月08日
    浏览(55)
  • 深度学习7:生成对抗网络 – Generative Adversarial Networks | GAN

    生成对抗网络 – GAN 是最近2年很热门的一种无监督算法,他能生成出非常逼真的照片,图像甚至视频。我们手机里的照片处理软件中就会使用到它。 目录 生成对抗网络 GAN 的基本原理 大白话版本 非大白话版本 第一阶段:固定「判别器D」,训练「生成器G」 第二阶段:固定

    2024年02月11日
    浏览(57)
  • PyTorch深度学习实战(31)——生成对抗网络(Generative Adversarial Network, GAN)

    生成对抗网络 ( Generative Adversarial Networks , GAN ) 是一种由两个相互竞争的神经网络组成的深度学习模型,它由一个生成网络和一个判别网络组成,通过彼此之间的博弈来提高生成网络的性能。生成对抗网络使用神经网络生成与原始图像集非常相似的新图像,它在图像生成中应用

    2024年01月22日
    浏览(50)
  • 人工智能基础部分20-生成对抗网络(GAN)的实现应用

    大家好,我是微学AI,今天给大家介绍一下人工智能基础部分20-生成对抗网络(GAN)的原理与简单应用。生成对抗网络是一种由深度学习模型构成的神经网络系统,由一个生成器和一个判别器相互博弈来提升模型的能力。本文将从以下几个方面进行阐述:生成对抗网络的概念、

    2024年02月09日
    浏览(117)
  • 深度学习进阶篇[9]:对抗生成网络GANs综述、代表变体模型、训练策略、GAN在计算机视觉应用和常见数据集介绍,以及前沿问题解决

    【深度学习入门到进阶】必看系列,含激活函数、优化策略、损失函数、模型调优、归一化算法、卷积模型、序列模型、预训练模型、对抗神经网络等 专栏详细介绍:【深度学习入门到进阶】必看系列,含激活函数、优化策略、损失函数、模型调优、归一化算法、卷积模型、

    2024年02月08日
    浏览(102)
  • 人工智能(pytorch)搭建模型11-pytorch搭建DCGAN模型,一种生成对抗网络GAN的变体实际应用

    大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型11-pytorch搭建DCGAN模型,一种生成对抗网络GAN的变体实际应用,本文将具体介绍DCGAN模型的原理,并使用PyTorch搭建一个简单的DCGAN模型。我们将提供模型代码,并使用一些数据样例进行训练和测试。最后,我们将

    2024年02月08日
    浏览(73)
  • 大数据机器学习GAN:生成对抗网络GAN全维度介绍与实战

    本文为生成对抗网络GAN的研究者和实践者提供全面、深入和实用的指导。通过本文的理论解释和实际操作指南,读者能够掌握GAN的核心概念,理解其工作原理,学会设计和训练自己的GAN模型,并能够对结果进行有效的分析和评估。 生成对抗网络(GAN)是深度学习的一种创新架

    2024年02月03日
    浏览(41)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包