cycleGAN算法解读

这篇具有很好参考价值的文章主要介绍了cycleGAN算法解读。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

本文参考:https://blog.csdn.net/Mr_health/article/details/112545671

1 CycleGAN概述 

CycleGAN:循环生成对抗神经网络,是一种非监督学习模型。

Pix2pix方法适用于成对数据的风格迁移,而大多数情况下对于A风格的图像,并没有与之相对应的B风格图像。获取严格意义上的成对数据是非常困难的,所以不依赖成对数据的算法具有非常重要的实际意义。我们所拥有的是一群处于风格A(源域)的图像和一群处于风格B(目标域)的图像,这样pix2pix方法就不管用了。

cyclegan,神经网络,计算机视觉,人工智能

CycleGAN的创新点在于能够在源域和目标域之间,无须建立训练数据间一对一的映射,就可以实现这种迁移。

2 CycleGAN基本架构

(1)输入

X:源域,风格A的图像

Y:目标域,风格B的图像

(2)两个生成器:

G:用于将风格A的图像x转换为风格B的图像

F:用于将风格B的图像y转换为风格A的图像

(3)Cycle解释

通过G将风格A的图像x转换为风格B的图像Y‘,之后再将Y’通过F后仍然能够转换回风格A,并能保证图像中的内容一致。

通过F将风格B的图像y转换为风格A的图像X‘,之后再将X’通过G后仍然能够转换回风格B,并能保证图像中的内容一致。

也就是训练好G和F就可以自由地完成风格A、B的转换了。

cyclegan,神经网络,计算机视觉,人工智能

3 损失函数

在训练中引入两个判别器:

Dy:区分真实的风格B的图像与通过G转换而来的假的风格B的图像

Dx:区分真实的风格A的图像与通过G转换而来的假的风格A的图像

损失函数主要由以下几个部分构成:

(1)Dy处的GAN损失:

cyclegan,神经网络,计算机视觉,人工智能 (2)Dx处的GAN损失:

cyclegan,神经网络,计算机视觉,人工智能

(3)循环一致性损失,即cycle解释那块逻辑

 cyclegan,神经网络,计算机视觉,人工智能

(4)Identity loss

cyclegan,神经网络,计算机视觉,人工智能

这个loss的含义是:生成器G用来生成y风格的图像,那么把y送入G应该仍然生成y,只有这样才能证明G具有生成y风格的能力。因此G(y)和y应该尽可能接近。根据论文中的解释,如果不加入该loss,那么生成器可能会自主地修改图像的色调,使得整体的颜色发生变化。

4 CycleGAN网络结构解读

GAN由生成网络(Generator)和辨别(Discriminator)网络两部分组成

Generator网络有2个,分别支持A->B和B->A的转化,其输入输出不会改变维度信息,如下图所示:

cyclegan,神经网络,计算机视觉,人工智能

Discriminator网络也有2个:

D_A:G_A(A) vs B

D_B:G_B(B) vs A

输入后会改变维度大小,输出channel从3变为1,特征为30*30。如果为真,则与30*30的1进行比较;如果为假,则与30*30的0进行比较。如下图所示:

cyclegan,神经网络,计算机视觉,人工智能

Generator和Discriminator网络结构如下:

cyclegan,神经网络,计算机视觉,人工智能

Loss组成:

由Generator的loss和Discriminator的loss两部分组成。

Generator部分的loss:

Loss_G_A = D_A(G_A(A)), #从G的角度生成的B要让D尽量判断为1

Loss_G_B = D_B(G_B(B)), #从G的角度生成的A要让D尽量判断为1

Loss_cycle_A = || G_B(G_A(A)) - A||

Loss_cycle_B = || G_A(G_B(B)) - B||

Loss_idt_A = ||G_A(B) - B||

Loss_idt_B = ||G_B(A) - A||

Loss_G = Loss_G_A + Loss_G_B + Loss_cycle_A + Loss_cycle_B + Loss_idt_A + Loss_idt_B

Discriminator部分的loss:

Loss_D = criterionGAN (netD(real), true) + criterionGAN(netD(fake), false)

从D的角度,G生成的要尽量判断为0,真实的要尽量判断为1。

 5 代码解读

(1)前向传播部分:

NetG_A就是G,完成A->B的风格转换(源域到目标域)

NetG_B就是F,完成B->A的风格转换(目标域到源域)   

 def forward(self):

        """Run forward pass; called by both functions <optimize_parameters> and <test>."""

        self.fake_B = self.netG_A(self.real_A)  # G_A(A)

        self.rec_A = self.netG_B(self.fake_B)   # G_B(G_A(A))

        self.fake_A = self.netG_B(self.real_B)  # G_B(B)

        self.rec_B = self.netG_A(self.fake_A)   # G_A(G_B(B))

(2)更新G

在if lambda_idt > 0:这个分支内,实现的就是identity loss。

后面就是GAN损失(loss_G_A、 loss_G_B)以及循环一致性损失(loss_cycle_A、loss_cycle_B)

代码里面的判别器netD_A判断的是真实B风格和生成B风格的真假,相当于论文中的Dy。

同理netD_B判断的是真实A风格和生成A风格的真假,相当于论文中的Dx。   

 def backward_G(self):

        """Calculate the loss for generators G_A and G_B"""

        lambda_idt = self.opt.lambda_identity

        lambda_A = self.opt.lambda_A

        lambda_B = self.opt.lambda_B

        # Identity loss

        if lambda_idt > 0:

            # G_A should be identity if real_B is fed: ||G_A(B) - B||

            self.idt_A = self.netG_A(self.real_B)  #将真实的B送入netG_A(A->B风格生成器)生成的应该还是B风格

            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt

            # G_B should be identity if real_A is fed: ||G_B(A) - A||

            self.idt_B = self.netG_B(self.real_A) #将真实的A送入netG_B(B->A风格生成器)生成的应该还是A风格

            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt

        else:

            self.loss_idt_A = 0

            self.loss_idt_B = 0



        # GAN loss D_A(G_A(A))

        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)

        # GAN loss D_B(G_B(B))

        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)

        # Forward cycle loss || G_B(G_A(A)) - A||

        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A

        # Backward cycle loss || G_A(G_B(B)) - B||

        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B

        # combined loss and calculate gradients

        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B

        self.loss_G.backward()

(3)更新D:

  

  def backward_D_A(self):

        """Calculate GAN loss for discriminator D_A"""

        fake_B = self.fake_B_pool.query(self.fake_B)

        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)



    def backward_D_B(self):

        """Calculate GAN loss for discriminator D_B"""

        fake_A = self.fake_A_pool.query(self.fake_A)

        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

(4)生成器结构

cyclegan,神经网络,计算机视觉,人工智能

一共由3个卷积层 + 5个残差块 + 3个卷积层构成。

这里没有用到池化等操作,在开始卷积层中(第二层、第三层)进行了下采样,在最后的3个卷积层中进行了上采样,这样最直接的就是减少了计算复杂度。另外还有一个好处是感受野增大,卷积下采样会增大有效区域。

5个残差块都是使用相同个数的(128)滤镜核,每个残差块中都有2个卷积层(3*3核),这里的卷积层中没有标准的0填充(padding),因为使用0填充会使生成出的图像的边界出现严重伪影。为了保证输入输出图像大小不改变,在图像初始输入部分加入了反射填充。

这里的残差网络不是使用何凯明的残差网络,卷积之后没有Relu,而是使用了Gross and Wilber的残差网络,后面这种方法验证在图像分类算法上面效果比较好。文章来源地址https://www.toymoban.com/news/detail-590974.html

到了这里,关于cycleGAN算法解读的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 计算机竞赛 题目:基于机器视觉opencv的手势检测 手势识别 算法 - 深度学习 卷积神经网络 opencv python

    🔥 优质竞赛项目系列,今天要分享的是 基于机器视觉opencv的手势检测 手势识别 算法 该项目较为新颖,适合作为竞赛课题方向,学长非常推荐! 🧿 更多资料, 项目分享: https://gitee.com/dancheng-senior/postgraduate 普通机器视觉手势检测的基本流程如下: 其中轮廓的提取,多边形

    2024年02月07日
    浏览(38)
  • 深度学习(32)——CycleGAN

    前几天被Ly问GAN,所以去学了学,之前只知道大概,现在稍微懂一点 一个随机向量经过生成器生成的一个图像作为fake image,然后在训练集上随机挑选一张图片real image,将两张image输入辨别器,让他判断照片是real 或者fake 注 generator生成的数据是fake,在做loss的时候要保证fake

    2024年02月14日
    浏览(20)
  • PyTorch 实现CycleGAN 风格迁移

    目录 一、前言 二、数据集 三、网络结构 四、代码      (一)net      (二)train      (三)test  五、结果      (一)loss      (二)训练可视化      (三)测试结果  六、完整代码         pix2pix对训练样本要求较高,需要成对的数据集,而这种样本的获取往往需

    2024年02月04日
    浏览(25)
  • 基于CycleGAN的山水风格画迁移

    1.1 研究背景及意义 绘画是人类重要的一种艺术形式,其中中国的山水画源远流长,具有丰富的美学内涵,沉淀着中国人的情思。游山玩水的大陆文化意识,以山为德、水为性的内在修为意识,咫尺天涯的视错觉意识,一直成为山水画演绎的中轴主线。从山水画中,我们可以

    2024年02月10日
    浏览(23)
  • 轻量级卷积神经网络MobileNets详细解读

    随着深度学习的飞速发展,计算机视觉领域内的卷积神经网络种类也层出不穷。从1998年的LeNet网络到2012引起深度学习热潮年的AlexNet网络,再到2014年的VGG网络,再到后来2015的ResNet网络,深度学习网络在图像处理上表现得越来越好。但是这些网络都在不断增加网络深度和宽度来

    2024年02月04日
    浏览(30)
  • 论文解读:在神经网络中提取知识(知识蒸馏)

    提高几乎所有机器学习算法性能的一种非常简单的方法是在相同的数据上训练许多不同的模型,然后对它们的预测进行平均[3]。不幸的是,使用整个模型集合进行预测是很麻烦的,并且可能在计算上过于昂贵,无法部署到大量用户,特别是如果单个模型是大型神经网络。Car

    2024年02月21日
    浏览(34)
  • CycleGAN的基本原理以及Pytorch框架实现

    目录 1.了解CycleGAN (1)什么是CycleGAN  (2)CycleGAN的应用场景   2 CycleGAN原理 (1)整个模型 (2)优化目标  (3)训练生成器和判别器 (1)训练生成器 (2)训练判别器 3.CycleGAN的网络结构  (1)生成器模型 (2)判别器模型 4.CycleGAN代码实现  5.mainWindow窗口显示转换之后风

    2024年02月07日
    浏览(22)
  • 计算机视觉-卷积神经网络

    目录 计算机视觉的发展历程 卷积神经网络 卷积(Convolution) 卷积计算 感受野(Receptive Field) 步幅(stride) 感受野(Receptive Field) 多输入通道、多输出通道和批量操作 卷积算子应用举例 计算机视觉作为一门让机器学会如何去“看”的学科,具体的说,就是让机器去识别摄

    2024年02月10日
    浏览(26)
  • 【图神经网络】GNNExplainer代码解读及其PyG实现

    接上一篇博客图神经网络的可解释性方法及GNNexplainer代码示例,我们这里简单分析GNNExplainer源码,并用PyTorch Geometric手动实现。 GNNExplainer的源码地址:https://github.com/RexYing/gnn-model-explainer (1)安装: 推荐使用python3.7以及创建虚拟环境: (2)训练一个GCN模型 其中EXPERIMENT_NAM

    2024年02月12日
    浏览(33)
  • 风格迁移CycleGAN开源项目代码运行步骤详细教程

       最近在学习Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks这篇论文,论文下载地址,想要复现一下文中的代码,过程中遇到了很多问题,因此记录下来。遇到其他问题欢迎在评论区留言,相互解答。 如果没有安装Anaconda或者MIniconda的可以先安装,并学一下

    2024年02月02日
    浏览(22)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包