GAN原理 & 代码解读

这篇具有很好参考价值的文章主要介绍了GAN原理 & 代码解读。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

模型架构

GAN原理 & 代码解读,机器学习&深度学习笔记,生成对抗网络,人工智能,神经网络
GAN原理 & 代码解读,机器学习&深度学习笔记,生成对抗网络,人工智能,神经网络

代码

数据准备

import os
import time
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import torch.nn as nn
import torch

# 创建文件夹存放图片
os.makedirs("data", exist_ok=True)
transform = transforms.Compose([
    transforms.ToTensor(), #它会进行0-1归一化,h方向/h,w方向/w。 然后将图片格式转换为 (channel,h,w)
    transforms.Normalize(0.5,0.5),#把数据归一化为均值为0.5,方差为0.5,图像的数值范围变成-1到1
])
# 下载训练数据后对图片进行transform里的toTensor和用均值方差归一化
train_dataset = datasets.MNIST('data',
                               train=True,
                               transform=transform,
                               download=True)
dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=64,shuffle=True)

定义生成器

'''
    输入:正态分布随机数噪声(长度为100)
    输出:生成的图片,(1,28,28)
    中间过程:
        linear1: 100 -> 256
        linear2: 256 -> 512
        linear3: 512 -> 28*28
        reshape: 28x28 -> (1,28,28)
'''
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__() # super().__init__() 是调用父类的__init__函数
        self.model = nn.Sequential(nn.Linear(100,256),nn.ReLU(),
                                   nn.Linear(256,512),nn.ReLU(),
                                    # 最后一层用tanh激活,将数据压缩到-1到1
                                   nn.Linear(512,28*28),nn.Tanh())
    def forward(self,x):
        img = self.model(x)
        img = img.view(-1,28,28,1) # 得到的是28*28=784,把它reshape为 (批量,h,w,channel)
        return img

定义判别器

'''
    判别器
    输入:(1,28,28)的图片
    输出:二分类的概率值 用sigmoid压缩到0-1之间
    内容:
    判别器 推荐使用LeakyRelu,因为生成器难以训练,Relu的负值直接变成0没有梯度了
'''
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28,512),nn.LeakyReLU(),
            nn.Linear(512,256),nn.LeakyReLU(),
            nn.Linear(256,1),nn.Sigmoid(),
        )
    def forward(self,x):
        x = x.view(-1,28*28)
        x = self.model(x)
        return x

初始化模型,优化器及损失计算函数

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device) # 初始化并放到了相应的设备上
dis = Discriminator().to(device)
dis_optim = torch.optim.Adam(dis.parameters(),lr=0.0001)
gen_optim = torch.optim.Adam(gen.parameters(),lr=0.0001)
bce_loss = torch.nn.BCELoss()

画生成器生成的图的绘图函数

def gen_img_plot(model,epoch,test_input):
    prediction = model(test_input).detach().cpu().numpy() # 放在内存上 并转换为Numpy
    prediction = np.squeeze(prediction) # np.squeeze是一个numpy函数,删除数组中形状为1的维度
    fig = plt.figure(figsize=(4,4))
    for i in range(16): # 迭代这n张图片
        plt.subplot(4,4,i+1)
        plt.imshow((prediction[i] + 1) / 2) # 生成器生成的图片是-1到1之间的,无法绘图。通过 (原+1)/2把[-1,1]压缩到[0,1]
        plt.axis('off')
    plt.show()

显示图片的函数

def img_plot(img):
    img = np.squeeze(img) # np.squeeze是一个numpy函数,删除数组中形状为1的维度
    fig = plt.figure(figsize=(4,4))
    for i in range(16): # 迭代这n张图片
        plt.subplot(4,4,i+1)
        plt.imshow((img[i] + 1) / 2) # 生成器生成的图片是-1到1之间的,无法绘图。通过 (原+1)/2把[-1,1]压缩到[0,1]
        plt.axis('off')
    plt.show()

定义训练函数


def train(num_epoch,test_input):
    D_loss = []
    G_loss = []
    # 训练循环
    for epoch in range(num_epoch):
        d_epoch_loss = 0
        g_epoch_loss = 0
        count = len(dataloader) # 返回批次数
        for step,(img,_) in enumerate(dataloader): # _是标签数据,img是(批次,h,w),每次取的img形状为(64,1,28,28)
            # print(f'step={step},img.shape={img.shape}')
            # img_plot(img)
            img = img.to(device)
            size = img.size(0) # 得到一个批次的图片
            random_noise = torch.randn(size,100,device=device) # 生成器的输入

            '''一. 训练判别器'''
            '''用真实图片训练判别器'''
            dis_optim.zero_grad()
            real_output = dis(img) # 对判别取输入真实的图片,输出对真实图片的预测结果
            # 判别器在真实图像上的损失
            d_real_loss = bce_loss(real_output,
                                   # torch.ones_like(real_output) 创建一个根real_loss一样形状的全1数组,作为标签。
                                   torch.ones_like(real_output))
            d_real_loss.backward()

            '''用生成的图片训练判别器'''
            gen_img = gen(random_noise)
            # 因为此时是为了训练判别器,所以不能让生成器的梯度参与进来。所以用detach()取出无梯度的tensor
            fake_output = dis(gen_img.detach())
            d_fake_loss = bce_loss(fake_output,
                                   torch.zeros_like(fake_output))
            d_fake_loss.backward()
            d_loss = d_real_loss+d_fake_loss
            dis_optim.step() # 对参数进行优化

            '''二.训练生成器'''
            gen_optim.zero_grad()
            # 刚才是去掉生成器生成的图片的梯度,来训练判别器。此处不需要去掉梯度。让判别器进行判别
            fake_output = dis(gen_img)
            # 思想:目的是生成越来越逼真的图片瞒过判别器,让判别器判定生成的图片是真实的图片。
            # 实现方法:把判别器的结果输入到bce_loss,用1作为标签,看判别器把生成的图片判别为真的损失。
            g_loss = bce_loss(fake_output,
                              torch.ones_like(fake_output))
            g_loss.backward()
            gen_optim.step()

            # 计算一个epoch的损失
            with torch.no_grad(): #  禁止梯度计算和参数更新
                d_epoch_loss +=d_loss
                g_epoch_loss +=g_loss
        # 计算整体loss每个epoch的平均Loss
        with torch.no_grad(): #  禁止梯度计算和参数更新
            d_epoch_loss /= count
            g_epoch_loss /= count
            D_loss.append(d_epoch_loss)
            G_loss.append(g_epoch_loss)
            print('Epoch:', epoch+1)
            print(f'd_epoch_loss={d_epoch_loss}')
            print(f'g_epoch_loss={g_epoch_loss}')
            # 将16个长度为100的噪音输入到生成器并画图
            gen_img_plot(gen,test_input)

开始训练

'''开始计时'''
start_time = time.time()

'''开始训练'''
test_input = torch.randn(16,100,device=device) # 生成16个 长度为100的正太分布随机数。放到GPU中 作为输入
print(test_input)
num_epoch = 50
train(num_epoch,test_input)
# 保存训练50次的参数
torch.save(gen.state_dict(),'gen_weights.pth')
torch.save(dis.state_dict(),'dis_weights.pth')

'''计时结束'''
end_time = time.time()
run_time = end_time - start_time
# 将输出的秒数保留两位小数
if int(run_time)<60:
    print(f'{round(run_time,2)}s')
else:
    print(f'{round(run_time/60,2)}minutes')

结果可视化

GAN原理 & 代码解读,机器学习&amp;深度学习笔记,生成对抗网络,人工智能,神经网络

加载训练好的参数

gen.load_state_dict(torch.load('/opt/software/computer_vision/codes/My_codes/paper_codes/GAN/weights/gen_weights.pth'))

用训练好的生成器生成图片并画图

test_new_input = torch.randn(16,100,device=device) # 生成16个 长度为100的正太分布随机数。放到GPU中 作为输入
gen_img_plot(gen,test_new_input)

GAN原理 & 代码解读,机器学习&amp;深度学习笔记,生成对抗网络,人工智能,神经网络
GAN的生成是随机的,不同的噪声,生成不同的数字文章来源地址https://www.toymoban.com/news/detail-671195.html

到了这里,关于GAN原理 & 代码解读的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 深度学习体系结构——CNN, RNN, GAN, Transformers, Encoder-Decoder Architectures算法原理与应用

    卷积神经网络(CNN)是一种特别适用于处理具有网格结构的数据,如图像和视频的人工神经网络。可以将其视作一个由多层过滤器构成的系统,这些过滤器能够处理图像并从中提取出有助于进行预测的有意义特征。 设想你手头有一张手写数字的照片,你希望计算机能够识别出

    2024年04月28日
    浏览(50)
  • 计算机毕业设计--基于深度学习技术(Transformer、GAN)的破损图像修复算法(含有Github代码)

    本篇文章是针对破损照片的修复。如果你想对老照片做一些色彩增强,清晰化,划痕修复,划痕检测,请参考我的另一篇CSDN作品 老照片(灰白照片)划痕修复+清晰化+色彩增强的深度学学习算法设计与实现 Abstract 在图像获取和传输过程中,往往 伴随着各种形式的损坏 ,降低

    2024年04月23日
    浏览(72)
  • 大数据机器学习深度解读ROC曲线:技术解析与实战应用

    机器学习和数据科学在解决复杂问题时,经常需要评估模型的性能。其中,ROC(Receiver Operating Characteristic)曲线是一种非常有用的工具,被广泛应用于分类问题中。该工具不仅在医学检测、信号处理中有着悠久的历史,而且在近年来的机器学习应用中也显得尤为关键。 ROC曲线

    2024年02月04日
    浏览(38)
  • 大数据机器学习深度解读决策树算法:技术全解与案例实战

    本文深入探讨了机器学习中的决策树算法,从基础概念到高级研究进展,再到实战案例应用,全面解析了决策树的理论及其在现实世界问题中的实际效能。通过技术细节和案例实践,揭示了决策树在提供可解释预测中的独特价值。 决策树算法是机器学习领域的基石之一,其强

    2024年02月04日
    浏览(48)
  • 机器视觉 多模态学习11篇经典论文代码以及解读

    此处整理了深度学习-机器视觉,最新的发展方向-多模态学习,中的11篇经典论文,整理了相关解读博客和对应的Github代码,看完此系列论文和博客,相信你能快速切入这个方向。每篇论文、博客或代码都有相关标签,一目了然,整理到这里了 webhub123 机器视觉 多模态学习

    2024年02月13日
    浏览(38)
  • 机器学习 & 深度学习编程笔记

    如果不加噪音就成了正常的线性函数了,所以要加噪音。 torch.normal(0, 0.01, y.shape)是一个用于生成服从正态分布的张量的函数。其中,0代表均值,0.01代表标准差,y.shape表示生成的张量的形状与y相同。具体而言,该函数会生成一个张量,其元素值是从均值为0、标准差为0.01的正

    2024年02月16日
    浏览(101)
  • 适合小白学习的GAN(生成对抗网络)算法超详细解读

    “GANs are \\\'the coolest idea in deep learning in the last 20 years.\\\' ” --Yann LeCunn, Facebook’s AI chief   今天我们就来认识一下这个传说中被誉为过去20年来深度学习中最酷的想法——GAN。  GAN之父的主页: http://www.iangoodfellow.com/  GAN论文地址: https://arxiv.org/pdf/1406.2661.pdf 目录 前言  📢一、

    2024年02月02日
    浏览(46)
  • 深度学习分割任务——Unet++分割网络代码详细解读(文末附带作者所用code)

    ​ 分成语义分割和实例分割 语义分割:语义分割就是把每个像素都打上标签(这个像素点是人,树,背景等)(语义分割只区分类别,不区分类别中具体单位)[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 实例分割:实例分割不光要区别类别,还

    2024年02月04日
    浏览(46)
  • 李宏毅《机器学习 深度学习》简要笔记(一)

    一、线性回归中的模型选择 上图所示: 五个模型,一个比一个复杂,其中所包含的function就越多,这样就有更大几率找到一个合适的参数集来更好的拟合训练集。所以,随着模型的复杂度提高,train error呈下降趋势。 上图所示: 右上角的表格中分别体现了在train和test中的损

    2024年01月25日
    浏览(42)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包