深度学习9:简单理解生成对抗网络原理

这篇具有很好参考价值的文章主要介绍了深度学习9:简单理解生成对抗网络原理。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

目录

生成算法

生成对抗网络(GAN)

“生成”部分

“对抗性”部分

GAN如何运作?

培训GAN的技巧?

GAN代码示例

如何改善GAN?

结论


生成算法

您可以将生成算法分组到三个桶中的一个:

  1. 鉴于标签,他们预测相关的功能(朴素贝叶斯)
  2. 给定隐藏的表示,他们预测相关的特征(变分自动编码器,生成对抗网络)
  3. 鉴于一些功能,他们预测其余的(修复,插补)

我们将探索生成对抗网络的一些基础知识!GAN具有令人难以置信的潜力,因为他们可以学习模仿任何数据分布。也就是说,GAN可以学习在任何领域创造类似于我们自己的世界:图像,音乐,语音。

深度学习9:简单理解生成对抗网络原理,2023 AI,人工智能,深度学习

生成对抗网络(GAN)

“生成”部分

  • 叫做发电机
  • 给定某个标签,尝试预测功能
  • EX:鉴于电子邮件被标记为垃圾邮件,预测(生成)电子邮件的文本。
  • 生成模型学习各个类的分布。

“对抗性”部分

  • 称为判别者
  • 鉴于这些功能,尝试预测标签
  • EX:根据电子邮件的文本,预测(区分)垃圾邮件或非垃圾邮件。
  • 判别模型学习了类之间的界限。

GAN如何运作?

一个称为Generator的神经网络生成新的数据实例,而另一个神经网络Discriminator则评估它们的真实性。

您可以将GAN视为伪造者(发电机)和警察(Discriminator)之间的猫捉老鼠游戏。伪造者正在学习制造假钱,警察正在学习如何检测假钱。他们都在学习和提高。伪造者不断学习创造更好的假货,并且警察在检测它们时不断变得更好。最终的结果是,伪造者(发电机)现在接受了培训,可以创造出超现实的金钱!

让我们用MNIST手写数字数据集探索一个具体的例子:

深度学习9:简单理解生成对抗网络原理,2023 AI,人工智能,深度学习

我们将让Generator创建新的图像,如MNIST数据集中的图像,它取自现实世界。当从真实的MNIST数据集中显示实例时,Discriminator的目标是将它们识别为真实的。

同时,Generator正在创建传递给Discriminator的新图像。它是这样做的,希望它们也将被认为是真实的,即使它们是假的。Generator的目标是生成可通过的手写数字,以便在不被捕获的情况下进行说谎。Discriminator的目标是将来自Generator的图像分类为假的。

深度学习9:简单理解生成对抗网络原理,2023 AI,人工智能,深度学习

GAN步骤:

  1. 生成器接收随机数并返回图像。
  2. 将生成的图像与从实际数据集中获取的图像流一起馈送到鉴别器中。
  3. 鉴别器接收真实和假图像并返回概率,0到1之间的数字,1表示真实性的预测,0表示假

两个反馈循环:

  1. 鉴别器处于反馈循环中,具有图像的基本事实(它们是真实的还是假的),我们知道。
  2. 发生器与Discriminator处于反馈循环中(Discriminator将其标记为真实或伪造,无论事实如何)。

培训GAN的技巧?

在开始训练发生器之前预先识别鉴别器将建立更清晰的梯度。

训练Discriminator时,保持Generator值不变。训练发生器时,保持Discriminator值不变。这使网络能够更好地了解它必须学习的梯度。

GAN被制定为两个网络之间的游戏,重要:保持它们的平衡。如果发电机或鉴别器太好,GAN可能很难学习。

GAN需要很长时间才能训练。在单个GPU上,GAN可能需要数小时,在单个CPU上,GAN可能需要数天。

GAN代码示例

class GAN():
    def __init__(self):
        self.img_rows = 28 
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy', 
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build and compile the generator
        self.generator = self.build_generator()
        self.generator.compile(loss='binary_crossentropy', optimizer=optimizer)

        # The generator takes noise as input and generated imgs
        z = Input(shape=(100,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The valid takes generated images as input and determines validity
        valid = self.discriminator(img)

        # The combined model  (stacked generator and discriminator) takes
        # noise as input => generates images => determines validity 
        self.combined = Model(z, valid)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def build_generator(self):

        noise_shape = (100,)
        
        model = Sequential()

        model.add(Dense(256, input_shape=noise_shape))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=noise_shape)
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):

        img_shape = (self.img_rows, self.img_cols, self.channels)
        
        model = Sequential()

        model.add(Flatten(input_shape=img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, save_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)

        half_batch = int(batch_size / 2)

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half batch of images
            idx = np.random.randint(0, X_train.shape[0], half_batch)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (half_batch, 100))

            # Generate a half batch of new images
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)


            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, 100))

            # The generator wants the discriminator to label the generated samples
            # as valid (ones)
            valid_y = np.array([1] * batch_size)

            # Train the generator
            g_loss = self.combined.train_on_batch(noise, valid_y)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % save_interval == 0:
                self.save_imgs(epoch)

    def save_imgs(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, 100))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("gan/images/mnist_%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=30000, batch_size=32, save_interval=200)

如何改善GAN?

GAN刚刚在2014年发明 – 它们非常新!GAN是一个很有前途的生成模型家族,因为与其他方法不同,它们可以生成非常干净和清晰的图像,并学习包含有关基础数据的有价值信息的权重。但是,如上所述,可能难以使Discriminator和Generator网络保持平衡。有很多正在进行的工作使GAN培训更加稳定。

除了生成漂亮的图片之外,还开发了一种利用GAN进行半监督学习的方法,该方法涉及鉴别器产生指示输入标签的附加输出。这种方法可以使用极少数标记示例在数据集上实现最前沿结果。例如,在MNIST上,通过完全连接的神经网络,每个类只有10个标记示例,实现了99.1%的准确度 – 这一结果非常接近使用所有60,000个标记示例的完全监督方法的最佳已知结果。这是非常有希望的,因为在实践中获得标记的示例可能非常昂贵。文章来源地址https://www.toymoban.com/news/detail-684422.html

结论

到了这里,关于深度学习9:简单理解生成对抗网络原理的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 深度学习(4)---生成式对抗网络(GAN)

     1. 生成式对抗网络(Generative Adversarial Network,GAN)是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。它启发自博弈论中的二人零和博弈(two-player game),两位博弈方分别由生成模型(generative model)和判别模型(discriminative model)充当。  2. 判别模

    2024年02月08日
    浏览(46)
  • PyTorch 深度学习实战 | 基于生成式对抗网络生成动漫人物

    生成式对抗网络(Generative Adversarial Network, GAN)是近些年计算机视觉领域非常常见的一类方法,其强大的从已有数据集中生成新数据的能力令人惊叹,甚至连人眼都无法进行分辨。本文将会介绍基于最原始的DCGAN的动漫人物生成任务,通过定义生成器和判别器,并让这两个网络

    2023年04月17日
    浏览(37)
  • 【计算机视觉|生成对抗】用深度卷积生成对抗网络进行无监督表示学习(DCGAN)

    本系列博文为深度学习/计算机视觉论文笔记,转载请注明出处 标题: Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks 链接:[1511.06434] Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks (arxiv.org) 近年来,卷积网络(CNNs)的监督学习

    2024年02月13日
    浏览(48)
  • 【Pytorch深度学习实战】(10)生成对抗网络(GAN)

     🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎 📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃 🎁欢迎各位→点赞👍 + 收藏⭐️ + 留言📝​ 📣系列专栏 - 机器学习【ML】 自然语言处理【NLP】  深度学习【DL】 ​  🖍foreword ✔说

    2023年04月08日
    浏览(42)
  • 【深度学习】生成对抗网络Generative Adversarial Nets

            本文是GAN网络的原始论文,发表于2014年,我们知道,对抗网络是深度学习中,CNN基础上的一大进步; 它最大的好处是,让网络摆脱训练成“死模型”到固定场所处去应用,而是对于变化的场景,网络有一个自己的策略; 这是非常值得研究的课题。 本文记录了原

    2024年02月15日
    浏览(44)
  • 深度学习7:生成对抗网络 – Generative Adversarial Networks | GAN

    生成对抗网络 – GAN 是最近2年很热门的一种无监督算法,他能生成出非常逼真的照片,图像甚至视频。我们手机里的照片处理软件中就会使用到它。 目录 生成对抗网络 GAN 的基本原理 大白话版本 非大白话版本 第一阶段:固定「判别器D」,训练「生成器G」 第二阶段:固定

    2024年02月11日
    浏览(43)
  • PyTorch深度学习实战(31)——生成对抗网络(Generative Adversarial Network, GAN)

    生成对抗网络 ( Generative Adversarial Networks , GAN ) 是一种由两个相互竞争的神经网络组成的深度学习模型,它由一个生成网络和一个判别网络组成,通过彼此之间的博弈来提高生成网络的性能。生成对抗网络使用神经网络生成与原始图像集非常相似的新图像,它在图像生成中应用

    2024年01月22日
    浏览(38)
  • 基于深度学习、机器学习,对抗生成网络,OpenCV,图像处理,卷积神经网络计算机毕业设计选题指导

    开发一个实时手势识别系统,使用卷积神经网络(CNN)和深度学习技术,能够识别用户的手势并将其映射到计算机操作,如控制游戏、音量调整等。这个项目需要涵盖图像处理、神经网络训练和实时计算等方面的知识。 利用深度学习模型,设计一个人脸识别系统,可以识别人

    2024年02月07日
    浏览(59)
  • 竞赛保研 基于生成对抗网络的照片上色动态算法设计与实现 - 深度学习 opencv python

    🔥 优质竞赛项目系列,今天要分享的是 🚩 基于生成对抗网络的照片上色动态算法设计与实现 该项目较为新颖,适合作为竞赛课题方向,学长非常推荐! 🥇学长这里给一个题目综合评分(每项满分5分) 难度系数:3分 工作量:3分 创新点:4分 🧿 更多资料, 项目分享: http

    2024年01月17日
    浏览(35)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包