WGAN基本原理及Pytorch实现WGAN

这篇具有很好参考价值的文章主要介绍了WGAN基本原理及Pytorch实现WGAN。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

目录

1.WGAN产生背景

(1)超参数敏感

(2)模型崩塌

2.WGAN主要解决的问题

3.不同距离的度量方式

(1)方式一

(2)方式二

(3)方式三

(4)方式四

4.WGAN原理

(1)p和q分布下的距离计算 

(2)EM距离转换优化目标推导

(3)判别器和生成器的优化目标

5.WGAN训练算法 

6.WGAN网络结构

7.数据集下载

8.WGAN代码实现 

9.mainWindow窗口显示生成器生成的图片

10.模型下载 


GAN原理及Pytorch框架实现GAN(比较容易理解)

Pytorch框架实现DCGAN(比较容易理解)

CycleGAN的基本原理以及Pytorch框架实现

1.WGAN产生背景

        之所以会产生WGAN,主要是因为GAN网络模型训练困难的问题,其中主要体现在GAN模型对超参数比较敏感,需要精心挑选才能使模型训练起来,并且也会出现模式崩塌的现象。

(1)超参数敏感

        超参数敏感是指网络的结构设定,学习率,初始化状态等超参数对网络的训练过程影响比较大,微量的超参数调整将可能导致网络的训练结果截然不同

wgan,pytorch,python,计算机视觉,pytorch

左图:表示使用WGAN算法训练的结果;

右图:表示标准的GAN在不使用Batch Normalization层导致网络训练不稳定,无法收敛,生成的样本与真实样本之间差距很大。

        为了更好的训练GAN网络,DCGAN论文的作者提出了不使用Pooling层,多使用Batch Normalization层,不使用全连接层,生成网络中激活函数应使用ReLU,最后一层使用tanh激活函数,判别网络激活函数应使用LeakReLU等一系列经验性的训练技巧

        但是上面的技巧仅仅能在一定程度上避免出现训练不稳定的现象,并没有从理论上解释为什么会出现训练困难以及如何解决训练不稳定的问题。

(2)模型崩塌

        模型崩塌(Mode Collapse)是指模型生成的样本单一,多样性很差的现象。

        由于判别器只能鉴别单个样本是否为真实样本分布,并没有对多样性进行显式约束,导致生成模型可能倾向于生成真实分布的部分区间中的少量高质量样本,以此来在判别器中获得较高的概率值,而不会学习到全部的真实分布。

        模式崩塌在GAN的训练过程中比较常见。在训练过程中,通过可视化生成网络的样本,可以看到,生成的图片种类非常单一,生成网络总是倾向于生成某一种单一风格的样本图像。

2.WGAN主要解决的问题

  • 引入了一种新的分布距离度量方法:Wasserstein距离,也称为(Earth-Mover Distance)简称EM距离,表示从一个分布变换到另一个分布的最小代价。
  • 定义了一种称为Wasserstein GAN的GAN形式,该形式使EM距离的合理有效近似最小化,并且本文从理论上证明了相应的优化问题是合理的。
  • WGAN解决了GANs的主要训练问题。特别是,训练WGAN不需要维护在鉴别器和生成器的训练中保持谨慎的平衡,并且也不需要对网络架构进行仔细的设计。模式在GANs中典型的下降现象也显著减少。WGAN最引人注目的实际好处之一是能够通过训练鉴别器进行运算来连续地估计EM距离。绘制这些学习曲线不仅对调试和超参数搜索,但也与观察到的样品质量。

3.不同距离的度量方式

提示:下面的一些公式可能看起来很枯燥无味,但是如果读者可以坚持读完,将是不小的收获,而且下面给出的公式还是只是论文中推导公式的冰山一角。

(1)方式一

wgan,pytorch,python,计算机视觉,pytorch

(2)方式二

wgan,pytorch,python,计算机视觉,pytorch 

(3)方式三

wgan,pytorch,python,计算机视觉,pytorch 

(4)方式四

wgan,pytorch,python,计算机视觉,pytorch

 

4.WGAN原理

(1)p和q分布下的距离计算 

        导致GAN训练不稳定的原因是因为JS散度在不重叠的分布p和q上的梯度曲面是恒定为0,的。当分布p和q不重叠时,JS散度始终为0,从而导致此时GAN的训练梯度出现梯度弥散现象(或者梯度消失),参数长时间得不到更新,网络无法收敛。

wgan,pytorch,python,计算机视觉,pytorch 

        可以看到上面结果给出,当两个分布完全不重叠时,无论分布之间的距离远近,JS散度为恒定值log2,此时JS散度将无法产生有效的梯度信息;当两个分布出现重叠时,JS散度才会平滑变动,产生有效梯度信息;当完全重叠之后,JS散度最小值为0.

wgan,pytorch,python,计算机视觉,pytorch

wgan,pytorch,python,计算机视觉,pytorch

wgan,pytorch,python,计算机视觉,pytorch

        学习区分两个高斯时的最佳判别器(Discriminator)和critic。正如本文所看到的,极小极大GAN的鉴别器饱和并导致梯度消失。本文的WGANcritic在空间的所有部分都提供了非常平滑的渐变。

(2)EM距离转换优化目标推导

 wgan,pytorch,python,计算机视觉,pytorch

(3)判别器和生成器的优化目标

 wgan,pytorch,python,计算机视觉,pytorch

wgan,pytorch,python,计算机视觉,pytorch 

5.WGAN训练算法 

wgan,pytorch,python,计算机视觉,pytorch

        wgan,pytorch,python,计算机视觉,pytorch

具体实现代码如下:

for epoch in range(NUM_EPOCHS):
    for batch_idx,(data,_) in enumerate(dataLoader):
        data = data.to(device)
        cur_batch_size = data.shape[0]

        #Train: Critic : max[critic(real)] - E[critic(fake)]
        loss_critic = 0
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(size = (cur_batch_size,Z_DIM,1,1),device=device)
            fake_img = gen(noise)
            #使用reshape主要是将最后的维度从[1,1,1,1]=>[1]
            critic_real = critic(data).reshape(-1)
            critic_fake = critic(fake_img).reshape(-1)

            loss_critic = (torch.mean(critic_real)- torch.mean(critic_fake))
            opt_critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

            #clip critic weight between -0.01 , 0.01
            for p in critic.parameters():
                p.data.clamp_(-WEIGHT_CLIP,WEIGHT_CLIP)

        #将维度从[1,1,1,1]=>[1]
        gen_fake = critic(fake_img).reshape(-1)
        #max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        loss_gen = -torch.mean(gen_fake)
        opt_gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

6.WGAN网络结构

Pytorch框架实现DCGAN(比较容易理解)

 

7.数据集下载

链接:https://pan.baidu.com/s/1i_VU3aQpLkCx4Z5fhDVKHA 
提取码:79y3

 

8.WGAN代码实现 

提示:代码放在了Github上,本文的代码是参考下面这位博主写的,但是自己其中只是做了一下修改,并且其中加了一个mainWindows界面代码,方便后面训练的模型进行图像风格的转换。

参考博主的代码:https://b23.tv/QUc0CNb

本文的代码下载:https://github.com/KeepTryingTo/Pytorch-GAN
wgan,pytorch,python,计算机视觉,pytorch

 

9.mainWindow窗口显示生成器生成的图片

提示:这里编写了一个显示生成器显示图片的程序(mainWindow.py),加载之前训练之后保存的生成器模型,之后可使用该模型进行随机生成图片,如下:

(1)运行mainWindow.py 初始界面如下

wgan,pytorch,python,计算机视觉,pytorch

 点击随机生成图片:

wgan,pytorch,python,计算机视觉,pytorch

 

wgan,pytorch,python,计算机视觉,pytorch

 

 

10.模型下载 

 链接:https://pan.baidu.com/s/1dBbz6yyaRHMHl6Dl5Q24Dg 
提取码:6t7u

参考文章:

参考博主的代码:https://b23.tv/QUc0CNb

《TensorFlow深度学习》文章来源地址https://www.toymoban.com/news/detail-689220.html

到了这里,关于WGAN基本原理及Pytorch实现WGAN的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 单张图像3D重建:原理与PyTorch实现

    近年来,深度学习(DL)在解决图像分类、目标检测、语义分割等 2D 图像任务方面表现出了出色的能力。DL 也不例外,在将其应用于 3D 图形问题方面也取得了巨大进展。 在这篇文章中,我们将探讨最近将深度学习扩展到单图像 3D 重建任务的尝试,这是 3D 计算机图形领域最重

    2024年02月04日
    浏览(23)
  • 模型的权值平均的原理和Pytorch的实现

    模型权值平均是一种用于改善深度神经网络泛化性能的技术。通过对训练过程中不同时间步的模型权值进行平均,可以得到更宽的极值点(optima)并提高模型的泛化能力。 在PyTorch中,官方提供了实现模型权值平均的方法。 这里我们首先介绍指数移动平均(EMA)方法,它使用

    2024年01月20日
    浏览(25)
  • Actor-Critic(A2C)算法 原理讲解+pytorch程序实现

    强化学习在人工智能领域中具有广泛的应用,它可以通过与环境互动来学习如何做出最佳决策。本文将介绍一种常用的强化学习算法:Actor-Critic并且附上基于pytorch实现的代码。 Actor-Critic算法是一种基于策略梯度(Policy Gradient)和价值函数(Value Function)的强化学习方法,通常

    2024年02月11日
    浏览(34)
  • 自监督去噪: self2self 原理及实现(Pytorch)

    文章地址:https://ieeexplore.ieee.org/document/9157420 原始代码:https://github.com/scut-mingqinchen/self2self 本文参考代码: https://github.com/JinYize/self2self_pytorch 本文参考博客: https://zhuanlan.zhihu.com/p/361472663 website:https://csyhquan.github.io/ 1. 原理简介 噪声图片 y 可以表示为 干净图片 x 和噪声 n的叠

    2024年02月15日
    浏览(27)
  • 自监督去噪:Noise2Noise原理及实现(Pytorch)

    文章地址:https://arxiv.org/abs/1803.04189 ICML github 代码: https://github.com/NVlabs/noise2noise 本文整理和参考代码: https://github.com/shivamsaboo17/Deep-Restore-PyTorch 文章核心句子: ‘learn to turn bad images into good images by only looking at bad images, and do this just as well, sometimes even better.’ 1. 理论背景 如果有

    2024年02月14日
    浏览(22)
  • 卷积神经网络CNN原理+代码(pytorch实现MNIST集手写数字分类任务)

    前言 若将图像数据输入全连接层,可能会导致丧失一些位置信息 卷积神经网络将图像按照原有的空间结构保存,不会丧失位置信息。 卷积运算: 1.以单通道为例: 将将input中选中的部分与kernel进行数乘 : 以上图为例对应元素相乘结果为211,并将结果填入output矩阵的左上角

    2024年02月04日
    浏览(46)
  • 自监督去噪:Noise2Self原理分析及实现 (Pytorch)

    文章地址 :https://arxiv.org/abs/1901.11365 代码地址 : https://github.com/czbiohub-sf/noise2self 要点   Noise2Self方法不需要信号先验信息、噪声估计信息和干净的训练数据。唯一的 假设 就是噪声在测量的不同维度上表现出的统计独立性,而真实信号表现出一定的相关性。Noiser2Self根据J-in

    2024年02月14日
    浏览(24)
  • Pytorch+Python实现人体关键点检测

    用Python+Pytorch工程代码对人体进行关键点检测和骨架提取,并实现可视化。 物体检测为许多视觉任务提供动力,如实例分割、姿态估计、跟踪和动作识别。它在监控、自动驾驶和视觉答疑中有下游应用。当前的对象检测器通过紧密包围对象的轴向包围框来表示每个对象。然后

    2024年02月09日
    浏览(34)
  • PyTorch深度学习实战(5)——计算机视觉

    计算机视觉是指通过计算机系统对图像和视频进行处理和分析,利用计算机算法和方法,使计算机能够模拟和理解人类的视觉系统。通过计算机视觉技术,计算机可以从图像和视频中提取有用的信息,实现对环境的感知和理解,从而帮助人们解决各种问题和提高效率。本节中

    2024年02月15日
    浏览(34)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包