使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)01 手写字体识别

这篇具有很好参考价值的文章主要介绍了使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)01 手写字体识别。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。


前面的博客讲了如何基于PyTorch使用神经网络识别手写数字
使用PyTorch构建神经网络
下面在此基础上构建一个生成对抗网络,生成对抗网络可以模拟出新的手写数字数据集。

1 生成对抗网络基本概念

生成对抗网络(GAN)是一种用于生成新的照片,文本或音频的模型。它由两部分组成:生成器和判别器。生成器的作用是生成新的样本,而判别器的作用是识别这些样本是真实的还是假的。两个模型相互博弈,通过不断调整自己的参数来提高自己的能力。生成器希望判别器错误地认为其生成的样本是真实的,而判别器希望能正确地识别生成器生成的样本是假的。最终,生成器会学到如何生成逼真的样本,而判别器会学到如何区分真假样本。

一个非常形象的例子,目前的数据集是人民币,生成器是造假币的,判别器是银行。刚开始造假币的只是粗略模仿人民币的印制,银行由于没有经验也分辨不好真钱还是假币。但随着时间推移,银行对鉴别假币越来越有经验,造假币的水平也变得越来越逼真,二者不断进步,这就是GAN网络。

生成对抗网络模型代码pytorch,深度学习,pytorch,生成对抗网络,python

2 生成对抗网络建模

2.1 建立MnistDataset类

对于非GAN独有的建模部分,讲解不会细化到每一行代码,如有阅读困难可参考本博客使用PyTorch构建神经网络部分的文章。但基本上具备Python的基础知识即可顺利阅读本篇文章。
与神经网络建模相同,我们首先构建一个MnistDataset类,这个类具备getitem功能,可以返回每条数据相应的数据标签label,image_values, target。这些变量的含义分别是:

  1. label:获得了指定数据的第一个数值,也就是这个数据的标签;
  2. target:制作了一个维度为10的张量,标签对应的项是1,其他是0。比如,某个手写数据的标签是2,则这个张量是[0, 0, 1, 0, 0, 0, 0, 0, 0, 0]。
  3. image_values:像素输入的值是0-255,这里对像素数据做了标准化,是值位于0-1之间。

同样,我们定义了一个绘制的功能,这个功能在建模中并没有实际作用,但是会很方便我们快速查看数据是否成功导入。MnistDataset类的全部代码如下:

class MnistDataset:
    def __init__(self, csv_file):
        self.data = pandas.read_csv(csv_file)
        pass

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        # 预期输出的张量制作
        label = self.data.iloc[index, 0]
        target = torch.zeros(10)
        target[label] = 1.0
        # 图像数据标准化
        image_values = torch.FloatTensor(self.data.iloc[index, 1:].values) / 255.0
        return label, image_values, target

    # 制图
    def plot_image(self, index):
        arr = self.data.iloc[index, 1:].values.reshape(28, 28)
        plt.title("label=" + str(self.data.iloc[index, 0]))
        plt.imshow(arr, interpolation='none', cmap='Blues')
        plt.show()

2.2 建立鉴别器

此处的鉴别器与基于PyTorch建立神经网络一文中的鉴别器基本相同。主要不同的是网络的输出层:本鉴别器的的网格为784-200-1。网格的输出层只有一个节点,这是因为鉴别器只需要判断这是真实数据还是虚假数据即可。真实数据为1,虚假数据为0。
鉴别器的主要函数包括:

# 鉴别器类
class Discriminator(nn.Module):

    def __init__(self):
        # 初始化父类
        super().__init__()

        # 定义神经网络
        self.model = nn.Sequential(
            nn.Linear(784, 200),
            nn.LeakyReLU(0.02),
            nn.LayerNorm(200),
            nn.Linear(200, 1),
            nn.Sigmoid()
        )


        # 创造损失函数
        self.loss_function = nn.MSELoss()

        # 创造优化器
        self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)

        # 创造进程计数器
        self.counter = 0
        self.progress = []

对类的初始化中:继承父类nn.Module的初始化属性;并建立784-200-1的神经网络,神经网络的激活函数使用最经典的Sigmoid函数;建立损失函数与优化器,损失函数选择MSE方法(均方误差)。

    def forward(self, inputs):
        # 执行模型
        return self.model(inputs)

简单的执行功能,能够基于input输出预测结果,即0或1。

    def train(self, inputs, targets):
        # 计算输出
        outputs = self.forward(inputs)

        # 计算损失
        loss = self.loss_function(outputs, targets)

        # 赋值进程计数器
        self.counter += 1
        if self.counter % 10 == 0:
            self.progress.append(loss.item())
        if self.counter % 10000 == 0:
            print("counter = ", self.counter)

        # 计算损失梯度,优化权重
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

训练模块,可以实现基于模型实际输出与与其输出,不断更新网络的权重。并每隔10次训练计算此时模型的损失,每隔10000次训练打印一次训练次数,方便掌握训练进度。

    # 绘制损失与训练过程的关系
    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0, 1.0), figsize=(16, 8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5))

对前面每10条保存一次的模型损失函数结果进行绘图。

2.3 测试鉴别器

此处我们还没有编写生成器,但是可以创建一个随机数据集,看看鉴别器是否可以分辨出真实的mnist数据和随机数据。
首先建立一个用于生成随机数据的生成器,size是生成数据的特征数。

def generate_random(size):
    random_data = torch.rand(size)
    return random_data

接下来我们用真是数据与随机数据训练模型

for label, image_data_tensor, target_tensor in mnist_dataset:
    # 真实数据
    D.train(image_data_tensor, torch.FloatTensor([1.0]))
    # 随机数据
    D.train(generate_random(784), torch.FloatTensor([0.0]))

其中真是数据我们希望输出节点的数据输出是1,而随机数据我们希望的输出是0。
在训练完成后,可以使用我们在鉴别器类中定义的绘图功能,查看模型损失的变化情况。同时,也可以再传入4组随机真假数据,来更清晰的查看此时模型的训练情况。

for i in range(4):
  image_data_tensor = mnist_dataset[random.randint(0,60000)][1]
  print( D.forward( image_data_tensor ).item() )
  pass

for i in range(4):
  print( D.forward( generate_random(784) ).item() )
  pass

基于这个运行结果也可以判断出,模型是可以有效的区分真实数据与随机数据的。
生成对抗网络模型代码pytorch,深度学习,pytorch,生成对抗网络,python

2.4 Mnist生成器制作

生成器与判别器都是神经网络模型,所以代码基本相同,这里主要讲一下不同的地方。与判别器相比,Mnist生成器应该与判别器的网格结构刚好相反。因为判别器是输入图像输出判别结果,而生成器应该是输入判别结果,输出图像。所以网络的结构可以是1-200-784。事实上,此处我们只要保证输出的格式是784个数据即可,为了让输出的数据更加多元,我们也可以增加输入层的节点数量。这里节点数量使用1,10,甚至是100都是可以的。此处我们以100个输入节点为例。

class Generator(nn.Module):

    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(100, 200),
            nn.LeakyReLU(0.02),
            nn.LayerNorm(200),
            nn.Linear(200, 784),
            nn.Sigmoid()
        )

除此之外,生成器的训练过程也稍有不同。在使用生成器生成数据后,我们需要将这个数据传入判别器,并使用判别器返回的损失作为这个生成器的损失。在Python中,在一个类调用类一个类的功能是完全可以的因此这一步骤变得简单了很多。

class Generator(nn.Module):
    def train(self, D, inputs, targets):
        # 生成器生成数据
        g_output = self.forward(inputs)

        # 将生成的数据传入判别器
        d_output = D.forward(g_output)

        loss = D.loss_function(d_output, targets)

        if self.counter % 10 == 0:
            self.progress.append(loss.item())

        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

除了以上两项,这个生成器都与鉴别器完全相同,大家按此更改或者直接在文末下载完整版代码均可。
在训练GAN之前,可以检查一下生成器的输出是否正确。方法还是让生成器生成一个数据,然后使用plt包绘制出来

G = Generator()
output = G.forward(generate_random(100))
img = output.detach().numpy().reshape(28,28)
plt.imshow(img, interpolation='none', cmap='Blues')
plt.show()

生成对抗网络模型代码pytorch,深度学习,pytorch,生成对抗网络,python
现在,我们的模型中就具备了生成对抗网络的三要素:真实数据、生成器与对抗器。

3 模型的训练

对于生成器,其输入是由我们使用随机数据生成器来产生的。之前我们使用torch.rand进行随机数据的生成,这次可以尝试使用torch.randn。两者的区别是:randn是从标准正态分布中返回一个或多个样本值。

# 生成器使用的随即输入
def generate_random_seed(size):
    random_data = torch.randn(size)
    return random_data

同2.3的过程一样,在训练过程中,我们将真是数据与生成器产出的数据交替传入鉴别器,只是此处增加了对生成器的训练。

for label, image_data_tensor, target_tensor in mnist_dataset:
    # 使用真实数据训练判别器
    D.train(image_data_tensor, torch.FloatTensor([1.0]))

    # 使用生成器数据训练判别器
    # 使用 detach() 截断梯度计算
    D.train(G.forward(generate_random_seed(100)).detach(), torch.FloatTensor([0.0]))

    # 训练生成器
    G.train(D, generate_random_seed(100), torch.FloatTensor([1.0]))

值得注意的是,在这里使用生成器数据训练判别器时,我们使用detach进行了截断。这个作用是在计算梯度时,对下图红叉所示地方进行切断,使梯度计算到这里就截止了,也就是此次计算只对生成器有效。这一操作的功能是降低模型的计算量
生成对抗网络模型代码pytorch,深度学习,pytorch,生成对抗网络,python
同样此处也可以引入time模块对训练进行计时。

4 模型表现的判断

前面在定义类时,我们已经内置好了绘制损失随训练变化的功能,这里直接调用即可。

D.plot_progress()
G.plot_progress()

生成对抗网络模型代码pytorch,深度学习,pytorch,生成对抗网络,python
鉴别器的损失基本看不到明显的变化,这是因为尽管鉴别器的能力不断提升,生成器的能力却也在不断提升。
生成对抗网络模型代码pytorch,深度学习,pytorch,生成对抗网络,python

生成器稍有不同,在前期出现了下降的趋势,在一定程度上骗过了鉴别器,但后期随着鉴别器能力的提升,生成器的随时也趋于稳定。

我们也可以依据生成器输出的图像,来更直观的判断生成器的表现。

f, axarr = plt.subplots(2,3, figsize=(16,8))
for i in range(2):
    for j in range(3):
        output = G.forward(generate_random_seed(100))
        img = output.detach().numpy().reshape(28,28)
        axarr[i,j].imshow(img, interpolation='none', cmap='Blues')
plt.show()

生成对抗网络模型代码pytorch,深度学习,pytorch,生成对抗网络,python
我们使用plt建立了一个2行3列的画布,并向生成器传入了随机参数,可以看到生成器的输出已经和手写图像很像了。

以上内容的全部代码,可以直接打包下载文章来源地址https://www.toymoban.com/news/detail-781216.html

到了这里,关于使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)01 手写字体识别的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 学习笔记:Pytorch利用MNIST数据集训练生成对抗网络(GAN)

    2023.8.27        在进行深度学习的进阶的时候,我发了生成对抗网络是一个很神奇的东西,为什么它可以“将一堆随机噪声经过生成器变成一张图片”,特此记录一下学习心得。         2014年,还在蒙特利尔读博士的Ian Goodfellow发表了论 文《Generative Adversarial Networks》(网址

    2024年02月10日
    浏览(42)
  • pytorch生成对抗网络GAN的基础教学简单实例(附代码数据集)

    这篇文章主要是介绍了使用pytorch框架构建生成对抗网络GAN来生成虚假图像的原理与简单实例代码。数据集使用的是开源人脸图像数据集img_align_celeba,共1.34G。生成器与判别器模型均采用简单的卷积结构,代码参考了pytorch官网。 建议对pytorch和神经网络原理还不熟悉的同学,可

    2024年02月15日
    浏览(43)
  • PyTorch深度学习实战(31)——生成对抗网络(Generative Adversarial Network, GAN)

    生成对抗网络 ( Generative Adversarial Networks , GAN ) 是一种由两个相互竞争的神经网络组成的深度学习模型,它由一个生成网络和一个判别网络组成,通过彼此之间的博弈来提高生成网络的性能。生成对抗网络使用神经网络生成与原始图像集非常相似的新图像,它在图像生成中应用

    2024年01月22日
    浏览(50)
  • 适合小白学习的GAN(生成对抗网络)算法超详细解读

    “GANs are \\\'the coolest idea in deep learning in the last 20 years.\\\' ” --Yann LeCunn, Facebook’s AI chief   今天我们就来认识一下这个传说中被誉为过去20年来深度学习中最酷的想法——GAN。  GAN之父的主页: http://www.iangoodfellow.com/  GAN论文地址: https://arxiv.org/pdf/1406.2661.pdf 目录 前言  📢一、

    2024年02月02日
    浏览(47)
  • 人工智能(pytorch)搭建模型11-pytorch搭建DCGAN模型,一种生成对抗网络GAN的变体实际应用

    大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型11-pytorch搭建DCGAN模型,一种生成对抗网络GAN的变体实际应用,本文将具体介绍DCGAN模型的原理,并使用PyTorch搭建一个简单的DCGAN模型。我们将提供模型代码,并使用一些数据样例进行训练和测试。最后,我们将

    2024年02月08日
    浏览(73)
  • 【计算机视觉|生成对抗】生成对抗网络(GAN)

    本系列博文为深度学习/计算机视觉论文笔记,转载请注明出处 标题: Generative Adversarial Nets 链接:Generative Adversarial Nets (nips.cc) 我们提出了一个通过**对抗(adversarial)**过程估计生成模型的新框架,在其中我们同时训练两个模型: 一个生成模型G,捕获数据分布 一个判别模型

    2024年02月12日
    浏览(61)
  • 生成对抗网络 (GAN)

    生成对抗网络(Generative Adversarial Networks,GAN)是由Ian Goodfellow等人在2014年提出的一种深度学习模型。GAN由两部分组成:一个生成器(Generator)和一个判别器(Discriminator),它们通过对抗过程来训练,从而能够生成非常逼真的数据。 生成器(Generator) 生成器的任务是创建尽可

    2024年03月10日
    浏览(65)
  • GAN(生成对抗网络)

    简介:GAN生成对抗网络本质上是一种思想,其依靠神经网络能够拟合任意函数的能力,设计了一种架构来实现数据的生成。 原理:GAN的原理就是最小化生成器Generator的损失,但是在最小化损失的过程中加入了一个约束,这个约束就是使Generator生成的数据满足我们指定数据的分

    2024年02月11日
    浏览(48)
  • 生成对抗网络----GAN

    ` GAN (Generative Adversarial Network) : 通过两个神经网络,即生成器(Generator)和判别器(Discriminator),相互竞争来学习数据分布。 { 生成器 ( G e n e r a t o r ) : 负责从随机噪声中学习生成与真实数据相似的数据。 判别器 ( D i s c r i m i n a t o r ) : 尝试区分生成的数据和真实数据。

    2024年02月20日
    浏览(48)
  • GAN-对抗生成网络

    generator:

    2024年02月09日
    浏览(40)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包