引言
在前面的博客中,我们讨论了生成模型VAE和GAN,近年来,新的生成模型——扩散模型受到越来越多的关注,因此值得好好去研究一番。扩散模型(Diffusion Models)最早由 [2] 于2015年提出,但直到2020年论文 [3] 发表之后才得到关注,本文详细梳理了 [3] 中的公式推导部分,帮助大家更好理解其中的数学原理。
数学模型
如下图所示(引自[1]),
x
0
x_0
x0 是原始数据,
q
q
q 是扩散模型,每扩散一次,都会在前一期数据的基础上添加部分噪声,当
t
→
∞
t \to \infty
t→∞,
x
T
x_T
xT 完全被噪声淹没,成为各向同性的高斯分布
x
T
∼
N
(
0
,
I
)
x_T \sim \mathcal{N}(0,I)
xT∼N(0,I),
p
θ
p_\theta
pθ 是生成模型,使用参数为
θ
\theta
θ 的网络来近似,将噪声恢复成有效信息,整个模型满足马尔可夫链条件。
目标函数
本文采用自顶向下的形式进行讲解,首先说明最终关注的目标函数,然后针对其中的细节分别深入。直观来说,我们的目标是让近似分布
p
θ
(
x
0
)
p_\theta(x_0)
pθ(x0) 尽可能接近数据的真实分布
q
(
x
0
)
q(x_0)
q(x0),所以目标函数可以用交叉熵来表示:
L
C
E
=
−
E
q
(
x
0
)
l
o
g
p
θ
(
x
0
)
=
−
E
q
(
x
0
)
l
o
g
[
∫
p
θ
(
x
0
:
T
)
d
x
1
:
T
]
=
−
E
q
(
x
0
)
l
o
g
[
∫
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
q
(
x
1
:
T
∣
x
0
)
d
x
1
:
T
]
=
−
E
q
(
x
0
)
l
o
g
[
E
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
q
(
x
1
:
T
∣
x
0
)
]
≤
−
E
q
(
x
0
:
T
)
l
o
g
p
θ
(
x
0
:
T
)
q
(
x
1
:
T
∣
x
0
)
,
J
e
n
s
e
n
i
n
e
q
u
a
l
i
t
y
\begin{aligned} \mathcal{L}_{CE}&=-\mathbb{E}_{q(x_0)}log\ p_\theta(x_0)\\ &=-\mathbb{E}_{q(x_0)}log[\int p_\theta(x_{0:T})dx_{1:T}]\\ &=-\mathbb{E}_{q(x_0)}log[\int q(x_{1:T}|x_0)\frac{p_\theta(x_{0:T})}{q(x_{1:T}|x_0)}dx_{1:T}]\\ &=-\mathbb{E}_{q(x_0)}log[\mathbb{E}_{q(x_{1:T}|x_0)}\frac{p_\theta(x_{0:T})}{q(x_{1:T}|x_0)}]\\ &\le -\mathbb{E}_{q(x_{0:T})}log\frac{p_\theta(x_{0:T})}{q(x_{1:T}|x_0)},\ Jensen\ inequality \end{aligned}
LCE=−Eq(x0)log pθ(x0)=−Eq(x0)log[∫pθ(x0:T)dx1:T]=−Eq(x0)log[∫q(x1:T∣x0)q(x1:T∣x0)pθ(x0:T)dx1:T]=−Eq(x0)log[Eq(x1:T∣x0)q(x1:T∣x0)pθ(x0:T)]≤−Eq(x0:T)logq(x1:T∣x0)pθ(x0:T), Jensen inequality
可以得到:
E
q
(
x
0
)
l
o
g
p
θ
(
x
0
)
≥
E
q
(
x
0
:
T
)
l
o
g
p
θ
(
x
0
:
T
)
q
(
x
1
:
T
∣
x
0
)
\mathbb{E}_{q(x_0)}log\ p_\theta(x_0) \ge \mathbb{E}_{q(x_{0:T})}log\frac{p_\theta(x_{0:T})}{q(x_{1:T}|x_0)}
Eq(x0)log pθ(x0)≥Eq(x0:T)logq(x1:T∣x0)pθ(x0:T)
不等式右边的项就是对数似然下界,记为
L
L
B
\mathcal{L}_{LB}
LLB,只要让其越大,不等式左边的项也就越大,交叉熵也就越小。
马克洛夫链假设
为了让
L
L
B
\mathcal{L}_{LB}
LLB 便于优化,需要补充一些知识。前面提到,模型满足马克洛夫链条件(Markov Chain),即当前状态
x
t
x_t
xt 仅与上一状态
x
t
−
1
x_{t-1}
xt−1 有关,假设马克洛夫关系为
A
→
B
→
C
A \to B \to C
A→B→C,可以得到性质:
p
(
B
,
C
∣
A
)
=
p
(
B
,
C
,
A
)
p
(
A
,
B
)
p
(
A
)
p
(
A
,
B
)
=
p
(
B
∣
A
)
p
(
C
∣
A
,
B
)
=
p
(
B
∣
A
)
p
(
C
∣
B
)
(1)
\begin{aligned} p(B,C|A)&=\frac{p(B,C,A)p(A,B)}{p(A)p(A,B)}=p(B|A)p(C|A,B)\\ &=p(B|A)p(C|B)\tag{1} \end{aligned}
p(B,C∣A)=p(A)p(A,B)p(B,C,A)p(A,B)=p(B∣A)p(C∣A,B)=p(B∣A)p(C∣B)(1)
利用公式(1),可以得到:
q
(
x
1
:
T
∣
x
0
)
=
∏
t
=
1
T
q
(
x
t
∣
x
t
−
1
)
p
θ
(
x
0
:
T
)
=
p
θ
(
x
T
)
∏
t
=
1
T
p
θ
(
x
t
−
1
∣
x
t
)
(2)
\tag{2} q(x_{1:T}|x_0)=\prod^T_{t=1}q(x_t|x_{t-1})\\ p_\theta(x_{0:T})=p_\theta(x_T)\prod^T_{t=1}p_\theta(x_{t-1}|x_t)
q(x1:T∣x0)=t=1∏Tq(xt∣xt−1)pθ(x0:T)=pθ(xT)t=1∏Tpθ(xt−1∣xt)(2)
将公式(2)代入
L
L
B
\mathcal{L}_{LB}
LLB,可以得到:
L
L
B
=
E
q
(
x
0
:
T
)
[
l
o
g
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
]
=
E
q
[
l
o
g
∏
t
=
1
T
q
(
x
t
∣
x
t
−
1
)
p
θ
(
x
T
)
∏
t
=
1
T
p
θ
(
x
t
−
1
∣
x
t
)
]
=
E
q
[
−
l
o
g
p
θ
(
x
T
)
+
∑
t
=
1
T
l
o
g
q
(
x
t
∣
x
t
−
1
)
p
θ
(
x
t
−
1
∣
x
t
)
]
=
E
q
[
−
l
o
g
p
θ
(
x
T
)
+
∑
t
=
2
T
l
o
g
q
(
x
t
∣
x
t
−
1
)
p
θ
(
x
t
−
1
∣
x
t
)
+
l
o
g
q
(
x
1
∣
x
0
)
p
θ
(
x
0
∣
x
1
)
]
=
E
q
[
−
l
o
g
p
θ
(
x
T
)
+
∑
t
=
2
T
l
o
g
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
θ
(
x
t
−
1
∣
x
t
)
⋅
q
(
x
t
∣
x
0
)
q
(
x
t
−
1
∣
x
0
)
+
l
o
g
q
(
x
1
∣
x
0
)
p
θ
(
x
0
∣
x
1
)
]
=
E
q
[
−
l
o
g
p
θ
(
x
T
)
+
∑
t
=
2
T
l
o
g
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
θ
(
x
t
−
1
∣
x
t
)
+
∑
t
=
2
T
l
o
g
q
(
x
t
∣
x
0
)
q
(
x
t
−
1
∣
x
0
)
+
l
o
g
q
(
x
1
∣
x
0
)
p
θ
(
x
0
∣
x
1
)
]
=
E
q
[
−
l
o
g
p
θ
(
x
T
)
+
∑
t
=
2
T
l
o
g
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
θ
(
x
t
−
1
∣
x
t
)
+
l
o
g
q
(
x
T
∣
x
0
)
q
(
x
1
∣
x
0
)
+
l
o
g
q
(
x
1
∣
x
0
)
p
θ
(
x
0
∣
x
1
)
]
=
E
q
[
l
o
g
q
(
x
T
∣
x
0
)
p
θ
(
x
T
)
+
∑
t
=
2
T
l
o
g
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
θ
(
x
t
−
1
∣
x
t
)
−
l
o
g
p
θ
(
x
0
∣
x
1
)
]
(3)
\tag{3} \begin{aligned} \mathcal{L}_{LB}&=\mathbb{E}_{q(x_{0:T})}[log\frac{q(x_{1:T}|x_0)}{p_\theta(x_{0:T})}]\\ &=\mathbb{E}_{q}[log\frac{\prod^T_{t=1}q(x_t|x_{t-1})}{p_\theta(x_T)\prod^T_{t=1}p_\theta(x_{t-1}|x_t)}]\\ &=\mathbb{E}_{q}[-log\ p_\theta(x_T)+\sum^T_{t=1}log\frac{q(x_t|x_{t-1})}{p_\theta(x_{t-1}|x_t)}]\\ &=\mathbb{E}_{q}[-log\ p_\theta(x_T)+\sum^T_{t=2}log\frac{q(x_t|x_{t-1})}{p_\theta(x_{t-1}|x_t)}+log\frac{q(x_1|x_0)}{p_\theta(x_0|x_1)}]\\ &=\mathbb{E}_{q}[-log\ p_\theta(x_T)+\sum^T_{t=2}log\frac{q(x_{t-1}|x_t,x_0)}{p_\theta(x_{t-1}|x_t)} \cdot \frac{q(x_t|x_0)}{q(x_{t-1}|x_0)}+log\frac{q(x_1|x_0)}{p_\theta(x_0|x_1)}]\\ &=\mathbb{E}_{q}[-log\ p_\theta(x_T)+\sum^T_{t=2}log\frac{q(x_{t-1}|x_t,x_0)}{p_\theta(x_{t-1}|x_t)}+\sum^T_{t=2}log\frac{q(x_t|x_0)}{q(x_{t-1}|x_0)}+log\frac{q(x_1|x_0)}{p_\theta(x_0|x_1)}]\\ &=\mathbb{E}_{q}[-log\ p_\theta(x_T)+\sum^T_{t=2}log\frac{q(x_{t-1}|x_t,x_0)}{p_\theta(x_{t-1}|x_t)}+log\frac{q(x_T|x_0)}{q(x_1|x_0)}+log\frac{q(x_1|x_0)}{p_\theta(x_0|x_1)}]\\ &=\mathbb{E}_{q}[log\frac{q(x_T|x_0)}{p_\theta(x_T)}+\sum^T_{t=2}log\frac{q(x_{t-1}|x_t,x_0)}{p_\theta(x_{t-1}|x_t)}-log\ p_\theta(x_0|x_1)] \end{aligned}
LLB=Eq(x0:T)[logpθ(x0:T)q(x1:T∣x0)]=Eq[logpθ(xT)∏t=1Tpθ(xt−1∣xt)∏t=1Tq(xt∣xt−1)]=Eq[−log pθ(xT)+t=1∑Tlogpθ(xt−1∣xt)q(xt∣xt−1)]=Eq[−log pθ(xT)+t=2∑Tlogpθ(xt−1∣xt)q(xt∣xt−1)+logpθ(x0∣x1)q(x1∣x0)]=Eq[−log pθ(xT)+t=2∑Tlogpθ(xt−1∣xt)q(xt−1∣xt,x0)⋅q(xt−1∣x0)q(xt∣x0)+logpθ(x0∣x1)q(x1∣x0)]=Eq[−log pθ(xT)+t=2∑Tlogpθ(xt−1∣xt)q(xt−1∣xt,x0)+t=2∑Tlogq(xt−1∣x0)q(xt∣x0)+logpθ(x0∣x1)q(x1∣x0)]=Eq[−log pθ(xT)+t=2∑Tlogpθ(xt−1∣xt)q(xt−1∣xt,x0)+logq(x1∣x0)q(xT∣x0)+logpθ(x0∣x1)q(x1∣x0)]=Eq[logpθ(xT)q(xT∣x0)+t=2∑Tlogpθ(xt−1∣xt)q(xt−1∣xt,x0)−log pθ(x0∣x1)](3)
将公式(3)的最后一行化成 KL 散度的形式:
E
q
(
x
0
)
D
K
L
[
q
(
x
T
∣
x
0
)
∣
∣
p
θ
(
x
T
)
]
+
E
q
(
x
0
,
x
t
)
∑
t
=
2
T
D
K
L
[
q
(
x
t
−
1
∣
∣
x
0
,
x
t
)
∣
∣
p
θ
(
x
t
−
1
∣
x
t
)
]
−
E
q
(
x
0
,
x
1
)
l
o
g
p
θ
(
x
0
∣
x
1
)
(4)
\tag{4} \mathbb{E}_{q(x_0)}D_{KL}[q(x_T|x_0)||p_\theta(x_T)]+\\ \mathbb{E}_{q(x_0,x_t)}\sum^T_{t=2}D_{KL}[q(x_{t-1}||x_0,x_t)||p_\theta(x_{t-1}|x_t)]-\\ \mathbb{E}_{q(x_0,x_1)}log\ p_\theta(x_0|x_1)
Eq(x0)DKL[q(xT∣x0)∣∣pθ(xT)]+Eq(x0,xt)t=2∑TDKL[q(xt−1∣∣x0,xt)∣∣pθ(xt−1∣xt)]−Eq(x0,x1)log pθ(x0∣x1)(4)
公式(4)第一项对应 VAE 中的正则化损失
D
K
L
(
q
ϕ
(
z
∣
x
)
∣
∣
p
θ
(
z
)
)
D_{KL}(q_\phi(z|x)||p_\theta(z))
DKL(qϕ(z∣x)∣∣pθ(z)),第三项对应于重建损失
E
q
ϕ
(
z
∣
x
)
[
l
o
g
(
p
θ
(
x
∣
z
)
)
]
\mathbb{E}_{q_\phi(z|x)}[log(p_\theta(x|z))]
Eqϕ(z∣x)[log(pθ(x∣z))],第二项是多个 KL 散度的和,每个度量
p
p
p 后验分布和
q
q
q 已知
x
0
x_0
x0 后验分布的距离。
重参数化
利用重参数化技巧可以让公式(4)中的
x
t
x_t
xt 可解,进一步简化
L
L
B
\mathcal{L}_{LB}
LLB。给定真实数据
x
0
∼
q
(
x
)
x_0 \sim q(x)
x0∼q(x),扩散过程的每一步可以表示为:
q
(
x
t
∣
x
t
−
1
)
=
N
(
x
t
;
1
−
β
t
x
t
−
1
,
β
t
I
)
(5)
q(x_t|x_{t-1})=\mathcal{N}(x_t;\ \sqrt{1-\beta_t}x_{t-1},\beta_tI)\tag{5}
q(xt∣xt−1)=N(xt; 1−βtxt−1,βtI)(5)
其中
β
t
\beta_t
βt 是一个超参数。利用重参数化技巧,可以使用
x
0
x_0
x0 直接计算任意时间点 t 上的
x
t
x_t
xt,不需要一步步迭代。假设
α
t
=
1
−
β
t
,
α
t
‾
=
∏
i
=
1
t
α
i
,
ϵ
t
∼
N
(
0
,
I
)
\alpha_t=1-\beta_t,\overline{\alpha_t}=\prod^t_{i=1}\alpha_i,\epsilon_t \sim \mathcal{N}(0,I)
αt=1−βt,αt=∏i=1tαi,ϵt∼N(0,I):
x
t
=
α
t
x
t
−
1
+
1
−
α
t
ϵ
t
−
1
=
α
t
[
α
t
−
1
x
t
−
2
+
1
−
α
t
−
1
ϵ
t
−
2
]
+
1
−
α
t
ϵ
t
−
1
=
α
t
α
t
−
1
x
t
−
2
+
α
t
−
α
t
α
t
−
1
ϵ
t
−
2
+
1
−
α
t
ϵ
t
−
1
=
α
t
α
t
−
1
x
t
−
2
+
1
−
α
t
α
t
−
1
ϵ
‾
t
−
2
=
.
.
.
=
α
‾
t
x
0
+
1
−
α
‾
t
ϵ
∴
q
(
x
t
∣
x
0
)
=
N
(
x
t
;
α
‾
t
x
0
,
(
1
−
α
‾
t
)
I
)
(6)
\tag{6} \begin{aligned} x_t&=\sqrt{\alpha_t}x_{t-1}+\sqrt{1-\alpha_t}\epsilon_{t-1}\\ &=\sqrt{\alpha_t}[\sqrt{\alpha_{t-1}}x_{t-2}+\sqrt{1-\alpha_{t-1}}\epsilon_{t-2}]+\sqrt{1-\alpha_t}\epsilon_{t-1}\\ &=\sqrt{\alpha_t\alpha_{t-1}}x_{t-2}+\sqrt{\alpha_t-\alpha_t\alpha_{t-1}}\epsilon_{t-2}+\sqrt{1-\alpha_t}\epsilon_{t-1}\\ &=\sqrt{\alpha_t\alpha_{t-1}}x_{t-2}+\sqrt{1-\alpha_t\alpha_{t-1}}\overline{\epsilon}_{t-2}\\ &=...\\ &=\sqrt{\overline{\alpha}_t}x_0+\sqrt{1-\overline{\alpha}_t}\epsilon\\ &\therefore q(x_t|x_0)=\mathcal{N}(x_t;\ \sqrt{\overline{\alpha}_t}x_0,(1-\overline{\alpha}_t)I) \end{aligned}
xt=αtxt−1+1−αtϵt−1=αt[αt−1xt−2+1−αt−1ϵt−2]+1−αtϵt−1=αtαt−1xt−2+αt−αtαt−1ϵt−2+1−αtϵt−1=αtαt−1xt−2+1−αtαt−1ϵt−2=...=αtx0+1−αtϵ∴q(xt∣x0)=N(xt; αtx0,(1−αt)I)(6)
其中
α
t
−
α
t
α
t
−
1
ϵ
t
−
2
+
1
−
α
t
ϵ
t
−
1
∼
N
(
0
,
1
−
α
t
α
t
−
1
)
\sqrt{\alpha_t-\alpha_t\alpha_{t-1}}\epsilon_{t-2}+\sqrt{1-\alpha_t}\epsilon_{t-1} \sim \mathcal{N}(0,1-\alpha_t\alpha_{t-1})
αt−αtαt−1ϵt−2+1−αtϵt−1∼N(0,1−αtαt−1),从而得到
ϵ
‾
t
−
2
\overline{\epsilon}_{t-2}
ϵt−2。进一步,可以计算
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t-1}|x_t,x_0)
q(xt−1∣xt,x0) 的解析式:
q
(
x
t
−
1
∣
x
t
,
x
0
)
=
q
(
x
t
∣
x
t
−
1
,
x
0
)
q
(
x
t
−
1
∣
x
0
)
q
(
x
t
∣
x
0
)
∝
e
x
p
[
−
1
2
(
(
x
t
−
α
t
x
t
−
1
)
2
β
t
+
(
x
t
−
1
−
α
‾
t
−
1
x
0
)
2
1
−
α
‾
t
−
1
−
(
x
t
−
α
‾
t
x
0
)
2
1
−
α
‾
t
)
]
=
e
x
p
[
−
1
2
(
(
α
t
β
t
+
1
1
−
α
‾
t
−
1
)
x
t
−
1
2
−
(
2
α
t
β
t
x
t
+
2
α
‾
t
−
1
1
−
α
‾
t
−
1
x
0
)
x
t
−
1
+
C
(
x
t
,
x
0
)
)
]
(7)
\tag{7} \begin{aligned} q(x_{t-1}|x_t,x_0)&=q(x_t|x_{t-1},x_0)\frac{q(x_{t-1}|x_0)}{q(x_t|x_0)}\\ &\propto exp[-\frac{1}{2}(\frac{(x_t-\sqrt{\alpha_t}x_{t-1})^2}{\beta_t}+\frac{(x_{t-1}-\sqrt{\overline{\alpha}_{t-1}}x_0)^2}{1-\overline{\alpha}_{t-1}}-\frac{(x_t-\sqrt{\overline{\alpha}_t}x_0)^2}{1-\overline{\alpha}_t})]\\ &=exp[-\frac{1}{2}((\frac{\alpha_t}{\beta_t}+\frac{1}{1-\overline{\alpha}_{t-1}})x^2_{t-1}-(\frac{2\sqrt{\alpha_t}}{\beta_t}x_t+\frac{2\sqrt{\overline{\alpha}_{t-1}}}{1-\overline{\alpha}_{t-1}}x_0)x_{t-1}+C(x_t,x_0))] \end{aligned}
q(xt−1∣xt,x0)=q(xt∣xt−1,x0)q(xt∣x0)q(xt−1∣x0)∝exp[−21(βt(xt−αtxt−1)2+1−αt−1(xt−1−αt−1x0)2−1−αt(xt−αtx0)2)]=exp[−21((βtαt+1−αt−11)xt−12−(βt2αtxt+1−αt−12αt−1x0)xt−1+C(xt,x0))](7)
我们知道,
a
x
2
+
b
x
ax^2+bx
ax2+bx 可以化成
a
(
x
+
b
2
a
)
2
+
c
a(x+\frac{b}{2a})^2+c
a(x+2ab)2+c,那么对于指数项是
a
x
2
+
b
x
ax^2+bx
ax2+bx 这种格式的高斯分布,
μ
=
−
b
2
a
,
σ
2
=
1
a
\mu=-\frac{b}{2a},\sigma^2=\frac{1}{a}
μ=−2ab,σ2=a1。设
q
(
x
t
−
1
∣
x
t
,
x
0
)
=
N
(
x
t
−
1
;
μ
~
(
x
t
,
x
0
)
,
β
~
t
I
)
q(x_{t-1}|x_t,x_0)=\mathcal{N}(x_{t-1};\ \tilde{\mu}(x_t,x_0),\tilde{\beta}_tI)
q(xt−1∣xt,x0)=N(xt−1; μ~(xt,x0),β~tI),代入公式(7)有:
β
~
t
=
1
−
α
‾
t
−
1
1
−
α
‾
t
⋅
β
t
μ
~
t
(
x
t
,
x
0
)
=
α
t
(
1
−
α
‾
t
−
1
)
1
−
α
‾
t
x
t
+
α
‾
t
−
1
β
t
1
−
α
‾
t
x
0
(8)
\tag{8} \begin{aligned} \tilde{\beta}_t&=\frac{1-\overline{\alpha}_{t-1}}{1-\overline{\alpha}_t}\cdot\beta_t\\ \tilde{\mu}_t(x_t,x_0)&=\frac{\sqrt{\alpha_t}(1-\overline{\alpha}_{t-1})}{1-\overline{\alpha}_t}x_t+\frac{\sqrt{\overline{\alpha}_{t-1}}\beta_t}{1-\overline{\alpha}_t}x_0 \end{aligned}
β~tμ~t(xt,x0)=1−αt1−αt−1⋅βt=1−αtαt(1−αt−1)xt+1−αtαt−1βtx0(8)
根据公式(6),将
x
0
=
1
α
‾
t
(
x
t
−
1
−
α
‾
t
ϵ
t
)
x_0=\frac{1}{\sqrt{\overline{\alpha}_t}}(x_t-\sqrt{1-\overline{\alpha}_t}\epsilon_t)
x0=αt1(xt−1−αtϵt) 代入公式(8),可以进一步化简
μ
~
t
(
x
t
,
x
0
)
\tilde{\mu}_t(x_t,x_0)
μ~t(xt,x0):
μ
~
t
(
x
t
,
x
0
)
=
1
α
t
(
x
t
−
1
−
α
t
1
−
α
‾
t
ϵ
t
)
(9)
\tilde{\mu}_t(x_t,x_0)=\frac{1}{\sqrt{\alpha_t}}(x_t-\frac{1-\alpha_t}{\sqrt{1-\overline{\alpha}_t}}\epsilon_t)\tag{9}
μ~t(xt,x0)=αt1(xt−1−αt1−αtϵt)(9)
高斯分布的 KL 散度
重参数化技巧使能够解析计算
x
t
x_t
xt 和
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t-1}|x_t,x_0)
q(xt−1∣xt,x0),除此之外,当前分布为高斯分布时,KL 散度的计算也能得到简化,考虑一维高斯分布
a
∼
N
(
μ
1
,
σ
1
2
)
,
b
∼
N
(
μ
2
,
σ
2
2
)
a \sim \mathcal{N}(\mu_1,\sigma^2_1),b \sim \mathcal{N}(\mu_2,\sigma^2_2)
a∼N(μ1,σ12),b∼N(μ2,σ22),它们的 KL 散度为(同理可以扩展到多维高斯分布):
D
K
L
(
a
∣
∣
b
)
=
E
a
l
o
g
a
b
=
E
a
[
l
o
g
σ
2
σ
1
+
(
x
−
μ
2
)
2
2
σ
2
2
+
(
x
−
μ
1
)
2
2
σ
1
2
]
=
l
o
g
σ
2
σ
1
+
1
2
E
a
[
(
x
−
μ
2
)
2
σ
2
2
+
(
x
−
μ
1
)
2
σ
1
2
]
=
l
o
g
σ
2
σ
1
+
1
2
E
a
[
(
1
σ
2
2
−
1
σ
1
2
)
x
2
+
(
2
μ
1
σ
1
2
−
2
μ
2
σ
2
2
)
x
+
μ
2
2
σ
2
2
−
μ
1
2
σ
1
2
]
=
l
o
g
σ
2
σ
1
+
σ
1
2
+
(
μ
1
−
μ
2
)
2
2
σ
2
2
−
1
2
(10)
\tag{10} \begin{aligned} D_{KL}(a||b)&=\mathbb{E}_alog\ \frac{a}{b}\\ &=\mathbb{E}_a[log\frac{\sigma_2}{\sigma_1}+\frac{(x-\mu_2)^2}{2\sigma^2_2}+\frac{(x-\mu_1)^2}{2\sigma^2_1}]\\ &=log\frac{\sigma_2}{\sigma_1}+\frac{1}{2}\mathbb{E}_a[\frac{(x-\mu_2)^2}{\sigma^2_2}+\frac{(x-\mu_1)^2}{\sigma^2_1}]\\ &=log\frac{\sigma_2}{\sigma_1}+\frac{1}{2}\mathbb{E}_a[(\frac{1}{\sigma^2_2}-\frac{1}{\sigma^2_1})x^2+(\frac{2\mu_1}{\sigma^2_1}-\frac{2\mu_2}{\sigma^2_2})x+\frac{\mu^2_2}{\sigma^2_2}-\frac{\mu^2_1}{\sigma^2_1}]\\ &=log\frac{\sigma_2}{\sigma_1}+\frac{\sigma^2_1+(\mu_1-\mu_2)^2}{2\sigma^2_2}-\frac{1}{2} \end{aligned}
DKL(a∣∣b)=Ealog ba=Ea[logσ1σ2+2σ22(x−μ2)2+2σ12(x−μ1)2]=logσ1σ2+21Ea[σ22(x−μ2)2+σ12(x−μ1)2]=logσ1σ2+21Ea[(σ221−σ121)x2+(σ122μ1−σ222μ2)x+σ22μ22−σ12μ12]=logσ1σ2+2σ22σ12+(μ1−μ2)2−21(10)
优化目标函数
公式(4)中的第一项是常数,第三项可以看作是第二项
t
=
1
t=1
t=1 时的结果,所以我们主要考虑第二项。设
p
θ
(
x
t
−
1
∣
x
t
)
=
N
(
x
t
−
1
;
μ
θ
(
x
t
,
t
)
,
σ
t
2
I
)
p_\theta(x_{t-1}|x_t)=\mathcal{N}(x_{t-1};\ \mu_\theta(x_t,t),\sigma^2_tI)
pθ(xt−1∣xt)=N(xt−1; μθ(xt,t),σt2I),根据公式(10),第二项可以化简为:
L
t
−
1
=
E
q
[
1
2
σ
t
2
∣
∣
μ
~
t
(
x
t
,
x
0
)
−
μ
θ
(
x
t
,
t
)
∣
∣
2
]
+
C
(11)
L_{t-1}=\mathbb{E}_q[\frac{1}{2\sigma^2_t}||\tilde{\mu}_t(x_t,x_0)-\mu_\theta(x_t,t)||^2]+C\tag{11}
Lt−1=Eq[2σt21∣∣μ~t(xt,x0)−μθ(xt,t)∣∣2]+C(11)
将公式(9)代入公式(11)可以得到:
L
t
−
1
=
E
x
t
,
ϵ
[
1
2
σ
t
2
∣
∣
1
α
t
(
x
t
−
β
t
1
−
α
‾
t
ϵ
)
−
μ
θ
(
x
t
,
t
)
∣
∣
2
]
(12)
L_{t-1}=\mathbb{E}_{x_t,\epsilon}[\frac{1}{2\sigma^2_t}||\frac{1}{\sqrt{\alpha_t}}(x_t-\frac{\beta_t}{\sqrt{1-\overline{\alpha}_t}}\epsilon)-\mu_\theta(x_t,t)||^2]\tag{12}
Lt−1=Ext,ϵ[2σt21∣∣αt1(xt−1−αtβtϵ)−μθ(xt,t)∣∣2](12)
[3]作者进行了参数化
μ
θ
(
x
t
,
t
)
=
1
α
t
(
x
t
−
β
t
1
−
α
‾
t
ϵ
θ
(
x
t
,
t
)
)
\mu_\theta(x_t,t)=\frac{1}{\sqrt{\alpha_t}}(x_t-\frac{\beta_t}{\sqrt{1-\overline{\alpha}_t}}\epsilon_\theta(x_t,t))
μθ(xt,t)=αt1(xt−1−αtβtϵθ(xt,t)),相当于网络拟合的是每个时间点的噪声(为什么选择这样参数化我还没明白),同时代入公式(6),公式(12)进一步化简为:
L
t
−
1
=
E
x
0
,
ϵ
[
β
t
2
2
σ
t
2
α
t
(
1
−
α
‾
t
)
∣
∣
ϵ
−
ϵ
θ
(
α
‾
t
x
0
+
1
−
α
‾
t
ϵ
,
t
)
∣
∣
2
]
L_{t-1}=\mathbb{E}_{x_0,\epsilon}[\frac{\beta^2_t}{2\sigma^2_t\alpha_t(1-\overline{\alpha}_t)}||\epsilon-\epsilon_\theta(\sqrt{\overline{\alpha}_t}x_0+\sqrt{1-\overline{\alpha}_t}\epsilon,t)||^2]
Lt−1=Ex0,ϵ[2σt2αt(1−αt)βt2∣∣ϵ−ϵθ(αtx0+1−αtϵ,t)∣∣2]
[3]作者发现,将前面的系数丢掉,训练更加稳定,因此得到最终的损失:
L
s
i
m
p
l
e
=
E
t
,
x
0
,
ϵ
[
∣
∣
ϵ
−
ϵ
θ
(
α
‾
t
x
0
+
1
−
α
‾
t
ϵ
,
t
)
∣
∣
2
]
L_{simple}=\mathbb{E}_{t,x_0,\epsilon}[||\epsilon-\epsilon_\theta(\sqrt{\overline{\alpha}_t}x_0+\sqrt{1-\overline{\alpha}_t}\epsilon,t)||^2]
Lsimple=Et,x0,ϵ[∣∣ϵ−ϵθ(αtx0+1−αtϵ,t)∣∣2]文章来源:https://www.toymoban.com/news/detail-403121.html
参考
[1] Understanding Diffusion Models: A Unified Perspective
[2] Deep Unsupervised Learning using Nonequilibrium Thermodynamics
[3] Denoising Diffusion Probabilistic Models
[4] What are Diffusion Models?
[5] Probabilistic Diffusion Model概率扩散模型理论文章来源地址https://www.toymoban.com/news/detail-403121.html
到了这里,关于理解扩散模型:Diffusion Models & DDPM的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!