AIGC实战——条件生成对抗网络(Conditional Generative Adversarial Net, CGAN)

这篇具有很好参考价值的文章主要介绍了AIGC实战——条件生成对抗网络(Conditional Generative Adversarial Net, CGAN)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

0. 前言

我们已经学习了如何构建生成对抗网络 (Generative Adversarial Net, GAN) 以从给定的训练集中生成逼真图像。但是,我们无法控制想要生成的图像类型,例如控制模型生成男性或女性的面部图像;我们可以从潜空间中随机采样一个点,但是不能预知给定潜变量能够生成什么样的图像。在本节中,我们将构建一个能够控制输出的 GAN,即条件生成对抗网络 (Conditional Generative Adversarial Net, GAN)。该模型最早由 MirzaOsindero2014 年提出,是对 GAN 架构的简单改进。

1. CGAN架构

在本节中,我们将使用面部数据集中的头发颜色属性来设置 CGAN 的条件。也就是说,我们将能够明确指定是否要生成带有金发的图像。头发颜色标签作为 CelebA 数据集的一部分已在数据集中提供,CGAN 的架构如下图所示。

AIGC实战——条件生成对抗网络(Conditional Generative Adversarial Net, CGAN),AIGC,生成对抗网络,人工智能

标准 GANCGAN 之间的关键区别在于:在 CGAN 中,我们需要向生成器和判别器传递与标签相关的额外信息。在生成器中,标签信息转化为独热编码 (one-hot) 向量后附加在潜空间样本之后。在判别器中,通过重复独热编码向量填充得到与输入图像相同形状的通道,将标签信息添加为 RGB 图像的额外通道。
CGAN 之所以能够生成指定类型的图像,是因为其判别器可以获得关于图像内容的额外信息,因此生成器必须确保其输出与提供的标签一致,以继续欺骗判别器。如果生成器生成了与图像标签不一致的图像,即使图像非常逼真,判别器会将它们判定为伪造图像,因为图像和标签并不匹配。
在本节所构建的 CGAN 中,因为有两个类别(金发和非金发),独热编码标签的长度是 2。但是,我们也可以根据需要拥有使用多个标签。例如,在 Fashion-MNIST 数据集上训练 CGAN 时,为了输出 10 种不同类型的 Fashion-MNIST 图像,可以通过将长度为 10 的独热编码标签向量并入生成器的输入,并将 10 个额外的独热编码标签通道并入判别器的输入。
综上,我们需要对标准 GAN 架构所进行的修改是,将标签信息与生成器和判别器的现有输入连接起来:

# 图像通道和标签通道分别传递给判别器,并进行连接
critic_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))
label_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CLASSES))
x = layers.Concatenate(axis=-1)([critic_input, label_input])
x = layers.Conv2D(64, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2D(128, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU()(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(128, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(128, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(1, kernel_size=4, strides=1, padding="valid")(x)
critic_output = layers.Flatten()(x)

critic = models.Model([critic_input, label_input], critic_output)
print(critic.summary())
# 潜向量和标签类别分别传递给生成器,并在调整形状之前进行连接
generator_input = layers.Input(shape=(Z_DIM,))
label_input = layers.Input(shape=(CLASSES,))
x = layers.Concatenate(axis=-1)([generator_input, label_input])
x = layers.Reshape((1, 1, Z_DIM + CLASSES))(x)
x = layers.Conv2DTranspose(
    128, kernel_size=4, strides=1, padding="valid", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(
    128, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(
    128, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(
    64, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
generator_output = layers.Conv2DTranspose(
    CHANNELS, kernel_size=4, strides=2, padding="same", activation="tanh"
)(x)
generator = models.Model([generator_input, label_input], generator_output)
print(generator.summary())

2. 模型训练

调整 CGANtrain_step 方法,以令生成器和判别器适应新的输入格式:

    def train_step(self, data):
        # 从数据集中提取图像和标签
        real_images, one_hot_labels = data
        # 将独热编码向量扩展为具有与输入图像相同空间尺寸 (64×64) 的独热编码图像
        image_one_hot_labels = one_hot_labels[:, None, None, :]
        image_one_hot_labels = tf.repeat(image_one_hot_labels, repeats=IMAGE_SIZE, axis=1)
        image_one_hot_labels = tf.repeat(image_one_hot_labels, repeats=IMAGE_SIZE, axis=2)

        batch_size = tf.shape(real_images)[0]

        for i in range(self.critic_steps):
            random_latent_vectors = tf.random.normal( shape=(batch_size, self.latent_dim))

            with tf.GradientTape() as tape:
                # 生成器接受包含两个输入的列表——随机潜向量和独热编码的标签向量
                fake_images = self.generator([random_latent_vectors, one_hot_labels], training=True)
                # 判别器接受包含两个输入的列表——真实/生成图像和独热编码的标签通道
                fake_predictions = self.critic([fake_images, image_one_hot_labels], training=True)
                real_predictions = self.critic([real_images, image_one_hot_labels], training=True)

                c_wass_loss = tf.reduce_mean(fake_predictions) - tf.reduce_mean(real_predictions)
                c_gp = self.gradient_penalty(batch_size, real_images, fake_images, image_one_hot_labels)
                # 梯度惩罚函数还需要通过独热编码的标签通道传递(由于其流经判别器)
                c_loss = c_wass_loss + c_gp * self.gp_weight

            c_gradient = tape.gradient(c_loss, self.critic.trainable_variables)
            self.c_optimizer.apply_gradients(zip(c_gradient, self.critic.trainable_variables))

        random_latent_vectors = tf.random.normal(
            shape=(batch_size, self.latent_dim)
        )

        with tf.GradientTape() as tape:
            # 生成器训练过程的修改与判别器训练步骤的修改相同
            fake_images = self.generator([random_latent_vectors, one_hot_labels], training=True)
            fake_predictions = self.critic([fake_images, image_one_hot_labels], training=True)
            g_loss = -tf.reduce_mean(fake_predictions)

        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        self.g_optimizer.apply_gradients(zip(gen_gradient, self.generator.trainable_variables))

        self.c_loss_metric.update_state(c_loss)
        self.c_wass_loss_metric.update_state(c_wass_loss)
        self.c_gp_metric.update_state(c_gp)
        self.g_loss_metric.update_state(g_loss)
        return {m.name: m.result() for m in self.metrics}

3. CGAN 分析

我们可以通过将特定的独热编码标签传递到生成器的输入中来控制 CGAN 的输出。例如,要生成一张非金发的人脸图像,我们传入向量 [1, 0];要生成一张金发的人脸图像,我们传入向量 [0, 1]
CGAN 的输出如下图所示。可以看到,在保持随机潜向量不变的情况下,只改变条件标签向量,显然 CGAN 已经学会使用标签向量来控制图像的头发颜色属性,且图像的其余部分几乎没有改变。这证明了 GAN 能够以这种方式组织潜空间中的点,使得各个特征可以相互解耦。

AIGC实战——条件生成对抗网络(Conditional Generative Adversarial Net, CGAN),AIGC,生成对抗网络,人工智能

如果数据集中有标签可用,即使不一定需要将生成的输出与标签相关联,将它们作为 GAN 的输入通常也可以提高生成图像的质量,我们可以把标签看作是像素输入的信息扩展。

小结

在本节中,构建了一个条件生成对抗网络 (Conditional Generative Adversarial Net, CGAN),通过将标签作为输入传递给判别器和生成器,能够生成可控类别的图像,这是由于标签为网络提供了额外的信息,以便使生成的输出与给定的标签相关联。

系列链接

AIGC实战——生成模型简介
AIGC实战——深度学习 (Deep Learning, DL)
AIGC实战——卷积神经网络(Convolutional Neural Network, CNN)
AIGC实战——自编码器(Autoencoder)
AIGC实战——变分自编码器(Variational Autoencoder, VAE)
AIGC实战——使用变分自编码器生成面部图像
AIGC实战——生成对抗网络(Generative Adversarial Network, GAN)
AIGC实战——WGAN(Wasserstein GAN)文章来源地址https://www.toymoban.com/news/detail-760234.html

到了这里,关于AIGC实战——条件生成对抗网络(Conditional Generative Adversarial Net, CGAN)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • GAN(Generative Adversarial Nets (生成对抗网络))

    一、GAN 1、应用 GAN的应用十分广泛,如图像生成、图像转换、风格迁移、图像修复等等。 2、简介 生成式对抗网络是近年来复杂分布上无监督学习最具前景的方法之一。模型通过框架中(至少)两个模块:生成模型(Generative Model,G)和判别模型(Discriminative Model,D)的互相

    2024年02月04日
    浏览(37)
  • 生成对抗网络 – Generative Adversarial Networks | GAN

    目录 生成对抗网络 GAN 的基本原理 非大白话版本 第一阶段:固定「判别器D」,训练「生成器G」

    2024年04月15日
    浏览(45)
  • 李宏毅 Generative Adversarial Network(GAN)生成对抗网络

    附课程提到的各式各样的GAN:https://github.com/hindupuravinash/the-gan-zoo 想要让机器做到的是生成东西。-训练出来一个generator。 假设要做图像生成,要做的是随便给一个输入(random sample一个vector,比如从gaussian distribution sample一个vector),generator产生一个image。丢不同的vector,就应

    2024年01月21日
    浏览(56)
  • 深度学习7:生成对抗网络 – Generative Adversarial Networks | GAN

    生成对抗网络 – GAN 是最近2年很热门的一种无监督算法,他能生成出非常逼真的照片,图像甚至视频。我们手机里的照片处理软件中就会使用到它。 目录 生成对抗网络 GAN 的基本原理 大白话版本 非大白话版本 第一阶段:固定「判别器D」,训练「生成器G」 第二阶段:固定

    2024年02月11日
    浏览(57)
  • 【计算机视觉|生成对抗】条件生成对抗网络(CGAN)

    本系列博文为深度学习/计算机视觉论文笔记,转载请注明出处 标题: Conditional Generative Adversarial Nets 链接:[1411.1784] Conditional Generative Adversarial Nets (arxiv.org) 生成对抗网络(Generative Adversarial Nets) [8] 最近被引入为训练生成模型的一种新颖方法。 在这项工作中,我们介绍了生

    2024年02月13日
    浏览(42)
  • 生成对抗网络cGAN(条件GAN)

    论文: Conditional Generative Adversarial Nets 论文地址: https://arxiv.org/abs/1411.1784 针对原始GAN的缺点: 生成的图像是随机的,不可预测的,无法控制网络输出特定的图片,生成目标不明确,可控性不强等问题。 改进方法: cGAN的中心思想是希望可以控制 GAN 生成的图片,而不是单纯

    2024年01月18日
    浏览(43)
  • 条件生成对抗网络(cGAN)在AI去衣技术中的应用探索

    随着深度学习技术的飞速发展,生成对抗网络(GAN)作为其中的一个重要分支,在图像生成、图像修复等领域展现出了强大的能力。其中,条件生成对抗网络(cGAN)通过引入条件变量来控制生成模型的输出,进一步提高了GAN的灵活性和实用性。本文将深入探讨cGAN在AI去衣技术

    2024年04月27日
    浏览(50)
  • 产品经理看AIGC--GAN(生成对抗网络)白话原理

    《AIGC:智能创作时代》的阅读随笔(推荐单独阅读第二章,其余章节快速略过),期待从业务角度而非推导角度更好的理解,为产品从业人员提供更好的了解沟通渠道。 如何从白话角度理解生成对抗网络,核心在于如何理解“对抗”,通俗的字面理解“对抗”,我们会联想

    2024年02月15日
    浏览(38)
  • 【CVPR 2023的AIGC应用汇总(8)】3D相关(编辑/重建/生成) diffusion扩散/GAN生成对抗网络方法...

    【CVPR 2023的AIGC应用汇总(5)】语义布局可控生成,基于diffusion扩散/GAN生成对抗 【CVPR 2023的AIGC应用汇总(4)】图像恢复,基于GAN生成对抗/diffusion扩散模型 【CVPR 2023的AIGC应用汇总(3)】GAN改进/可控生成的方法10篇 【CVPR 2023的AIGC应用汇总(2)】可控文生图,基于diffusion扩散模型/G

    2024年02月10日
    浏览(52)
  • PyTorch 深度学习实战 | 基于生成式对抗网络生成动漫人物

    生成式对抗网络(Generative Adversarial Network, GAN)是近些年计算机视觉领域非常常见的一类方法,其强大的从已有数据集中生成新数据的能力令人惊叹,甚至连人眼都无法进行分辨。本文将会介绍基于最原始的DCGAN的动漫人物生成任务,通过定义生成器和判别器,并让这两个网络

    2023年04月17日
    浏览(50)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包