模型架构
代码
数据准备
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import torch.nn as nn
import torch
# 创建文件夹存放图片
os.makedirs("data", exist_ok=True)
transform = transforms.Compose([
transforms.ToTensor(), #它会进行0-1归一化,h方向/h,w方向/w。 然后将图片格式转换为 (channel,h,w)
transforms.Normalize(0.5,0.5),#把数据归一化为均值为0.5,方差为0.5,图像的数值范围变成-1到1
])
# 下载训练数据后对图片进行transform里的toTensor和用均值方差归一化
train_dataset = datasets.MNIST('data',
train=True,
transform=transform,
download=True)
dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=64,shuffle=True)
定义生成器
'''
输入:正态分布随机数噪声(长度为100)
输出:生成的图片,(1,28,28)
中间过程:
linear1: 100 -> 256
linear2: 256 -> 512
linear3: 512 -> 28*28
reshape: 28x28 -> (1,28,28)
'''
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__() # super().__init__() 是调用父类的__init__函数
self.model = nn.Sequential(nn.Linear(100,256),nn.ReLU(),
nn.Linear(256,512),nn.ReLU(),
# 最后一层用tanh激活,将数据压缩到-1到1
nn.Linear(512,28*28),nn.Tanh())
def forward(self,x):
img = self.model(x)
img = img.view(-1,28,28,1) # 得到的是28*28=784,把它reshape为 (批量,h,w,channel)
return img
定义判别器
'''
判别器
输入:(1,28,28)的图片
输出:二分类的概率值 用sigmoid压缩到0-1之间
内容:
判别器 推荐使用LeakyRelu,因为生成器难以训练,Relu的负值直接变成0没有梯度了
'''
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.model = nn.Sequential(
nn.Linear(28*28,512),nn.LeakyReLU(),
nn.Linear(512,256),nn.LeakyReLU(),
nn.Linear(256,1),nn.Sigmoid(),
)
def forward(self,x):
x = x.view(-1,28*28)
x = self.model(x)
return x
初始化模型,优化器及损失计算函数
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device) # 初始化并放到了相应的设备上
dis = Discriminator().to(device)
dis_optim = torch.optim.Adam(dis.parameters(),lr=0.0001)
gen_optim = torch.optim.Adam(gen.parameters(),lr=0.0001)
bce_loss = torch.nn.BCELoss()
画生成器生成的图的绘图函数
def gen_img_plot(model,epoch,test_input):
prediction = model(test_input).detach().cpu().numpy() # 放在内存上 并转换为Numpy
prediction = np.squeeze(prediction) # np.squeeze是一个numpy函数,删除数组中形状为1的维度
fig = plt.figure(figsize=(4,4))
for i in range(16): # 迭代这n张图片
plt.subplot(4,4,i+1)
plt.imshow((prediction[i] + 1) / 2) # 生成器生成的图片是-1到1之间的,无法绘图。通过 (原+1)/2把[-1,1]压缩到[0,1]
plt.axis('off')
plt.show()
显示图片的函数
def img_plot(img):
img = np.squeeze(img) # np.squeeze是一个numpy函数,删除数组中形状为1的维度
fig = plt.figure(figsize=(4,4))
for i in range(16): # 迭代这n张图片
plt.subplot(4,4,i+1)
plt.imshow((img[i] + 1) / 2) # 生成器生成的图片是-1到1之间的,无法绘图。通过 (原+1)/2把[-1,1]压缩到[0,1]
plt.axis('off')
plt.show()
定义训练函数
def train(num_epoch,test_input):
D_loss = []
G_loss = []
# 训练循环
for epoch in range(num_epoch):
d_epoch_loss = 0
g_epoch_loss = 0
count = len(dataloader) # 返回批次数
for step,(img,_) in enumerate(dataloader): # _是标签数据,img是(批次,h,w),每次取的img形状为(64,1,28,28)
# print(f'step={step},img.shape={img.shape}')
# img_plot(img)
img = img.to(device)
size = img.size(0) # 得到一个批次的图片
random_noise = torch.randn(size,100,device=device) # 生成器的输入
'''一. 训练判别器'''
'''用真实图片训练判别器'''
dis_optim.zero_grad()
real_output = dis(img) # 对判别取输入真实的图片,输出对真实图片的预测结果
# 判别器在真实图像上的损失
d_real_loss = bce_loss(real_output,
# torch.ones_like(real_output) 创建一个根real_loss一样形状的全1数组,作为标签。
torch.ones_like(real_output))
d_real_loss.backward()
'''用生成的图片训练判别器'''
gen_img = gen(random_noise)
# 因为此时是为了训练判别器,所以不能让生成器的梯度参与进来。所以用detach()取出无梯度的tensor
fake_output = dis(gen_img.detach())
d_fake_loss = bce_loss(fake_output,
torch.zeros_like(fake_output))
d_fake_loss.backward()
d_loss = d_real_loss+d_fake_loss
dis_optim.step() # 对参数进行优化
'''二.训练生成器'''
gen_optim.zero_grad()
# 刚才是去掉生成器生成的图片的梯度,来训练判别器。此处不需要去掉梯度。让判别器进行判别
fake_output = dis(gen_img)
# 思想:目的是生成越来越逼真的图片瞒过判别器,让判别器判定生成的图片是真实的图片。
# 实现方法:把判别器的结果输入到bce_loss,用1作为标签,看判别器把生成的图片判别为真的损失。
g_loss = bce_loss(fake_output,
torch.ones_like(fake_output))
g_loss.backward()
gen_optim.step()
# 计算一个epoch的损失
with torch.no_grad(): # 禁止梯度计算和参数更新
d_epoch_loss +=d_loss
g_epoch_loss +=g_loss
# 计算整体loss每个epoch的平均Loss
with torch.no_grad(): # 禁止梯度计算和参数更新
d_epoch_loss /= count
g_epoch_loss /= count
D_loss.append(d_epoch_loss)
G_loss.append(g_epoch_loss)
print('Epoch:', epoch+1)
print(f'd_epoch_loss={d_epoch_loss}')
print(f'g_epoch_loss={g_epoch_loss}')
# 将16个长度为100的噪音输入到生成器并画图
gen_img_plot(gen,test_input)
开始训练
'''开始计时'''
start_time = time.time()
'''开始训练'''
test_input = torch.randn(16,100,device=device) # 生成16个 长度为100的正太分布随机数。放到GPU中 作为输入
print(test_input)
num_epoch = 50
train(num_epoch,test_input)
# 保存训练50次的参数
torch.save(gen.state_dict(),'gen_weights.pth')
torch.save(dis.state_dict(),'dis_weights.pth')
'''计时结束'''
end_time = time.time()
run_time = end_time - start_time
# 将输出的秒数保留两位小数
if int(run_time)<60:
print(f'{round(run_time,2)}s')
else:
print(f'{round(run_time/60,2)}minutes')
结果可视化
文章来源:https://www.toymoban.com/news/detail-671195.html
加载训练好的参数
gen.load_state_dict(torch.load('/opt/software/computer_vision/codes/My_codes/paper_codes/GAN/weights/gen_weights.pth'))
用训练好的生成器生成图片并画图
test_new_input = torch.randn(16,100,device=device) # 生成16个 长度为100的正太分布随机数。放到GPU中 作为输入
gen_img_plot(gen,test_new_input)
GAN的生成是随机的,不同的噪声,生成不同的数字文章来源地址https://www.toymoban.com/news/detail-671195.html
到了这里,关于GAN原理 & 代码解读的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!