前言
鸽了好久没更了,主要是刚入学学业压力还蛮大,挺忙的,没时间总结啥东西。
接下来就要好好搞科研啦。先来学习一篇diffusion的经典之作Denoising Diffusion Probabilistic Models(DDPM)。(看完这篇可看它的改进版 IDDPM原理和代码剖析)
先不断前向加高斯噪声,这一步骤称为前向过程。然后就是利用模型不断预测加噪前的图片,从而还原出原图像。同时在学习时,deep_thoughts这个up的视频帮了我不少忙, 由衷感谢 54、Probabilistic Diffusion Model概率扩散模型理论与完整PyTorch代码详细解读, 推荐大家去观看他的视频。
没事多学点数学: 生成扩散模型漫谈:构建ODE的一般步骤(上)
DDPM重要公式
由于有些公式推导过程可能较长,我会把它放到【推导】部分,至于【提纲】部分则会列出简洁的重要的公式。
提纲
(1) 前向加噪
q
(
X
t
∣
X
t
−
1
)
=
N
(
X
t
;
1
−
β
t
X
t
−
1
,
β
t
I
)
q(X_t|X_{t-1}) = N(X_t;\sqrt{1-\beta_t}X_{t-1}, \beta_t I)
q(Xt∣Xt−1)=N(Xt;1−βtXt−1,βtI)
β t \beta_t βt 在 DDPM中是0到1的小数,并且满足 β 1 < β 2 < . . . < β T \beta_1 \lt \beta_2 \lt ... \lt \beta_T β1<β2<...<βT
(2)
q
(
X
t
∣
X
0
)
=
N
(
X
t
;
α
‾
t
X
0
,
(
1
−
α
‾
t
)
I
)
q(X_t|X_0) = N(X_t; \sqrt{\overline{\alpha}_t}X_0, (1-\overline{\alpha}_t)I)
q(Xt∣X0)=N(Xt;αtX0,(1−αt)I)
或者写为
X
t
=
α
‾
t
X
0
+
1
−
α
‾
t
ϵ
X_t = \sqrt{\overline{\alpha}_t}X_0+\sqrt{1-\overline{\alpha}_t}~\epsilon
Xt=αtX0+1−αt ϵ
其中,
α
t
\alpha_t
αt 定义为
1
−
β
t
1-\beta_t
1−βt, 不要问
α
t
+
β
t
\alpha_t + \beta_t
αt+βt 为何为1, 因为我们就是定义来的,定义
α
\alpha
α 只是为了让后续公式书写更加简洁。
(3) 后验的均值和方差
即
q
(
X
t
−
1
∣
X
t
,
X
0
)
q(X_{t-1}|X_t, X_0)
q(Xt−1∣Xt,X0) 的均值
μ
~
(
X
t
,
X
0
)
\widetilde{\mu}(X_t, X_0)
μ
(Xt,X0) 以及方差
β
~
t
\widetilde{\beta}_t
β
t 分别为
μ
~
(
X
t
,
X
0
)
=
α
‾
t
−
1
1
−
α
‾
t
X
0
+
α
t
(
1
−
α
‾
t
−
1
)
1
−
α
‾
t
X
t
\widetilde{\mu}(X_t, X_0) = \frac{\sqrt{\overline{\alpha}_{t-1}}}{1-\overline{\alpha}_t} X_0 + \frac{\sqrt{\alpha_t}(1-\overline{\alpha}_{t-1})}{1-\overline{\alpha}_{t}}X_t
μ
(Xt,X0)=1−αtαt−1X0+1−αtαt(1−αt−1)Xt
β
~
t
=
1
−
α
‾
t
−
1
1
−
α
‾
t
β
t
\widetilde{\beta}_t = \frac{1-\overline{\alpha}_{t-1}}{1-\overline{\alpha}_t} \beta_t
β
t=1−αt1−αt−1βt
另外根据公式 (2) , 利用
X
0
X_0
X0与
X
t
X_t
Xt的关系,有
X
0
=
1
α
‾
t
(
X
t
−
1
−
α
‾
t
ϵ
t
)
X_0=\frac{1}{\sqrt{\overline{\alpha}_t}}(X_t-\sqrt{1-\overline{\alpha}_t}~~\epsilon_t)
X0=αt1(Xt−1−αt ϵt), 带入到上式中,得到
μ
~
(
X
t
,
X
0
)
=
1
α
‾
t
(
X
t
−
β
t
1
−
α
‾
t
ϵ
t
)
\widetilde{\mu}(X_t, X_0) = \frac{1}{\sqrt{\overline{\alpha}_t}}(X_t-\frac{\beta_t}{\sqrt{1-\overline{\alpha}_t}}\epsilon_t)
μ
(Xt,X0)=αt1(Xt−1−αtβtϵt)
推导
主要参考 deep thoughts-关于第54期视频去噪概率扩散模型DDPM的更新版notebook分享
由于
DDPM 训练和采样
原论文总结得很好,直接抄下hh
DDPM S_curve数据集小demo
声明: 暂时找不到原始出处,这是根据一个开源项目来的。
如果有小伙伴知道原始出处在哪,请麻烦留言,我后续会补上。
同时声明该案例只用于学习!!
数据集加载
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_s_curve
import torch
s_curve,_ = make_s_curve(10**4,noise=0.1)
s_curve = s_curve[:,[0,2]]/10.0
print("shape of s:",np.shape(s_curve))
data = s_curve.T
fig,ax = plt.subplots()
ax.scatter(*data,color='blue',edgecolor='white');
ax.axis('off')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dataset = torch.Tensor(s_curve).float().to(device)
shape of s: (10000, 2)
确定超参数的值(important)
接下来主要是算出一些有用的量。
num_steps = 100
#制定每一步的beta
betas = torch.linspace(-6,6,num_steps).to(device)
betas = torch.sigmoid(betas)*(0.5e-2 - 1e-5)+1e-5
#计算alpha、alpha_prod、alpha_prod_previous、alpha_bar_sqrt等变量的值
alphas = 1-betas
alphas_prod = torch.cumprod(alphas,0)
alphas_prod_p = torch.cat([torch.tensor([1]).float().to(device),alphas_prod[:-1]],0)
alphas_bar_sqrt = torch.sqrt(alphas_prod)
one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)
assert alphas.shape==alphas_prod.shape==alphas_prod_p.shape==\
alphas_bar_sqrt.shape==one_minus_alphas_bar_log.shape\
==one_minus_alphas_bar_sqrt.shape
print("all the same shape",betas.shape)
噪声方案
#制定每一步的beta
betas = torch.linspace(-6,6,num_steps).to(device)
betas = torch.sigmoid(betas)*(0.5e-2 - 1e-5)+1e-5
注意一下很多变量都是列表。
α
t
=
1
−
β
t
\alpha_t = 1-\beta_t
αt=1−βt
alphas = 1-betas
α ‾ t = ∏ i = 0 n α i \overline{\alpha}_t = \prod_{i=0}^n \alpha_i αt=∏i=0nαi
alphas_prod = torch.cumprod(alphas, dim=0)
α t − 1 \alpha_{t-1} αt−1 在代码中为 alphas_prod_p
alphas_prod_p = torch.cat([torch.tensor([1]).float().to(device),alphas_prod[:-1]],0)
α ‾ t \sqrt{\overline{\alpha}_t} αt 在代码中为 alphas_bar_sqrt
alphas_bar_sqrt = torch.sqrt(alphas_prod)
l o g ( 1 − α ‾ t ) log(1-\overline{\alpha}_t) log(1−αt)
one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
one_minus_alphas_bar_sqrt对应 1 − α ‾ t \sqrt{1-\overline{\alpha}_t} 1−αt
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)
前向过程采样(important)
X t = α ‾ t X 0 + 1 − α ‾ t ϵ X_t = \sqrt{\overline{\alpha}_t}X_0+\sqrt{1-\overline{\alpha}_t}~\epsilon Xt=αtX0+1−αt ϵ
#计算任意时刻的x采样值,基于x_0和重参数化
def q_x(x_0,t):
"""可以基于x[0]得到任意时刻t的x[t]"""
noise = torch.randn_like(x_0).to(device)
alphas_t = alphas_bar_sqrt[t]
alphas_1_m_t = one_minus_alphas_bar_sqrt[t]
return (alphas_t * x_0 + alphas_1_m_t * noise)#在x[0]的基础上添加噪声
演示数据前向100步的结果
num_shows = 20
fig,axs = plt.subplots(2,10,figsize=(28,3))
plt.rc('text',color='black')
#共有10000个点,每个点包含两个坐标
#生成100步以内每隔5步加噪声后的图像
for i in range(num_shows):
j = i//10
k = i%10
q_i = q_x(dataset, torch.tensor([i*num_steps//num_shows]).to(device))#生成t时刻的采样数据
q_i = q_i.to('cpu')
axs[j,k].scatter(q_i[:,0],q_i[:,1],color='red',edgecolor='white')
axs[j,k].set_axis_off()
axs[j,k].set_title('$q(\mathbf{x}_{'+str(i*num_steps//num_shows)+'})$')
模型
import torch
import torch.nn as nn
class MLPDiffusion(nn.Module):
def __init__(self,n_steps,num_units=128):
super(MLPDiffusion,self).__init__()
self.linears = nn.ModuleList(
[
nn.Linear(2,num_units),
nn.ReLU(),
nn.Linear(num_units,num_units),
nn.ReLU(),
nn.Linear(num_units,num_units),
nn.ReLU(),
nn.Linear(num_units,2),
]
)
self.step_embeddings = nn.ModuleList(
[
nn.Embedding(n_steps,num_units),
nn.Embedding(n_steps,num_units),
nn.Embedding(n_steps,num_units),
]
)
def forward(self,x,t):
# x = x_0
for idx,embedding_layer in enumerate(self.step_embeddings):
t_embedding = embedding_layer(t)
x = self.linears[2*idx](x)
x += t_embedding
x = self.linears[2*idx+1](x)
x = self.linears[-1](x)
return x
损失函数(important)
由于DDPM中方差被设置为定值,因此这里只需要比较均值的loss, 又因为DDPM是预测噪声,因为只要比较后验的噪声和模型预测的噪声的MSE loss就可以指导模型进行训练了。
def diffusion_loss_fn(model,x_0,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,n_steps):
"""对任意时刻t进行采样计算loss"""
batch_size = x_0.shape[0]
#对一个batchsize样本生成随机的时刻t
t = torch.randint(0,n_steps,size=(batch_size//2,)).to(device)
t = torch.cat([t,n_steps-1-t],dim=0)
t = t.unsqueeze(-1)
#x0的系数
a = alphas_bar_sqrt[t]
#eps的系数
aml = one_minus_alphas_bar_sqrt[t]
#生成随机噪音eps
e = torch.randn_like(x_0).to(device)
#构造模型的输入
x = x_0 * a + e * aml
#送入模型,得到t时刻的随机噪声预测值
output = model(x,t.squeeze(-1))
#与真实噪声一起计算误差,求平均值
return (e - output).square().mean()
其中 X t = α ‾ t X 0 + 1 − α ‾ t ϵ X_t = \sqrt{\overline{\alpha}_t}X_0+\sqrt{1-\overline{\alpha}_t}~\epsilon Xt=αtX0+1−αt ϵ, 即为代码中的
x = x_0 * a + e * aml
逆过程采样(important)
p_sample_loop负责迭代式的调用p_sample, 是不断恢复图像的过程。
def p_sample_loop(model,shape,n_steps,betas,one_minus_alphas_bar_sqrt):
"""从x[T]恢复x[T-1]、x[T-2]|...x[0]"""
cur_x = torch.randn(shape).to(device)
x_seq = [cur_x]
for i in reversed(range(n_steps)):
cur_x = p_sample(model,cur_x,i,betas,one_minus_alphas_bar_sqrt)
x_seq.append(cur_x)
return x_seq
def p_sample(model,x,t,betas,one_minus_alphas_bar_sqrt):
"""从x[T]采样t时刻的重构值"""
t = torch.tensor([t]).to(device)
coeff = betas[t] / one_minus_alphas_bar_sqrt[t]
eps_theta = model(x,t)
mean = (1/(1-betas[t]).sqrt())*(x-(coeff*eps_theta))
z = torch.randn_like(x).to(device)
sigma_t = betas[t].sqrt()
sample = mean + sigma_t * z
return (sample)
这里
X
t
−
1
=
μ
~
+
β
t
z
X_{t-1} = \widetilde{\mu} + \sqrt{\beta_t} z
Xt−1=μ
+βtz (DDPM不学习方差,方差直接设置为
β
\beta
β, 所以标准差就是
β
t
\sqrt{\beta_t}
βt, 其中 z是随机生成的噪声)
μ
~
(
X
t
,
X
0
)
=
1
α
‾
t
(
X
t
−
β
t
1
−
α
‾
t
ϵ
t
)
\widetilde{\mu}(X_t, X_0) = \frac{1}{\sqrt{\overline{\alpha}_t}}(X_t-\frac{\beta_t}{\sqrt{1-\overline{\alpha}_t}}\epsilon_t)
μ
(Xt,X0)=αt1(Xt−1−αtβtϵt) 对应下面的代码
coeff = betas[t] / one_minus_alphas_bar_sqrt[t]
eps_theta = model(x,t)
mean = (1/(1-betas[t]).sqrt())*(x-(coeff*eps_theta))
one_minus_alphas_bar_sqrt对应 1 − α ‾ t \sqrt{1-\overline{\alpha}_t} 1−αt
模型训练
seed = 1234
print('Training model...')
batch_size = 128
dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True)
num_epoch = 4000
plt.rc('text',color='blue')
model = MLPDiffusion(num_steps)#输出维度是2,输入是x和step
model = model.cuda()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
for t in range(num_epoch):
for idx,batch_x in enumerate(dataloader):
loss = diffusion_loss_fn(model,batch_x,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,num_steps)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(),1.)
optimizer.step()
if(t%100==0):
print(loss)
x_seq = p_sample_loop(model,dataset.shape,num_steps,betas,one_minus_alphas_bar_sqrt)
x_seq = [item.to('cpu') for item in x_seq]
fig,axs = plt.subplots(1,10,figsize=(28,3))
for i in range(1,11):
cur_x = x_seq[i*10].detach()
axs[i-1].scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white');
axs[i-1].set_axis_off();
axs[i-1].set_title('$q(\mathbf{x}_{'+str(i*10)+'})$')
Training model…
tensor(0.5281, device=‘cuda:0’, grad_fn=)
tensor(0.6795, device=‘cuda:0’, grad_fn=)
tensor(0.3125, device=‘cuda:0’, grad_fn=)
tensor(0.3071, device=‘cuda:0’, grad_fn=)
tensor(0.2241, device=‘cuda:0’, grad_fn=)
tensor(0.3483, device=‘cuda:0’, grad_fn=)
tensor(0.4395, device=‘cuda:0’, grad_fn=)
tensor(0.3733, device=‘cuda:0’, grad_fn=)
tensor(0.6234, device=‘cuda:0’, grad_fn=)
tensor(0.2991, device=‘cuda:0’, grad_fn=)
tensor(0.3027, device=‘cuda:0’, grad_fn=)
tensor(0.3399, device=‘cuda:0’, grad_fn=)
tensor(0.2055, device=‘cuda:0’, grad_fn=)
tensor(0.4996, device=‘cuda:0’, grad_fn=)
tensor(0.4738, device=‘cuda:0’, grad_fn=)
tensor(0.1580, device=‘cuda:0’, grad_fn=)
…
好多个epoch之后be like:
文章来源:https://www.toymoban.com/news/detail-443220.html
动画演示
import io
from PIL import Image
imgs = []
for i in range(100):
plt.clf()
q_i = q_x(dataset,torch.tensor([i]))
plt.scatter(q_i[:,0],q_i[:,1],color='red',edgecolor='white',s=5);
plt.axis('off');
img_buf = io.BytesIO()
plt.savefig(img_buf,format='png')
img = Image.open(img_buf)
imgs.append(img)
reverse = []
for i in range(100):
plt.clf()
cur_x = x_seq[i].detach()
plt.scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white',s=5);
plt.axis('off')
img_buf = io.BytesIO()
plt.savefig(img_buf,format='png')
img = Image.open(img_buf)
reverse.append(img)
保存为gif文章来源地址https://www.toymoban.com/news/detail-443220.html
imgs = imgs + reverse
imgs[0].save("diffusion.gif",format='GIF',append_images=imgs,save_all=True,duration=100,loop=0)
到了这里,关于DDPM原理与代码剖析的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!