GAN实现mnist生成

这篇具有很好参考价值的文章主要介绍了GAN实现mnist生成。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

GAN参考,他写的超好

# 导入包
%matplotlib inline

import time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

import torchvision
from torchvision import models
from torchvision import transforms

# 如果有gpu就用gpu,如果没有就用cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 导入数据集
batch_size=32

# Compose定义了一系列transform,此操作相当于将多个transform一并执行
transform = transforms.Compose([
    transforms.ToTensor(),    
    # mnist是灰度图,此处只将一个通道标准化
    transforms.Normalize(mean=(0.5), 
                         std=(0.5))
    ])
                         
# 设定数据集
mnist_data = torchvision.datasets.MNIST("./mnist_data", train=True, download=True, transform=transform)

# 加载数据集,按照上述要求,shuffle本意为洗牌,这里指打乱顺序,很形象
dataloader = torch.utils.data.DataLoader(dataset=mnist_data,
                                         batch_size=batch_size,
                                         shuffle=True)
                                         
# 在线下载MNIST时如果下载速度特别慢可以更改源码,改为本地。

定义模型

image_size = 784
hidden_size = 256

# Discriminator
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid() # sigmoid结果为(0,1)
)

# Generator
latent_size = 64 # latent_size,相当于初始噪声的维数
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh() # 转换至(-1,1)
)

# 放到gpu上计算(如果有的话)
D = D.to(device)
G = G.to(device)

# 定义损失函数、优化器、学习率
loss_fn = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)


开始训练

# 先定义一个梯度清零的函数,方便后续使用
def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

# 迭代次数与计时
total_step = len(dataloader)
num_epochs = 200
start = time.perf_counter() # 开始时间

# 开始训练
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(dataloader): # 当前step
        batch_size = images.size(0) # 变成一维向量
        images = images.reshape(batch_size, image_size).to(device)
        
        # 定义真假label,用作评分
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        # 对D进行训练,D的损失函数包含两部分
        # 第一部分,D对真图的判断能力
        outputs = D(images) # 将真图送入D,输出(0,1),应该是越接近1越好
        d_loss_real = loss_fn(outputs, real_labels)
        real_score = outputs # 真图的分数,越大越好
        
        # 第二部分,D对假图的判断能力
        z = torch.randn(batch_size, latent_size).to(device) # 开始生成一组fake images即32*784的噪声经过G的假图
        fake_images = G(z)
        outputs = D(fake_images.detach()) # 将假图片给D,detach表示不作用于求grad
        d_loss_fake = loss_fn(outputs, fake_labels)
        fake_score = outputs # 假图的分数,越小越好
        
        # 开始优化discriminator
        d_loss = d_loss_real + d_loss_fake # 总的损失就是以上两部分相加,越小越好
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # 对G进行训练,G的损失函数包含一部分
        # 可以用前面的z,也可以新生成,因为模型没有改变,事实上是一样的
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        g_loss = loss_fn(outputs, real_labels) # G想骗过D,故让其越接近1越好
        
        # 开始优化generator
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        
        # 优化完成,下面进行一些反馈,展示学习进度
        if i % 100 == 0:
            print("Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}"
                  .format(epoch, num_epochs, i, total_step, d_loss.item(), g_loss.item(), real_score.mean().item(), fake_score.mean().item()))

# 训练结束,跳出循环,检验成果
end = time.perf_counter() # 结束时间
total = end - start
minutes = total//60
seconds = total - minutes*60
print("利用GPU总用时:{:.2f}分钟{:.2f}秒".format(minutes, seconds))


在上面的代码中,优化 Discriminator(D)和 Generator(G)是分开进行的。当优化 Discriminator 时,只有 Discriminator 的参数会被更新。这是通过执行 d_loss.backward() 和 d_optimizer.step() 来实现的。在这一步,Generator 的参数不会被更新。

同样地,当优化 Generator 时,只有 Generator 的参数会被更新。这是通过执行 g_loss.backward() 和 g_optimizer.step() 来实现的。在这一步,Discriminator 的参数不会被更新。

特别要注意的是,当计算 Discriminator 的损失时,用到了 fake_images.detach()。这是为了确保在更新 Discriminator 的时候不会影响到 Generator 的参数。

总的来说,每一步优化都只影响到一个模型的参数,不会同时更新两者。这样做是为了实现两个模型的对抗训练,其中 Discriminator 尽量变得擅长于区分真实图片和生成图片,而 Generator 尽量变得擅长于生成越来越逼真的图片。

在计算生成器的损失时,代码中没有明确地使用类似 .detach() 的操作来断开与判别器(Discriminator)参数的关联。这是因为在这个特定的训练步骤中,目的就是要更新生成器(Generator)的参数以使得由生成器生成的假图像能更好地欺骗判别器。

当执行 g_loss.backward() 和 g_optimizer.step() 时,只有生成器的参数会被更新。这是因为优化器 g_optimizer 只管理生成器的参数。因此,即使损失 g_loss 是基于判别器的输出计算的,判别器的参数在这一步也不会被更新。

简而言之,代码中通过使用不同的优化器(d_optimizer 和 g_optimizer)来分别管理判别器和生成器的参数,从而确保在各自的更新步骤中只更新一个模型的参数。这样,在更新生成器参数的时候,判别器的参数不会受到影响。文章来源地址https://www.toymoban.com/news/detail-732719.html

到了这里,关于GAN实现mnist生成的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【ChatGPT】人工智能生成内容的综合调查(AIGC):从 GAN 到 ChatGPT 的生成人工智能历史

      :AIGC,Artificial Intelligence Generated Content   【禅与计算机程序设计艺术:导读】 2022年,可以说是生成式AI的元年。近日,俞士纶团队发表了一篇关于AIGC全面调查,介绍了从GAN到ChatGPT的发展史。 论文地址: https://arxiv.org/pdf/2303.04226.pdf 刚刚过去的2022年,无疑是生成式

    2023年04月18日
    浏览(88)
  • 【人工智能概论】 构建神经网络——以用InceptionNet解决MNIST任务为例

    两条原则,四个步骤。 从宏观到微观 把握数据形状 准备数据 构建模型 确定优化策略 完善训练与测试代码 InceptionNet的设计思路是通过增加网络宽度来获得更好的模型性能。 其核心在于基本单元Inception结构块,如下图: 通过纵向堆叠Inception块构建完整网络。 MNIST是入门级的

    2023年04月20日
    浏览(52)
  • 人工智能|深度学习——基于对抗网络的室内定位系统

    基于CSI的工业互联网深度学习定位.zip资源-CSDN文库 室内定位技术是工业互联网相关技术的关键一环。 该技术旨在解决于室外定位且取得良好效果的GPS由于建筑物阻挡无法应用于室内的问题 。实现室内定位技术,能够在真实工业场景下实时追踪和调配人员并做到对自动化生产

    2024年02月20日
    浏览(45)
  • 【人工智能】实验五 采用卷积神经网络分类MNIST数据集与基础知识

    熟悉和掌握 卷积神经网络的定义,了解网络中卷积层、池化层等各层的特点,并利用卷积神经网络对MNIST数据集进行分类。 编写卷积神经网络分类软件,编程语言不限,如Python等,以MNIST数据集为数据,实现对MNIST数据集分类操作,其中MNIST数据集共10类,分别为手写0—9。

    2024年02月04日
    浏览(64)
  • 89 | Python人工智能篇 —— 深度学习算法 Keras 实现 MNIST分类

    本教程将带您深入探索Keras,一个开源的深度学习框架,用于构建人工神经网络模型。我们将一步步引导您掌握Keras的核心概念和基本用法,学习如何构建和训练深度学习模型,以及如何将其应用于实际问题中。

    2024年02月13日
    浏览(59)
  • AI技术在网络攻击中的滥用与对抗 - 人工智能恶意攻击

    随着人工智能技术的迅猛发展,我们享受到了许多便利,但同时也面临着新的安全威胁。本文将探讨人工智能技术在网络攻击中的滥用,并提出一些防御机制。 人工智能技术的先进性和灵活性使其成为恶意攻击者的有力工具。以下是一些常见的人工智能滥用案例: 欺骗和钓

    2024年02月12日
    浏览(43)
  • 数据生成 | MATLAB实现GAN生成对抗网络结合SVM支持向量机的数据生成

    生成效果 基本描述 数据生成 | MATLAB实现1-DGAN生成对抗网络的数据生成 1.Matlab实现1-DGAN生成对抗网络数据生成,运行环境Matlab2021b及以上; 2.基于生成数据训练SVM分类模型; 3.计算生成数据在SVM模型上的分类准确率,同时测试原始数据在生成数据训练SVM模型上的分类准确率;

    2024年02月10日
    浏览(77)
  • 【计算机视觉|生成对抗】生成对抗网络(GAN)

    本系列博文为深度学习/计算机视觉论文笔记,转载请注明出处 标题: Generative Adversarial Nets 链接:Generative Adversarial Nets (nips.cc) 我们提出了一个通过**对抗(adversarial)**过程估计生成模型的新框架,在其中我们同时训练两个模型: 一个生成模型G,捕获数据分布 一个判别模型

    2024年02月12日
    浏览(61)
  • GAN实现mnist生成

    GAN参考,他写的超好 在上面的代码中,优化 Discriminator(D)和 Generator(G)是分开进行的。当优化 Discriminator 时,只有 Discriminator 的参数会被更新。这是通过执行 d_loss.backward() 和 d_optimizer.step() 来实现的。在这一步,Generator 的参数不会被更新。 同样地,当优化 Generator 时,只

    2024年02月07日
    浏览(51)
  • GAN-对抗生成网络

    generator:

    2024年02月09日
    浏览(40)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包