AutoDecoder
自动解码器(AD)是论文"DeepSDF: Learning Continuous Signed Distance Functions for Shape Representation" 中使用的一种方法,与传统编码-解码结构不同,AD无编码器,仅有一个解码器。解码器实现特征向量(隐向量)与图片之间的转换。
在训练过程中同时优化特征向量与神经网络参数。如果训练集有N张图片,特征向量长度为n,神经网络参数为m,那么待训练参数共有N*n+m个。
训练完成之后,任给一个特征向量,输入解码器,则可得到一张图片。
DeepSDF原文更为复杂,使用AD生成带符号的最小距离场,进而实现3D形状的生成。DeepSDF原文还有一个针对MNIST手写数据集的简单的案例,但是没有给出源代码,难以上手。笔者在Github上找到了一个具体实现代码https://github.com/alexeybokhovkin/DeepSDF,在此基础上作了一些修改完善,撰写了这篇博客,以作记录,希望对读者有所帮助。
代码
项目共有四个文件,dataset.py
用于定义数据集,evaluate.py
用于评估训练后的神经网络,network.py
用于定义神经网络结构,train.py
用于训练神经网络
1 dataset.py
用于导入MNIST数据集(或FashionMNIST)数据集,如果本地没有则会从互联网下载
import torchvision
from torch.utils.data.dataset import Dataset
# A wrapper dataset over MNIST to return images and indices
class DatasetMNIST(Dataset):
def __init__(self, root_dir, latent_size, transform=None):
mnist = torchvision.datasets.FashionMNIST(root=root_dir, train=True,download=True)
self.data = mnist.train_data.float()/255.0
def __len__(self):
return len(self.data)
def __getitem__(self, index):
image = self.data[index]
return image.flatten(), index
2 network.py
定义了一个全连接神经网络,输入特征向量,输出图片(展平为向量)。为了便于观察结果,输入的特征向量维数为2。需要注意的是,特征向量也作为被训练参数,共N*n个元素。
import torch
import torch.nn as nn
import torch.nn.init as init
# Autodecoder structure
class AD(nn.Module):
def __init__(self, image_size=784, z_dim=2, data_shape=60000):
super(AD, self).__init__()
self.decoder = nn.Sequential(
nn.Linear(z_dim, 128),
nn.ReLU(True),
nn.Linear(128, 256),
nn.ReLU(True),
nn.Linear(256, 512),
nn.ReLU(True), nn.Linear(512, 28 * 28), nn.Tanh())
self.latent_vectors = nn.Parameter(torch.FloatTensor(data_shape, z_dim))
init.xavier_normal(self.latent_vectors)
def forward(self, ind):
x = self.latent_vectors[ind]
return self.decoder(x)
def predict(self, x):
return self.decoder(x)
3 train.py
使用Adams训练神经网络,batch_size=128
,num_epochs=250
。训练只需几分钟,完成后保存为model.pth
文件,以供调用
import os
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
from dataset import DatasetMNIST
from network import AD
from tqdm import tqdm
# Hyper-parameters
image_width = 28
image_size = image_width*image_width
h_dim = 512
num_epochs = 250
batch_size = 128
learning_rate = 1e-3
latent_size = 2
if __name__ == "__main__":
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# Create a directory if not exists
sample_dir = 'samples'
os.makedirs(sample_dir, exist_ok=True)
dataset = DatasetMNIST(root_dir='./data', latent_size=latent_size)
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True)
model = AD(image_size=image_size, z_dim=latent_size, data_shape=60000).cuda()
# recusntruction loss
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
model.train()
for epoch in range(num_epochs):
tq = tqdm(total=len(data_loader))
tq.set_description('Epoch {}'.format(epoch))
for i, (x, ind) in enumerate(data_loader):
# Forward pass
x = x.cuda()
x_reconst = model(ind)
loss = criterion(x_reconst, x)
optimizer.zero_grad()
loss.backward()
optimizer.step()
tq.update()
tq.set_postfix(loss='{:.3f}'.format(loss.item()))
if epoch%5 == 0:
with torch.no_grad():
# Visualize 2D latent space
steps = 50
bound = 0.8
size = image_width
out = torch.zeros(size=(steps * size, steps * size))
for i, l1 in enumerate(np.linspace(-bound, bound, steps)):
for j, l2 in enumerate(np.linspace(-bound, bound, steps)):
vector = torch.tensor([l1, l2]).to(dtype=torch.float32).cuda()
out_ = model.predict(vector)
out[i * size:(i + 1) * size, j * size:(j + 1)
* size] = out_.view(size, size)
save_image(out, os.path.join(
sample_dir, 'latent_space-{}.png'.format(epoch + 1)))
# save model
torch.save(model, 'model.pth')
4 evaluate.py
加载训练好的模型,遍历特征向量[l1,l2],使用解码器生成对应的图像,保存为latent_space-eval.png
将所有训练集的特征向量[l1,l2]绘制在一张图上,保存为latent_space-distribution.png
最后将两张图片拼接得到latent_space-merged.png
import os
import torch
import numpy as np
from torchvision.utils import save_image
from train import image_width
if __name__ == "__main__":
# Create a directory if not exists
sample_dir = 'samples'
os.makedirs(sample_dir, exist_ok=True)
# 选择GPU或CPU
device = torch.device(
"cuda") if torch.cuda.is_available() else torch.device("cpu")
# 从文件加载已经训练完成的模型
model = torch.load('model.pth', map_location=device)
model.eval() # 设置模型为evaluation状态
print(model)
# Visualize 2D latent space
steps = 50
bound = 0.8
size = image_width
out_grid = torch.zeros(size=(steps * size, steps * size))
for i, l1 in enumerate(np.linspace(-bound, bound, steps)):
for j, l2 in enumerate(np.linspace(-bound, bound, steps)):
vector = torch.tensor([l1, l2]).to(dtype=torch.float32).cuda()
out_ = model.predict(vector)
out_grid[i * size:(i + 1) * size, j * size:(j + 1)
* size] = out_.view(size, size)
save_image(out_grid, os.path.join(
sample_dir, 'latent_space-eval.png'))
out_dist = torch.zeros(size=(steps * size, steps * size))
latent_vectors_scaled=model.latent_vectors.cpu().detach().numpy()
latent_vectors_scaled=np.clip(latent_vectors_scaled, -bound+0.005, bound-0.005)
latent_vectors_scaled = ((latent_vectors_scaled+bound)/(2.0*bound)*steps*size*1.0)
for i in range(len(latent_vectors_scaled)):
l1=round(latent_vectors_scaled[i][0])
l2=round(latent_vectors_scaled[i][1])
out_dist[l1, l2]=1.0
out_dist[l1-1, l2]=1.0
out_dist[l1+1, l2]=1.0
out_dist[l1, l2-1]=1.0
out_dist[l1-1, l2-1]=1.0
out_dist[l1+1, l2-1]=1.0
out_dist[l1, l2+1]=1.0
out_dist[l1-1, l2+1]=1.0
out_dist[l1+1, l2+1]=1.0
save_image(out_dist, os.path.join(
sample_dir, 'latent_space-distribution.png'))
out_merged = torch.cat((out_grid, out_dist), dim=1)
save_image(out_merged, os.path.join(
sample_dir, 'latent_space-merged.png'))
print(model.latent_vectors.max(), model.latent_vectors.min())
结果
训练前遍历特征向量绘制得到解码后的图片:
epoch=16, 遍历特征向量绘制得到解码后的图片:
epoch=101, 遍历特征向量绘制得到解码后的图片:
epoch=250, 遍历特征向量绘制得到解码后的图片:
训练样本分布如下:
文章来源:https://www.toymoban.com/news/detail-786417.html
结论
(1) AD事实上是一种特征提取方法
,本文从数据集中提取了一个2D特征,在2D平面内重构出了原始数据集。在实际使用中,特征向量是一个高维向量,效果会更好。
(2) 特征向量的分布近似于一个正态分布,但是不同类别之间存在鸿沟
(3)AD的本质是对图片数据进行压缩,图片公有信息蕴含于神经网络参数中,个体信息蕴含于特征向量。文章来源地址https://www.toymoban.com/news/detail-786417.html
到了这里,关于使用AutoDecoder自动解码器实现简单MNIST特征向量提取的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!