Pytorch Advanced(二) Variational Auto-Encoder

这篇具有很好参考价值的文章主要介绍了Pytorch Advanced(二) Variational Auto-Encoder。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

自编码说白了就是一个特征提取器,也可以看作是一个降维器。下面找了一张很丑的图来说明自编码的过程。

Pytorch Advanced(二) Variational Auto-Encoder,deep learning,pytorch,人工智能,python

自编码分为压缩和解码两个过程。从图中可以看出来,压缩过程就是将一组数据特征进行提取, 得到更深层次的特征。解码的过程就是利用之前的深层次特征再还原成为原来的数据特征。那么如何保证从压缩到解码两部分,原数据和解码数据保持一致呢?这就是要训练的过程。

如何理解降维?如果压缩的过程是卷积,维度可以根据核的个数变化,特征维度因此而改变。


import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3

dataset = torchvision.datasets.MNIST(root='../../data',
                                     train=True,
                                     transform=transforms.ToTensor(),
                                     download=True)

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=batch_size, 
                                          shuffle=True)

模型搭建:这里搭建的是一个变分自编码,Variational Autoencoder

那么变分自编码是为了解决什么问题呢? ——- 其主要思想还是希望学习隐层变量,并将其用来表示原始数据,但是它加另一个条件, 即隐层变量能学习原始数据的分布, 并反过来生产一些和原始数据相似的数据(这有啥用?—-可用于图片修复,让图片按训练集的数据分布变化)。

变分自编码 (Variational Autoencoder) 为了让隐层抓住输入数据特性, 而不是简单的输出数据=输入数据,他在隐层中加入随机噪声(单位高斯噪声)(这个过程也叫reparametrize),以确保隐层能较好抽象输入数据特点。

代码中怎么做的呢?

1、编码过程中我们保存了第二层线性层的输出。其中第二层包含有fc2与fc3两部分,他们是并联的。

2、给隐藏层加入随机噪声,作为解码的输入

class VAE(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(h_dim, z_dim)
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)
        
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc4(z))
        return F.sigmoid(self.fc5(h))
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decode(z)
        return x_reconst, mu, log_var

训练:由于训练中加入了噪声,所以损失值的结构也因此改变。一部分来源于解码内容核原内容的相似度,另一部分是kl_div,具体是什么意义需查看论文。

model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


# Start training
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(data_loader):
        # Forward pass
        x = x.to(device).view(-1, image_size)
        x_reconst, mu, log_var = model(x)
        
        # Compute reconstruction loss and kl divergence
        # For KL divergence, see Appendix B in VAE paper or http://yunjey47.tistory.com/43
        reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
        kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        # Backprop and optimize
        loss = reconst_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 10 == 0:
            print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" 
                   .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))
    
    with torch.no_grad():
        # Save the sampled images
        z = torch.randn(batch_size, z_dim).to(device)
        out = model.decode(z).view(-1, 1, 28, 28)
        save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))

        # Save the reconstructed images
        out, _, _ = model(x)
        x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
        save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))

模型训练完成了之后该如何使用这个模型呢?

model.decode()是一个解码的过程,我们给他一个随机的中间特征z就可以输出一个数字图片了。

z = torch.randn(1,z_dim).to(device)
out = model.decode(z)
plt.imshow(out.cpu().data.numpy().reshape(28,28),cmap='gray')
plt.show()

Pytorch Advanced(二) Variational Auto-Encoder,deep learning,pytorch,人工智能,python

有了随机的一张图片之后,我们把他完整的放入模型中,生成了和输入相似的一张图片,也没看出来是修复了图像......

out,_,_ = model(out) 
plt.imshow(out.cpu().data.numpy().reshape(28,28),cmap='gray')
plt.show()

Pytorch Advanced(二) Variational Auto-Encoder,deep learning,pytorch,人工智能,python文章来源地址https://www.toymoban.com/news/detail-733618.html

到了这里,关于Pytorch Advanced(二) Variational Auto-Encoder的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 变分自编码器(Variational AutoEncoder,VAE)

    说到编码器这块,不可避免地要讲起 AE (AutoEncoder)自编码器。它的结构下图所示: 据图可知,AE通过自监督的训练方式,能够将输入的原始特征通过编码encoder后得到潜在的特征编码,实现了自动化的特征工程,并且达到了降维和泛化的目的。而后通过对进行decoder后,我们

    2024年01月18日
    浏览(31)
  • 量子机器学习Variational Quantum Classifier (VQC)简介

    变分量子分类器(Variational Quantum Classifier,简称VQC)是一种利用量子计算技术进行分类任务的机器学习算法。它属于量子机器学习算法家族,旨在利用量子计算机的计算能力,潜在地提升经典机器学习方法的性能。 VQC的基本思想是使用一个量子电路,也称为变分量子电路,将

    2024年02月08日
    浏览(37)
  • 【论文导读】- Variational Graph Recurrent Neural Networks(VGRNN)

    Variational Graph Recurrent Neural Networks(VGRNN) 原文地址:Variational Graph Recurrent Neural Networks(VGRNN):https://arxiv.org/abs/1908.09710 源码: https://github.com/VGraphRNN/VGRNN Representation learning over graph structured data has been mostly studied in static graph settings while efforts for modeling dynamic graphs are still scant

    2024年02月08日
    浏览(47)
  • VARIATIONAL IMAGE COMPRESSION WITH A SCALE HYPERPRIOR

    VARIATIONAL IMAGE COMPRESSION WITH A SCALE HYPERPRIOR ABSTRACT We describe an end-to-end trainable model for image compression based on variational autoencoders .The model incorporates a hyperprior to effectively capture spatial dependencies in the latent representation.This hyperprior relates to side information, a concept universal to virtually all modern

    2024年02月13日
    浏览(29)
  • 使用VMD(Variational-Modal-Decomposition)分解多维信号

    很多博客都说的比较好了 vrcarva/vmdpy: Variational mode decomposition (VMD) in Python (github.com) https://github.com/vrcarva/vmdpy GitHub已经将matlab里的实现方法用Python写出来了 并且也已经有了大佬做解读 (10条消息) 变分模态分解(VMD)运算步骤及源码解读_comli_cn的博客-CSDN博客_变分模态分解 htt

    2024年02月01日
    浏览(33)
  • AIGC实战——变分自编码器(Variational Autoencoder, VAE)

    我们已经学习了如何实现自编码器,并了解了自编码器无法在潜空间中的空白位置处生成逼真的图像,且空间分布并不均匀,为了解决这些问题#

    2024年02月05日
    浏览(37)
  • Git advanced高级操作

    这篇文章是继Git概念介绍,常用命令与工作流程整理 配图_TranSad的博客-CSDN博客 之后的一些补充,学习总结一些额外Git操作中的比较常用的操作。所以这篇文章假设你已经有了前面的基础,我就直接说一些没有提到过的部分。 在Git中我们通常把HEAD当成是指向当前分支的指针

    2024年02月04日
    浏览(32)
  • 深度刨析指针Advanced 2

    作者主页 :paper jie的博客_CSDN博客-C语言,算法详解领域博主 本文作者 :大家好,我是paper jie,感谢你阅读本文,欢迎一建三连哦。 本文录入于 《系统解析C语言》专栏,本专栏是针对于大学生,编程小白精心打造的。笔者用重金(时间和精力)打造,将C语言基础知识一网打尽

    2024年02月09日
    浏览(46)
  • Advanced Solidity初学者教程

    目录 Advanced Solidity 引言: 1. 数学和算术 2. 时间和时间单位 3. 结构体 4. 修饰器 5. 枚举 6. 继承 7. 抽象合约 8. 接口 9. 库 10. 存储位置 Advanced Solidity(高级Solidity)是一种区块链编程语言Solidity的深入应用,通常用于构建智能合约和去中心化应用(DApps)。它涉及复杂的编程概念

    2024年04月28日
    浏览(55)
  • 西门子 PLCSim Advanced 初步入门

    PLCSim Advanced 是西门子为S7 1500推出的高级仿真模拟工具,支持 TCPIP网络通讯,4.0SP1 版本支持模拟S71500, S71500R/H,ET200SP ET200 PRO等CPU仿真。不支持S71200, SoftPLC 除基本编程运算外支持的仿真功能包括: WebServer, OPC UA, S7通讯, 开放式通讯, 与真实的CPU, 触控屏,WINCC等HMI设备; 不支

    2024年02月06日
    浏览(55)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包