手把手教你训练一个VAE生成模型一生成手写数字

这篇具有很好参考价值的文章主要介绍了手把手教你训练一个VAE生成模型一生成手写数字。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

1 VAE简介

VAE(Variational Autoencoder)变分自编码器是一种使用变分推理的自编码器,其主要用于生成模型。 VAE 的编码器是模型的一部分,用于将输入数据压缩成潜在表示,即编码。

VAE 编码器包括两个子网络:一个是推断网络,另一个是生成网络。推断网络输入原始输入数据,并输出两个参数:均值和方差。这些参数用于描述编码的潜在分布。生成网络输入潜在编码并输出重构的输入数据。

为了从输入数据中学习潜在表示,VAE 采用变分推理的方法。变分推理是一种通过最大化对数似然来学习潜在分布的方法。首先,我们假设潜在分布为高斯分布,然后通过最大化对数似然估计参数。这些参数(均值和方差)由推断网络学习。

对于给定的输入数据,推断网络学习参数,然后使用这些参数计算潜在分布。我们从潜在分布中采样一个编码,然后将它输入生成网络。生成网络使用这个编码重构原始输入数据。最后,我们使用重构数据和原始数据之间的差异来计算损失。这个损失用来衡量 VAE 对原始输入数据的重构精度。

最后,VAE 编码器的目的是学习一种潜在表示,使得重构输入数据的损失最小。这个潜在表示可以用于生成新的数据,或者用于其他目的,如数据压缩或降维。
总的来说,VAE 编码器是一种使用变分推理的自编码器,用于学习潜在表示,并使用这个表示重构输入数据。

2 生成手写数字实践

VAE 生成模型的最简单例子可能是用于生成手写数字的模型。手写数字数据集通常被编码为 28x28 像素的灰度图像。我们可以使用 VAE 来学习生成新的手写数字图像。

# 加载 MNIST 数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
mnist = datasets.MNIST(root='.', download=True, transform=transform)

首先,我们需要定义 VAE 的网络结构。这个 VAE 的编码器可能包括一个卷积层,用于提取图像特征,以及一个全连接层,用于将卷积层的输出压缩成潜在表示。编码器的输出是两个参数:均值和方差。

# 定义 VAE 编码器
class VAEEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, latent_size * 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        mu, log_var = x.split(latent_size, dim=1)
        return mu, log_var

然后,我们可以使用这些参数计算潜在分布,并从中采样潜在编码。潜在编码是我们用于生成新图像的输入。我们的 VAE 还包括一个解码器,用于将潜在编码解码为图像。解码器可能包括一个全连接层和一个卷积层,用于将潜在编码转换为图像。

# 定义 VAE 解码器
class VAEDecoder(nn.Module):
    def __init__(self, latent_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(latent_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x

最后,我们使用重构图像和原始图像之间的差异来计算 VAE 的损失。我们可以使用这个损失来训练 VAE,以使得重构图像尽可能接近原始图像。当我们的 VAE 训练完成后,我们就可以使用它来生成新的手写数字图像。

# 定义 VAE 损失函数
def vae_loss(recon, x, mu, log_var):
    recon_loss = nn.BCELoss(reduction='sum')(recon, x)
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return recon_loss + kl_loss

为了生成新的图像,我们可以从 VAE 的潜在分布中采样一个潜在编码,然后将它输入 VAE 的解码器。解码器会使用这个编码生成一个新的图像。我们可以使用不同的潜在编码生成不同的图像,从而生成一系列新的手写数字图像。

 # 使用 VAE 生成图像
    with torch.no_grad():
        z = torch.randn(1, latent_size)
        image = model.decoder(z).view(28, 28)
        image = image.detach().numpy()
        plt.imshow(image, cmap='gray')
        plt.show() 

这是一个 VAE 生成模型的最简单例子。 VAE 可以用于生成各种各样的数据,包括图像、文本、音频和视频。 VAE 的更复杂的例子可能包括更复杂的网络结构、更多的层和更多的参数。

下面是使用 PyTorch 实现 VAE 生成手写数字的完整代码:

# VAE.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# 定义 VAE 编码器
class VAEEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, latent_size * 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        mu, log_var = x.split(latent_size, dim=1)
        return mu, log_var

# 定义 VAE 解码器
class VAEDecoder(nn.Module):
    def __init__(self, latent_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(latent_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x

# 定义 VAE 模型
class VAE(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super().__init__()
        self.encoder = VAEEncoder(input_size, hidden_size, latent_size)
        self.decoder = VAEDecoder(latent_size, hidden_size, input_size)

    def forward(self, x):
        mu, log_var = self.encoder(x)
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        z = mu + std * eps
        recon = self.decoder(z)
        return recon, mu, log_var

# 定义 VAE 损失函数
def vae_loss(recon, x, mu, log_var):
    recon_loss = nn.BCELoss(reduction='sum')(recon, x)
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return recon_loss + kl_loss

# 加载 MNIST 数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
mnist = datasets.MNIST(root='.', download=True, transform=transform)

# 定义训练参数
batch_size = 64
lr = 1e-3
num_epochs = 20

# 定义数据加载器
data_loader = DataLoader(mnist, batch_size=batch_size, shuffle=True) # shuffle=True 打乱数据

# 定义模型、优化器和损失函数
# 定义 VAE 模型
input_size = 28 * 28
hidden_size = 256
latent_size = 64
model = VAE(input_size, hidden_size, latent_size)

# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=lr)

if __name__ == '__main__': # 仅在当前文件中运行时才执行以下代码
    # 训练 VAE 模型
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for x, _ in data_loader:
            x = x.view(-1, input_size)
            recon, mu, log_var = model(x)
            loss = vae_loss(recon, x, mu, log_var)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f'Epoch {epoch+1} loss: {epoch_loss / len(mnist):.4f}')              

    # 使用 VAE 生成图像
    with torch.no_grad():
        z = torch.randn(1, latent_size)
        image = model.decoder(z).view(28, 28)
        image = image.detach().numpy()
        plt.imshow(image, cmap='gray')
        plt.show() 

    # 保存模型
    torch.save(model.state_dict(), 'vae.pth')

3 调用生成模型生成指定数字

上面我们已经训练好了 VAE 模型,如果想使用该模型生成指定的数字,则不需要再次训练模型。我们可以直接使用训练好的模型,通过指定的 latent variables 生成想要的数字。

要做到这一点,需要按照以下步骤操作:

  1. 选择一个你想要生成的数字的图像作为样本,如:mnist [9][0]=4, [7][0]=3, [0][0]=5
  2. 使用 VAE 的编码器将该图像编码为 latent variables
  3. 将生成的 latent variables 作为输入传递给 VAE 的解码器,生成你想要的数字图像

下面是实现上述操作的示例代码:

在另一个文件 generate.py 中调用上面已经训练好的模型:

# generate.py 
import torch
import matplotlib.pyplot as plt
from VAE import model, input_size, mnist # 从 VAE.py 中导入模型、输入大小和 MNIST 数据集

# 加载已训练好的模型
model.load_state_dict(torch.load('vae.pth'))

# 选择mnist的样本图像 
sample_image = mnist[0][0] # mnist[0][0]是数字5的数据集

# 使用 VAE 的编码器将样本图像编码为 latent variables
mu, log_var = model.encoder(sample_image.view(-1, input_size))

# 将生成的 latent variables 作为输入传递给 VAE 的解码器,生成数字图像
generated_image = model.decoder(mu).view(28, 28)

# 显示原始图像和生成的图像
plt.subplot(1, 2, 1)
plt.title('Original Image')
plt.imshow(sample_image.view(28, 28), cmap='gray')
plt.subplot(1, 2, 2)
plt.title('Generated Image')
plt.imshow(generated_image.detach().numpy(), cmap='gray')
plt.show()

在上面的代码中,使用了 MNIST 数据集的第0个样本图像作为输入,所以模型生成的数字应该是数据集中第一个样本的数字,5。如果我们想生成不同的数字,可以使用不同的样本图像,例如 mnist[1][0],mnist[2][0] 等。

上面首先使用 VAE 的编码器将样本图像编码为 latent variables,然后使用 VAE 的解码器生成数字图像,再使用model.load_state_dict() 加载已保存的模型。最后,使用已加载的模型生成数字图像并显示。效果如下图:
手把手教你训练一个VAE生成模型一生成手写数字
上面模型的生成性能可能不是最好的,如果我们想改变 VAE 模型的表现,例如生成更加细腻、清晰的图像,则可能需要再次训练模型。我们可以通过调整训练参数,例如批次大小、学习率等来实现。

此外,我们还可以尝试改变 VAE 模型的结构,例如增加或减少网络层的数量,或者改变每一层的单元数量来提高模型的表现。这需要对深度学习和神经网络有较深的理解,并且可能需要多次尝试和调整才能找到最优的网络结构。

为了提升生成模型的性能,我们可以尝试以下操作:

  • 增加编码器和解码器的层数,以增加模型的复杂度。
  • 使用更复杂的激活函数,例如 LeakyReLU 或 ELU。
  • 使用更多的训练数据,例如从其他数据集中收集更多的数据。
  • 尝试使用不同的优化器,例如 RMSProp 或 Adamax。
  • 调整学习率,例如适当降低学习率以避免过拟合。
  • 使用数据增强,例如随机旋转、翻转或缩放图像来增加训练数据的多样性。

欢迎关注,感谢支持!文章来源地址https://www.toymoban.com/news/detail-459496.html

到了这里,关于手把手教你训练一个VAE生成模型一生成手写数字的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 手把手教你如何使用SpringBoot3打造一个个性化的代码生成器

    代码基于SpringBoot3、Vue3、highlight实现自定义代码生成功能 SpringBoot3.x、MySQL8、MyBatisPlus3.5.x、velocity2.x、SpringSecurity6.x、Vue3、TypeScript、highlight demo所需要的依赖及其对应版本号 pom 配置文件 这里是最基础的MySQL的配置信息 application 1.1、代码生成器源码目录 这里是代码生成器的源

    2024年01月19日
    浏览(34)
  • YOLOv5入门实践(5)——从零开始,手把手教你训练自己的目标检测模型(包含pyqt5界面)

      通过前几篇文章,相信大家已经学会训练自己的数据集了。本篇是YOLOv5入门实践系列的最后一篇,也是一篇总结,我们再来一起按着 配置环境--标注数据集--划分数据集--训练模型--测试模型--推理模型 的步骤,从零开始,一起实现自己的目标检测模型吧! 前期回顾: YOLO

    2023年04月26日
    浏览(45)
  • 超详细AI二维码制作教程:手把手教你如何用Stable Diffusion 生成一个创意二维码?

    AI已来,未来已来! 来势汹汹的人工智能,如同创世纪的洪水,正在全世界的范围内引发一场史无前例的科技革命。AI正在改变世界!而我们正是这场巨变的见证者。 今天我们要介绍的内容就是:如何利用AI工具Stable Diffusion,生成你的专属创意二维码? (下文包含详细图文教

    2024年02月16日
    浏览(49)
  • 手把手教你使用Segformer训练自己的数据

    使用Transformer进行语义分割的简单高效设计。 将 Transformer 与轻量级多层感知 (MLP) 解码器相结合,表现SOTA!性能优于SETR、Auto-Deeplab和OCRNet等网络 相比于ViT,Swin Transfomer计算复杂度大幅度降低,具有输入图像大小线性计算复杂度。Swin Transformer随着深度加深,逐渐合并图像块来

    2024年01月20日
    浏览(50)
  • YOLOV7训练自己的数据集以及训练结果分析(手把手教你)

    YOLOV7训练自己的数据集整个过程主要包括:环境安装----制作数据集----参数修改----模型测试----模型推理 labelme标注的数据格式是VOC,而YOLOv7能够直接使用的是YOLO格式的数据,因此下面将介绍如何将自己的数据集转换成可以直接让YOLOv7进行使用。 1. 创建数据集 在data目录下新建

    2023年04月20日
    浏览(38)
  • 手把手教你如何使用YOLOV5训练自己的数据集

    YOLOV5是目前最火热的目标检测算法之一。YOLOV5为一阶段检测算法因此它的速度非常之快。可以在复杂场景中达到60祯的实时检测频率。 接下来本文将详细的讲述如何使用YOLOV5去训练自己的数据集 YOLOV5中使用了Tensorboard和Wandb来可视化训练,其中Wandb配置可以看这篇文章: Wand

    2024年02月05日
    浏览(43)
  • 如何运用yolov5训练自己的数据(手把手教你学yolo)

    在这篇博文中,我们对YOLOv5模型进行微调,用于自定义目标检测的训练和推理。 深度学习领域在2012年开始快速发展。在那个时候,这个领域还比较独特,编写深度学习程序和软件的人要么是深度学习实践者,要么是在该领域有丰富经验的研究人员,或者是具备优秀编码技能

    2024年02月07日
    浏览(68)
  • YOLOv5入门实践(4)——手把手教你训练自己的数据集

      在上一篇文章中我们介绍了如何划分数据集,划分好之后我们的前期准备工作就已经全部完成了,下面开始训练自己的数据集吧! 前期回顾: YOLOv5入门实践(1)——手把手带你环境配置搭建 YOLOv5入门实践(2)——手把手教你利用labelimg标注数据集

    2024年04月10日
    浏览(39)
  • 手把手教你用YOLOv5算法训练数据和检测目标(不会你捶我)

    本人从一个小白,一路走来,已能够熟练使用YOLOv5算法来帮助自己解决一些问题,早就想分析一下自己的学习心得,一直没有时间,最近工作暂时告一段落,今天抽空写点东西,一是为自己积累一些学习笔记,二是可以为一些刚接触YOLOv5算法的小白们提供一些参考,希望大家

    2024年02月01日
    浏览(41)
  • 手把手教你搭建一个Minecraft 服务器

    这次,我们教大家如何搭建一个我的世界服务器 首先,我们来到这个网站 MCVersions.net - Minecraft Versions Download List MCVersions.net offers an archive of Minecraft Client and Server jars to download, for both current and old releases! https://mcversions.net/   在这里,我们点击对应的版本,从左到右依次是稳定版

    2024年02月09日
    浏览(37)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包