VAE损失函数的推导及实现

这篇具有很好参考价值的文章主要介绍了VAE损失函数的推导及实现。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

VAE损失函数的推导

VAE最原始的优化目标

我们从解码器的角度来引出VAE的优化目标,即传入一个变量z,我们期待解码器能生成我们所期望生成的数据。

我们举个简单的例子来说明一下:假设在我们当前的任务下解码器的目标是根据输入的z来生成一张手写数字图片。当我们传入z之后,解码器的输出可能是各种各样的,但我们希望解码器能生成手写数字图片,而不是生成一个汉字或者是其他奇奇怪怪的符号,而这就是VAE的最原始的优化目标。

我们使用p代表解码器,p(x|z)代表给定z时解码器产生x的概率,其中x并非一个具体的值,而可以看作是一类数据,比如在我们上述的例子中,x可以代表某种风格的手写体数字,p(x|z)就是生成这些数字的概率,这里的概率也并非一个具体的值,而是某一风格的每个数字对应了一个概率,其输出的是一个概率分布。

当我们明白了这些时,我们就可以写出来VAE的优化目标,即最大化解码器输出x的概率,即最大化p(x)。

损失函数推导前的准备

我们可以将p(x)其改写为包含了传入参数的形式,即
vae损失函数,概率论,深度学习,python
当我们将z从离散分布变为连续分布时,该式就变成了
vae损失函数,概率论,深度学习,python
这里的p(z)可以是任意分布,在VAE中我们常常假设p(z)服从标准正态分布。

我们同时也需要知道KL散度的一些相关知识:KL散度用于衡量两个分布之间的差异,其值越大则两个分布的差异越大,同时两个分布的KL散度非负。计算a、b两个分布的KL散度的公式如下
vae损失函数,概率论,深度学习,python

损失函数的推导其一

为了最大化p(x),我们可以采用极大似然估计的方法来进行,即最大化vae损失函数,概率论,深度学习,python
对应于我们之前给的例子,这里的每个x可以代表了某一个风格的手写体,我们的目标是生成手写体数字,因此我们并不会局限其风格,只要生成的正确就要最大化其概率。

由于最大化L即相当于最大化log p(x),因此后续目标调整为最大化log p(x)。我们假设q代表了编码器,q(z|x)就代表了给定x时编码器产生z的概率。由于
vae损失函数,概率论,深度学习,python
即不管给定何种x,其产生不同z的概率之和恒为1。又因为p(x)与z无关,因此我们可以将log p(x)改写为如下的形式。
vae损失函数,概率论,深度学习,python
由于p(x) = p(x, z) / p(z|x) = (p(x, z) / q(z|x)) * (q(z|x) / p(z|x))

其中第一次变化使用了概率论的定理,第二次变化仅仅加入了一个中间项,可以直接约分掉,并不影响结果。

此时我们可以将log p(x)写为如下形式。
vae损失函数,概率论,深度学习,python
我们将log里的乘积拆开,变为两项之和,即
vae损失函数,概率论,深度学习,python
结合之前提到过的KL散度相关的知识,我们可以看出第二项其实就是KL(q(z|x) || p(z|x))。因为该值为非负项,所以log p(x)不可能小于第一项,我们使用Lb来指代第一项,从而便于书写。

结合我们在准备阶段所提到的
vae损失函数,概率论,深度学习,python
我们可以知道,当p(x|z)不变时,p(x)也不变,从而log p(x)也不变,那么Lb+KL(q(z|x) || p(z|x))的值就不会变。这时如果我们利用q(z|x)来最大化Lb,那么Lb就会增大,而KL(q(z|x) || p(z|x))的值就会减小。

那么如果q(z|x)不变呢?此时当我们增大p(x|z)时,Lb会增大且p(x)会增大,即log p(x)也会增大。

由此我们可以得出结论,只要我们最大化Lb就能使log p(x)最大化。

损失函数的推导其二

此时我们的目标变为了最大化Lb。
由于p(x,z)=p(z)*p(x|z),我们将Lb中的p(x,z)替换为p(z)*p(x|z),并将其从log里的拆开,可以得到如下结果
vae损失函数,概率论,深度学习,python
我们可以看出Lb的第一项为-KL(q(z|x) || p(z)),即q(z|x)与p(z)两个分布之间的Kl散度的相反数。Lb的第二项可以看作是在q(z|x)这个分布下log p(x|z)的期望,即vae损失函数,概率论,深度学习,python
此时VAE的最终目标就一目了然了,VAE的训练目标有两个:
第一,最小化KL(q(z|x) || p(z)),使q(z|x)的分布尽量向p(z)靠近。
第二,最大化在q(z|x)这个分布下log p(x|z)的期望,其中q(z|x)为编码器输入x时产生z的概率。假设解码器利用z生成出了x’,我们就需要使x’尽可能向x靠近,以最大化log p(x|z)。

实际使用时所用到的损失函数

根据上述的两个训练目标,VAE的损失函数也被设计为两个:

  1. L1用于最小化KL(q(z|x) || p(z)),VAE假设q(z|x)的分布为正态分布,而p(z)为标准正态分布。计算两个正态分布之间的KL散度的公式如下:
    vae损失函数,概率论,深度学习,python
    由于此处p(z)为标准正态分布,因此其μ为0,σ为1,那么我们带入后可得
    vae损失函数,概率论,深度学习,python
    其中σ为q(z|x)的标准差,μ为q(z|x)的均值。

实际实现时,当编码器接收到x时,我们并不会让编码器直接输出对应的z,而是会使编码器输出z的分布的均值和标准差,此时我们就可以使用上述的式子作为损失函数,从而更新编码器参数。

此时我们得到了第一个损失函数。
vae损失函数,概率论,深度学习,python
在训练解码器时,我们会从标准正态分布中随机取样,使其乘上上述得到的方差,之后使其加上上述的均值,以此来构建解码器的输入,这样做相当于是给输入加上了噪音,使得解码器的稳定性更好。vae损失函数,概率论,深度学习,python
2. L2使解码器输出的x’尽可能向x靠近,要做到这个,我们只需要最小化x’和x之间的均方误差即可,即
vae损失函数,概率论,深度学习,python文章来源地址https://www.toymoban.com/news/detail-792360.html

损失函数的代码实现

def loss_function(recon, x, mu, std) -> torch.Tensor:
    """
    :param recon: output of the decoder
    :param x: encoder input
    :param mu: mean
    :param std: standard deviation
    :return:
    """
    recon_loss = torch.nn.functional.mse_loss(recon, x, reduction="sum")
    kl_loss = -0.5 * (1 + 2 * torch.log(std) - mu.pow(2) - std.pow(2))
    kl_loss = torch.sum(kl_loss)
    loss = recon_loss + kl_loss
    return loss

到了这里,关于VAE损失函数的推导及实现的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包