WGAN-gp模型——pytorch实现

这篇具有很好参考价值的文章主要介绍了WGAN-gp模型——pytorch实现。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

论文传送门:https://arxiv.org/pdf/1704.00028.pdf

WGAN存在的问题:在WGAN中,为使得判别器D(x)满足Lipschitz连续条件,从而对网络参数进行了[-c,c]的区间限制,使得网络参数分布极端,参数均接近于-c或c。

WGAN-gp的目的:解决WGAN参数分布极端的问题。 

WGAN-gp的方法:在判别器D的loss中增加梯度惩罚项,代替WGAN中对判别器D的参数区间限制,同样能保证D(x)满足Lipschitz连续条件。(证明过程见论文补充材料)

WGAN-gp模型——pytorch实现

WGAN-gp模型——pytorch实现

红框部分:与WGAN不同之处,即判别器D的loss增加梯度惩罚项和优化器选择Adam

梯度惩罚项的计算实现见代码70-87行,判别器D的损失函数修改见代码156行。文章来源地址https://www.toymoban.com/news/detail-513272.html

import os
import torch
from torch.utils.data import DataLoader

import torch.nn as nn

from torchvision import datasets, transforms
from torchvision.utils import save_image

from tqdm import tqdm


class Discriminator(nn.Module):  # 定义判别器(WS-divergence)
    def __init__(self, img_shape=(1, 28, 28)):  # 初始化方法
        super(Discriminator, self).__init__()  # 继承初始化方法
        self.img_shape = img_shape  # 图片形状

        self.linear1 = nn.Linear(self.img_shape[0] * self.img_shape[1] * self.img_shape[2], 512)  # linear映射
        self.linear2 = nn.Linear(512, 256)  # linear映射
        self.linear3 = nn.Linear(256, 1)  # linear映射
        self.leakyrelu = nn.LeakyReLU(0.2, inplace=True)  # leakyrelu激活函数

    def forward(self, x):  # 前传函数
        x = torch.flatten(x, 1)  # 输入图片从三维压缩至一维特征向量,(n,1,28,28)-->(n,784)
        x = self.linear1(x)  # linear映射,(n,784)-->(n,512)
        x = self.leakyrelu(x)  # leakyrelu激活函数
        x = self.linear2(x)  # linear映射,(n,512)-->(n,256)
        x = self.leakyrelu(x)  # leakyrelu激活函数
        x = self.linear3(x)  # linear映射,(n,256)-->(n,1)

        return x  # 返回近似拟合的Wasserstein距离


class Generator(nn.Module):  # 定义生成器
    def __init__(self, img_shape=(1, 28, 28), latent_dim=100):  # 初始化方法
        super(Generator, self).__init__()
        self.img_shape = img_shape  # 图片形状
        self.latent_dim = latent_dim  # 噪声z的长度

        self.linear1 = nn.Linear(self.latent_dim, 128)  # linear映射
        self.linear2 = nn.Linear(128, 256)  # linear映射
        self.bn2 = nn.BatchNorm1d(256, 0.8)  # bn操作
        self.linear3 = nn.Linear(256, 512)  # linear映射
        self.bn3 = nn.BatchNorm1d(512, 0.8)  # bn操作
        self.linear4 = nn.Linear(512, 1024)  # linear映射
        self.bn4 = nn.BatchNorm1d(1024, 0.8)  # bn操作
        self.linear5 = nn.Linear(1024, self.img_shape[0] * self.img_shape[1] * self.img_shape[2])  # linear映射
        self.leakyrelu = nn.LeakyReLU(0.2, inplace=True)  # leakyrelu激活函数
        self.tanh = nn.Tanh()  # tanh激活函数,将输出压缩至(-1.1)

    def forward(self, z):  # 前传函数
        z = self.linear1(z)  # linear映射,(n,100)-->(n,128)
        z = self.leakyrelu(z)  # leakyrelu激活函数
        z = self.linear2(z)  # linear映射,(n,128)-->(n,256)
        z = self.bn2(z)  # 一维bn操作
        z = self.leakyrelu(z)  # leakyrelu激活函数
        z = self.linear3(z)  # linear映射,(n,256)-->(n,512)
        z = self.bn3(z)  # 一维bn操作
        z = self.leakyrelu(z)  # leakyrelu激活函数
        z = self.linear4(z)  # linear映射,(n,512)-->(n,1024)
        z = self.bn4(z)  # 一维bn操作
        z = self.leakyrelu(z)  # leakyrelu激活函数
        z = self.linear5(z)  # linear映射,(n,1024)-->(n,784)
        z = self.tanh(z)  # tanh激活函数
        z = z.view(-1, self.img_shape[0], self.img_shape[1], self.img_shape[2])  # 从一维特征向量扩展至三维图片,(n,784)-->(n,1,28,28)

        return z  # 返回生成的图片


def cal_gp(D, real_imgs, fake_imgs, cuda):  # 定义函数,计算梯度惩罚项gp
    r = torch.rand(size=(real_imgs.shape[0], 1, 1, 1))  # 真假样本的采样比例r,batch size个随机数,服从区间[0,1)的均匀分布
    if cuda:  # 如果使用cuda
        r = r.cuda()  # r加载到GPU
    x = (r * real_imgs + (1 - r) * fake_imgs).requires_grad_(True)  # 输入样本x,由真假样本按照比例产生,需要计算梯度
    d = D(x)  # 判别网络D对输入样本x的判别结果D(x)
    fake = torch.ones_like(d)  # 定义与d形状相同的张量,代表梯度计算时每一个元素的权重
    if cuda:  # 如果使用cuda
        fake = fake.cuda()  # fake加载到GPU
    g = torch.autograd.grad(  # 进行梯度计算
        outputs=d,  # 计算梯度的函数d,即D(x)
        inputs=x,  # 计算梯度的变量x
        grad_outputs=fake,  # 梯度计算权重
        create_graph=True,  # 创建计算图
        retain_graph=True  # 保留计算图
    )[0]  # 返回元组的第一个元素为梯度计算结果
    gp = ((g.norm(2, dim=1) - 1) ** 2).mean()  # (||grad(D(x))||2-1)^2 的均值
    return gp  # 返回梯度惩罚项gp


if __name__ == "__main__":
    # 训练参数
    total_epochs = 100  # 训练轮次
    batch_size = 64  # 批大小
    lr_D = 4e-3  # 判别网络D学习率
    lr_G = 1e-3  # 生成网络G学习率
    num_workers = 8  # 数据加载线程数
    latent_dim = 100  # 噪声z长度
    image_size = 28  # 图片尺寸
    channel = 1  # 图片通道
    a = 10  # 梯度惩罚项系数
    clip_value = 0.01  # 判别器参数限定范围
    dataset_dir = "dataset/mnist"  # 训练数据集路径
    gen_images_dir = "gen_images"  # 生成样例图片路径
    cuda = True if torch.cuda.is_available() else False  # 设置是否使用cuda
    os.makedirs(dataset_dir, exist_ok=True)  # 创建训练数据集路径
    os.makedirs(gen_images_dir, exist_ok=True)  # 创建样例图片路径
    image_shape = (channel, image_size, image_size)  # 图片形状

    # 模型
    D = Discriminator(image_shape)  # 实例化判别器
    G = Generator(image_shape, latent_dim)  # 实例化生成器
    if cuda:  # 如果使用cuda
        D = D.cuda()  # 模型加载到GPU
        G = G.cuda()  # 模型加载到GPU

    # 数据集
    transform = transforms.Compose(  # 数据预处理方法
        [transforms.Resize(image_size),  # resize
         transforms.ToTensor(),  # 转为tensor
         transforms.Normalize([0.5], [0.5])]  # 标准化
    )
    dataloader = DataLoader(  # dataloader
        dataset=datasets.MNIST(  # 数据集选取MNIST手写体数据集
            root=dataset_dir,  # 数据集存放路径
            train=True,  # 使用训练集
            download=True,  # 自动下载
            transform=transform  # 应用数据预处理方法
        ),
        batch_size=batch_size,  # 设置batch size
        num_workers=num_workers,  # 设置读取数据线程数
        shuffle=True  # 设置打乱数据
    )

    # 优化器
    optimizer_D = torch.optim.Adam(D.parameters(), lr=lr_D)  # 定义判别网络Adam优化器,传入学习率lr_D
    optimizer_G = torch.optim.Adam(G.parameters(), lr=lr_G)  # 定义生成网络Adam优化器,传入学习率lr_G

    # 训练循环
    for epoch in range(total_epochs):  # 循环epoch
        pbar = tqdm(total=len(dataloader), desc=f'Epoch {epoch + 1}/{total_epochs}', postfix=dict,
                    mininterval=0.3)  # 设置当前epoch显示进度
        LD = 0
        LG = 0
        for i, (real_imgs, _) in enumerate(dataloader):  # 循环iter
            if cuda:  # 如果使用cuda
                real_imgs = real_imgs.cuda()  # 数据加载到GPU
            bs = real_imgs.shape[0]  # batchsize

            # 开始训练判别网络D
            optimizer_D.zero_grad()  # 判别网络D清零梯度
            z = torch.randn((bs, latent_dim))  # 生成输入噪声z,服从标准正态分布,长度为latent_dim
            if cuda:  # 如果使用cuda
                z = z.cuda()  # 噪声z加载到GPU
            fake_imgs = G(z).detach()  # 噪声z输入生成网络G,得到生成图片,并阻止其反向梯度传播
            gp = cal_gp(D, real_imgs, fake_imgs, cuda)
            loss_D = -torch.mean(D(real_imgs)) + torch.mean(D(fake_imgs)) + a * gp  # 判别网络D的损失函数,相较于WGAN,增加了梯度惩罚项a*gp
            loss_D.backward()  # 反向传播,计算当前梯度
            optimizer_D.step()  # 根据梯度,更新网络参数
            LD += loss_D.item()  # 累计判别网络D的loss

            # 开始训练生成网络G
            optimizer_G.zero_grad()  # 生成网络G清零梯度
            gen_imgs = G(z)  # 噪声z输入生成网络G,得到生成图片
            loss_G = -torch.mean(D(gen_imgs))  # 生成网络G的损失函数
            loss_G.backward()  # 反向传播,计算当前梯度
            optimizer_G.step()  # 根据梯度,更新网络参数
            LG += loss_G.item()  # 累计生成网络G的loss

            pbar.set_postfix(**{'D_loss': loss_D.item(), 'G_loss': loss_G.item()})  # 显示判别网络D和生成网络G的损失
            pbar.update(1)  # 步进长度
        pbar.close()  # 关闭当前epoch显示进度
        print("total_D_loss:%.4f,total_G_loss:%.4f" % (
        LD / len(dataloader), LG / len(dataloader)))  # 显示当前epoch训练完成后,判别网络D和生成网络G的总损失
        save_image(gen_imgs.data[:25], "%s/ep%d.png" % (gen_images_dir, (epoch + 1)), nrow=5,
                   normalize=True)  # 保存生成图片样例(5x5)

到了这里,关于WGAN-gp模型——pytorch实现的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 数据生成 | MATLAB实现WGAN生成对抗网络数据生成

    生成效果 基本描述 1.WGAN生成对抗网络,数据生成,样本生成程序,MATLAB程序; 2.适用于MATLAB 2020版及以上版本; 3.基于Wasserstein生成对抗网络(Wasserstein Generative Adversarial Network,WGAN)的数据生成模型引入了梯度惩罚(Gradient Penalty)来改善训练的稳定性和生成样本的质量。W

    2024年02月12日
    浏览(36)
  • ResNet论文解读及代码实现(pytorch)

    又重新看了一遍何凯明大神的残差网络,之前懵懵懂懂的知识豁然开朗了起来。然后,虽然现在CSDN和知乎的风气不是太好,都是一些复制粘贴别人的作品来给自己的博客提高阅读量的人,但是也可以从其中汲取到很多有用的知识,我们要取其精华,弃其糟粕。 我只是大概的

    2024年02月04日
    浏览(52)
  • 人工智能(pytorch)搭建模型9-pytorch搭建一个ELMo模型,实现训练过程

    大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型9-pytorch搭建一个ELMo模型,实现训练过程,本文将介绍如何使用PyTorch搭建ELMo模型,包括ELMo模型的原理、数据样例、模型训练、损失值和准确率的打印以及预测。文章将提供完整的代码实现。 ELMo模型简介 数据

    2024年02月07日
    浏览(67)
  • Transformer模型 | Python实现TransformerCPI模型(pytorch)

    效果一览 文章概述 Python实现TransformerCPI模型(tensorflow) Dependencies: python 3.6 pytorch = 1.2.0 numpy RDkit = 2019.03.3.0 pandas Gensim =3.4.0 程序设计

    2024年02月07日
    浏览(35)
  • 用Pytorch实现线性回归模型

    前面已经学习过线性模型相关的内容,实现线性模型的过程并没有使用到Pytorch。 这节课主要是利用Pytorch实现线性模型。 学习器训练: 确定模型(函数) 定义损失函数 优化器优化(SGD) 之前用过Pytorch的Tensor进行Forward、Backward计算。 现在利用Pytorch框架来实现。 准备数据集

    2024年01月19日
    浏览(51)
  • 人工智能(pytorch)搭建模型14-pytorch搭建Siamese Network模型(孪生网络),实现模型的训练与预测

    大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型14-pytorch搭建Siamese Network模型(孪生网络),实现模型的训练与预测。孪生网络是一种用于度量学习(Metric Learning)和比较学习(Comparison Learning)的深度神经网络模型。它主要用于学习将两个输入样本映射到一个

    2024年02月11日
    浏览(143)
  • 推荐系统 | 基础推荐模型 | 矩阵分解模型 | 隐语义模型 | PyTorch实现

    基础推荐模型——传送门 : 推荐系统 | 基础推荐模型 | 协同过滤 | UserCF与ItemCF的Python实现及优化 推荐系统 | 基础推荐模型 | 矩阵分解模型 | 隐语义模型 | PyTorch实现 推荐系统 | 基础推荐模型 | 逻辑回归模型 | LS-PLM | PyTorch实现 推荐系统 | 基础推荐模型 | 特征交叉 | FM | FFM |

    2023年04月09日
    浏览(52)
  • 【youcans动手学模型】SENet 模型及 PyTorch 实现

    欢迎关注『youcans动手学模型』系列 本专栏内容和资源同步到 GitHub/youcans 本文用 PyTorch 实现 SENet 网络模型,使用 CIFAR10 数据集训练模型,进行图像分类。 胡杰团队(Momenta)在 2017 年发表论文 “Squeeze and Excitation Networks”,提出一种深度学习神经网络模型,称为 SENet。该论文

    2024年02月12日
    浏览(30)
  • 实践教程|基于 pytorch 实现模型剪枝

    PyTorch剪枝方法详解,附详细代码。 一,剪枝分类 1.1,非结构化剪枝 1.2,结构化剪枝 1.3,本地与全局修剪 二,PyTorch 的剪枝 2.1,pytorch 剪枝工作原理 2.2,局部剪枝 2.3,全局非结构化剪枝 三,总结 参考资料 所谓模型剪枝,其实是一种从神经网络中移除\\\"不必要\\\"权重或偏差(

    2024年02月12日
    浏览(40)
  • 使用PyTorch实现混合专家(MoE)模型

    Mixtral 8x7B 的推出在开放 AI 领域引发了广泛关注,特别是混合专家(Mixture-of-Experts:MoEs)这一概念被大家所认知。混合专家(MoE)概念是协作智能的象征,体现了“整体大于部分之和”的说法。MoE模型汇集了各种专家模型的优势,以提供更好的预测。它是围绕一个门控网络和一

    2024年01月17日
    浏览(43)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包