论文:[1503.02531] Distilling the Knowledge in a Neural Network (arxiv.org)
知识蒸馏是一种模型压缩方法,是一种基于“教师-学生网络思想”的训练方式,由于其简单,有效,并且已经在工业界被广泛应用。
知识蒸馏使用的是Teacher—Student模型,其中teacher是“知识”的输出者,student是“知识”的接受者。知识蒸馏的过程分为2个阶段:
①原始模型训练: 训练"Teacher模型", 简称为Net-T,它的特点是模型相对复杂,也可以由多个分别训练的模型集成而成。我们对"Teacher模型"不作任何关于模型架构、参数量、是否集成方面的限制,唯一的要求就是,对于输入X, 其都能输出Y,其中Y经过softmax的映射,输出值对应相应类别的概率值。
②精简模型训练: 训练"Student模型", 简称为Net-S,它是参数量较小、模型结构相对简单的单模型。同样的,对于输入X,其都能输出Y,Y经过softmax映射后同样能输出对应相应类别的概率值。在本论文中,作者将问题限定在分类问题下,或者其他本质上属于分类问题的问题,该类问题的共同点是模型最后会有一个softmax层,其输出值对应了相应类别的概率值。
现实中,由于我们不可能收集到某问题的所有数据来作为训练数据,并且新数据总是在源源不断的产生,因此我们只能退而求其次,训练目标变成在已有的训练数据集上建模输入和输出之间的关系。由于训练数据集是对真实数据分布情况的采样,训练数据集上的最优解往往会多少偏离真正的最优解。
而在知识蒸馏时,由于我们已经有了一个泛化能力较强的Net-T,我们在利用Net-T来蒸馏训练Net-S时,可以直接让Net-S去学习Net-T的泛化能力。一个很直白且高效的迁移泛化能力的方法就是使用softmax层输出的类别的概率来作为“soft target”。
①传统training过程(hard targets): 对ground truth求极大似然
②KD的training过程(soft targets): 用large model的class probabilities作为soft targets
例子:
在MNIST手写数字识别任务中
假设某个输入的“2”更加形似"3",softmax的输出值中"3"对应的概率为0.1,而其他负标签对应的值都很小,而另一个"2"更加形似"7","7"对应的概率为0.1。这两个"2"对应的hard target的值是相同的,但是它们的soft target却是不同的,由此我们可见soft target蕴含着比hard target多的信息。并且soft target分布的熵相对高时,其soft target蕴含的知识就更丰富。
两个”2“的hard target相同而soft target不同。
这就解释了为什么通过蒸馏的方法训练出的Net-S相比使用完全相同的模型结构和训练数据只使用hard target的训练方法得到的模型,拥有更好的泛化能力。
温度T
把其他类别的可能性放大,把他们的相对大小充分暴露出来,让学生网络更加强烈地知道这些非类别的信息。当T=1时,与之前没有变化;当T越大,曲线的波峰就会越来越平滑。
知识蒸馏的过程:
第一步:有一个已经训练好的Teacher model,把很多数据喂给Teacher model,再把数据喂给(未训练/半成品)Student model,两个都是在T=t时经过Softmax,然后计算这两个的损失函数值,让它们两个越接近越好,学生在模拟老师的预测结果。
第二步:Student model在T=1情况下经过softmax操作,把预测结果hard prediction和真实数据的结果hard label进行求损失值,希望它们两个越接近越好。
总结:Student model(T=t)与Teacher model(T=t)的预测结果越来越接近;Student model(T=1)的预测结果与数据结果(标准答案)越来越接近。
Loss = k1*distillation Loss+k2*student Loss。(加权求和)
在使用Student model时只需要输入数据就行,不需要T,因为模型的参数已经训练完成了,最后只需要经过基础softmax操作得到最终结果。
实验结果:
使用MNIST数据集训练Teacher model,把MNIST数据集中去除”3“相关的所有数据集来训练Student model,实验结果证明,经过知识蒸馏后,没有学习过”3“的Student model可以识别出”3“。
Soft targets可以仅仅使用3%的训练集来训练并达到近似Teacher model的效果。
知识蒸馏的应用场景:
①模型压缩
②优化训练,防止过拟合
③无限大、无监督数据集的数据挖掘文章来源:https://www.toymoban.com/news/detail-740747.html
④少样本、零样本学习文章来源地址https://www.toymoban.com/news/detail-740747.html
到了这里,关于知识蒸馏(Knowledge Distillation)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!