PyTorch中计算KL散度详解

这篇具有很好参考价值的文章主要介绍了PyTorch中计算KL散度详解。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

PyTorch计算KL散度详解

最近在进行方法设计时,需要度量分布之间的差异,由于样本间分布具有相似性,首先想到了便于实现的KL-Divergence,使用PyTorch中的内置方法时,踩了不少坑,在这里详细记录一下。

简介

首先简单介绍一下KL散度(具体的可以在各种技术博客看到讲解,我这里不做重点讨论)。
从名称可以看出来,它并不是严格意义上的距离(所以才叫做散度~),原因是它并不满足距离的对称性,为了弥补这种缺陷,出现了JS散度(这就是另一个故事了…)
我们先来看一下KL散度的形式:
D K L ( P ∣ ∣ Q ) = ∑ i = 1 N p i log ⁡ p i q i = ∑ i = 1 N p i ∗ ( log ⁡ p i − log ⁡ q i ) DKL(P||Q) = \sum_{i=1}^{N} {p_i\log{\frac{p_i}{q_i}}} = \sum_{i=1}^{N} { p_i*(\log{p_i}-\log{q_i})} DKL(PQ)=i=1Npilogqipi=i=1Npi(logpilogqi)

手动代码实现

可以看到,KL散度形式上还是比较直观的,我们先手撸一个试试:
这里我们随机设定两个随机变量P和Q

import torch
P = torch.tensor([0.4, 0.6])
Q = torch.tensor([0.3, 0.7])

快速算一下答案:
D K L ( P ∣ ∣ Q ) = 0.4 ∗ ( log ⁡ 0.4 − log ⁡ 0.3 ) + 0.6 ∗ ( log ⁡ 0.6 − log ⁡ 0.7 ) ≈ 0.0226 \begin{aligned} DKL(P||Q) &= 0.4* (\log{0.4} - \log{0.3}) + 0.6 * (\log{0.6} - \log{0.7}) \\ & \approx 0.0226 \end{aligned} DKL(PQ)=0.4(log0.4log0.3)+0.6(log0.6log0.7)0.0226

数值计算实现版:

def DKL(_p, _q):
		"""calculate the KL divergence between _p and _q
		"""
    return  torch.sum(_p * (_p.log() - _q.log()), dim=-1)

divergence = DKL(P, Q)
print(divergence)
# tensor(0.0226)

上面的代码中,之所以求和时dim=-1是因为我在使用的过程中,考虑到有时是对batch中feature进行计算,所以这里只对特征维度进行求和。
接下来,就到了今天介绍的主角~

torch代码实现

torch中提供有两种不同的api用于计算KL散度,分别是torch.nn.functional.kl_div()torch.nn.KLDivLoss(),两者计算效果类似,区别无非是直接计算和作为损失函数类。

先介绍一下torch.nn.functional.kl_div()

注意,该方法的inputtarget K L ( P ∣ ∣ Q ) KL(P||Q) KL(PQ) P P P Q Q Q的位置正好相反,从参数名称就可以看出来(target为目标分布 P P Pinput为待度量分布 Q Q Q)。为了防止指代混乱,我后面统一用 P P P Q Q Q指代targetinput
PyTorch中计算KL散度详解
这里重点关注几个对计算结果有影响的参数:

reduction:该参数是结果应该以什么规约形式进行呈现,sum即为我们定义式中的效果,batchmean:按照batch大小求平均,mean:按照元素个数进行求平均

再看看log_target的效果:

if not log_target: # default
    loss_pointwise = target * (target.log() - input)
else:
    loss_pointwise = target.exp() * (target - input)

也就是说,如果log_target=False,此时计算方式为
r e s = P ∗ ( log ⁡ P − Q ) res = P * ( \log{P}-Q) res=P(logPQ)
这和我们熟悉的定义式的计算方式是不同的,如果想要和定义式的效果一致,需要对input取对数操作(在官方文档中也有提及,建议将input映射到对数空间,防止数值下溢):

import torch.nn.Functional as F

print(F.kl_div(Q.log(), P, reduction='sum'))
#tensor(0.0226)

而当log_target=True时,此时的计算方式变为
r e s = e P ∗ ( P − Q ) res=e^{P}*(P-Q) res=eP(PQ)
也就是说,此时我们对 P P P取对数操作即可得到定义式的效果:

print(F.kl_div(Q.log(), P.log(), 
	  log_target=True, reduction='sum'))
#tensor(0.0226)

这样设计的目的也是为了防止数值下溢。

torch.nn.KLDivLoss()的参数列表与torch.nn.functional.kl_div()类似,这里就不过多赘述。

总结

总的来说,当需要计算KL散度时,默认情况下需要对input取对数,并设置reduction='sum'方能得到与定义式相同的结果:

divergence = F.kl_div(Q.log(), P, reduction='sum')

由于我们度量的是两个分布的差异,因此通常需要对输入进行softmax归一化(如果已经归一化则无需此操作):文章来源地址https://www.toymoban.com/news/detail-420867.html

divergence = F.kl_div(Q.softmax(-1).log(), P.softmax(-1), reduction='sum')

到了这里,关于PyTorch中计算KL散度详解的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • KL散度和交叉熵的对比介绍

    KL散度(Kullback-Leibler Divergence)和交叉熵(Cross Entropy)是在机器学习中广泛使用的概念。这两者都用于比较两个概率分布之间的相似性,但在一些方面,它们也有所不同。本文将对KL散度和交叉熵的详细解释和比较。 KL散度,也称为相对熵(Relative Entropy),是用来衡量两个概

    2023年04月23日
    浏览(89)
  • 信息论之从熵、惊奇到交叉熵、KL散度和互信息

    考虑将A地观测的一个随机变量x,编码后传输到B地。 这个随机变量有8种可能的状态,每个状态都是等可能的。为了把x的值传给接收者,需要传输一个3-bits的消息。注意,这个变量的熵由下式给出: ⾮均匀分布⽐均匀分布的熵要⼩。 如果概率分布非均匀,同样使用等长编码,

    2023年04月15日
    浏览(78)
  • 【扩散模型Diffusion Model系列】0-从VAE开始(隐变量模型、KL散度、最大化似然与AIGC的关系)

    VAE(Variational AutoEncoder),变分自编码器,是一种无监督学习算法,被用于压缩、特征提取和生成式任务。相比于GAN(Generative Adversarial Network),VAE在数学上有着更加良好的性质,有利于理论的分析和实现。 生成式模型(Generative Model)的目标是学习一个模型,从 一个简单的分布 p (

    2024年02月03日
    浏览(49)
  • 整层水汽通量和整层水汽通量散度计算及python绘图

    整层水汽通量和整层水汽通量散度计算及python绘图 一、公式推导 1、整层水汽通量: (1)单层水汽通量: 在P坐标下, 单层水汽通量 = q·v/g q的单位为kg/kg,v的单位为m/s。对于重力加速度g的单位要进行换算: 也就是说,重力加速度g的单位是10**-2·hPa·m**2/kg。 最终,单层水汽

    2024年02月02日
    浏览(38)
  • KL15和KL30的区别

    相信刚接触汽车电子的伙伴都会有一个疑惑,什么是KL15?什么是KL30? KL是德语Klemme的缩写,指的是ECU的管脚,可以理解为Pin的意思。 KL30 电源(也称“常电”),即蓄电池,提供 ECU 的工作电压,一般是 11V 到 15V,一般在发动机未点火的时候(对应汽车钥匙孔的 OFF档),车上少部分

    2024年02月11日
    浏览(33)
  • [学习笔记-扫盲]KL15,KL30

    KL:德语Klemme,ECU的引脚,同Pin 15,30:引脚编号: KL15 表示发动机的点火信号和 启动车辆 的信号,汽车在Run模式 KL30 表示蓄电池的正极(31为负极),为各ECU进行低压供电,通常为11V~15V,即 接通蓄电池电源 其他状态: KLR:汽车在ACC模式 KL50:汽车在crank模式 钥匙初始位置

    2024年02月11日
    浏览(41)
  • 概率分布之间的散度(Divergence)

    Divergence between distributions.

    2024年02月08日
    浏览(40)
  • 哈密顿算符梯度 散度 旋度的补充

    做一些哈密顿算符的补充 后面的是一个向量,但是单独的看这个向量没有意义 需要把这个函数和其他函数放在一起做运算的时候才有意义 梯度 散度 和 旋度 蓝色部分是一个标量的梯度 我们用 算符 乘以这个 标量 (后面的就是向量乘以标量) 我们把 f 乘进去,得到了f对于

    2024年02月06日
    浏览(33)
  • Android input输入设备与kl文件的匹配

    input设备的事件上报和系统中keyCode的对应是通过kl(keyLayout)文件来进行转换的。Android系统中预置了很多的kl文件,如果要定制input行为,我们也会添加或者修改kl文件。 一个Android设备会存在多个input设备,本文主要分析是如何为不同的input设备寻找匹配对应的kl(keyLayout)文件的。

    2024年01月22日
    浏览(43)
  • 【机器学习】Kullback-Leibler散度实现数据监控

    https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence         本篇叙述了KL 散度的数学、直觉和如何实际使用;以及它如何最好地用于过程监测。Kullback-Leibler 散度度量(相对熵)是信息论中的一种统计测量方法,通常用于量化一个概率分布与参考概率分布之间的差异。   

    2024年02月09日
    浏览(35)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包