简介:GAN生成对抗网络本质上是一种思想,其依靠神经网络能够拟合任意函数的能力,设计了一种架构来实现数据的生成。
原理:GAN的原理就是最小化生成器Generator的损失,但是在最小化损失的过程中加入了一个约束,这个约束就是使Generator生成的数据满足我们指定数据的分布,GAN的巧妙之处在于使用一个神经网络(鉴别器Discriminator)来自动判断生成的数据是否符合我们所需要的分布。
实现细节:
一:
准备好我们想要让生成器生成的数据类型,比如MINIST手写数字集,包含1-10十个数字,一共60000张图片。生成器的目的就是学习这个数据集的分布。
二,
定义一个生成器,用于判别一张图片是实际的还是生成器生成的,当生成器完美学习得到数据分布之后,鉴别器可能就分不清图片是生成器的还是实际的,这样的话生成器就能生成我们想要的图片了。
生成器的训练过程为:实际数据输出结果1,生成数据输出结果为0,目的是学会区分真假数据,相当于提供一个约束,使生成数据符合指定分布。当鉴别生成器的数据分布时,只需要更新鉴别器的参数权重,不能够通过计算图将生成器的参数进行更新。
三,
定义一个生成器,给定一个输入,他就能生成1-10里面的一个数字的图片。生成器的反向更新是根据鉴别器的损失来确定(被约束进行反向更新)。生成器的网络权重参数是单独的,反向更新时,只需要更新计算图当中属于生成器部分的参数。
下面给出生成1-0-1-0数据格式的代码:文章来源:https://www.toymoban.com/news/detail-667857.html
# %% import torch import numpy import torch.nn as nn import matplotlib.pyplot as plt # %% def gennerate1010(): return torch.FloatTensor([numpy.random.uniform(0.9,1.1), numpy.random.uniform(0.,.1), numpy.random.uniform(0.9,1.1), numpy.random.uniform(0.0,.1)]) # %% def genneratexxxx(): return torch.rand(4) # %% class Discrimer(nn.Module): def __init__(self) -> None: father_obj = super(Discrimer,self) father_obj.__init__() self.create_model() self.counter = 0 self.progress = [] def create_model(self): self.model = nn.Sequential( nn.Linear(4,3), nn.Sigmoid(), nn.Linear(3,1), nn.Sigmoid(), ) self.loss_functon = nn.MSELoss() self.optimiser = torch.optim.SGD(self.parameters(),lr=0.01) def forward(self,x): return self.model(x) def train(self,x,targets): outputs = self.forward(x) loss = self.loss_functon(outputs,targets) self.counter += 1 if self.counter%10 == 0: self.progress.append(loss.item()) if self.counter%10000 == 0: print(self.counter) self.optimiser.zero_grad() loss.backward() self.optimiser.step() def plotprogress(self): plt.plot(self.progress,marker='*') plt.show() # %% class Gennerater(nn.Module): def __init__(self) -> None: father_obj = super(Gennerater,self) father_obj.__init__() self.create_model() self.counter = 0 self.progress = [] def create_model(self): self.model = nn.Sequential( nn.Linear(1,3), nn.Sigmoid(), nn.Linear(3,4), nn.Sigmoid(), ) # 这个优化器只能优化生成器部分的参数 self.optimiser = torch.optim.SGD(self.parameters(),lr=0.01) def forward(self,x): return self.model(x) def train(self,D,x,targets): g_outputs = self.forward(x) d_outputs = D.forward(g_outputs) # 使用鉴别器的loss函数,但是只更新生成器的参数,生成器的参数需要根据鉴别器的约束进行更新 loss = D.loss_functon(d_outputs,targets) self.counter += 1 if self.counter%10 == 0: self.progress.append(loss.item()) if self.counter%10000 == 0: print(self.counter) self.optimiser.zero_grad() loss.backward() self.optimiser.step() def plotprogress(self): plt.plot(self.progress,marker='*') plt.show() # %% D = Discrimer() # %% G = Gennerater() # %% for id in range(15000): # 喂入实际数据给鉴别器 D.train(gennerate1010(),torch.FloatTensor([1.])) # 喂入生成的数据,使用detach从计算图脱离,用于更新鉴别器,而生成器得不到更新 D.train(G.forward(torch.FloatTensor([0.5]).detach()),torch.FloatTensor([0.0])) G.train(D,torch.FloatTensor([0.5]),torch.FloatTensor([1.])) # %% D.plotprogress() # %% G.plotprogress() # %% G.forward(torch.FloatTensor([0.5]))
参考:PyTorch生成对抗网络编程文章来源地址https://www.toymoban.com/news/detail-667857.html
到了这里,关于GAN(生成对抗网络)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!