考虑一个广义的Softmax函数,q的logits为其中是温度,这是从统计力学中的玻尔兹曼分布中借用的概念。容易证明,当温度趋向0时,softmax输出将收敛为one-hot向量,温度趋向无穷时,输出更「软」。
因此,在知识蒸馏中,训练新模型的时候,可以使用较高的T,使得softmax产生的分布足够软,这时让新模型的softmax输出近似原模型;在训练结束以后再使用正常的温度来预测。记新模型产生的分布为q,原模型产生的分布为p,p的logits为(下面的推导只需要把T设为1,p设为one-hot向量,就是平时用数据集从头训练模型时的softmax+交叉熵得到的损失函数)
需要最小化的损失函数为C:
下面求C关于z的偏导数,由链式法则得:
p是原模型产生的softmax输出,与q无关。
$\frac{\partial C}{\partial q}$是一个n维向量:
$\frac{\partial q}{\partial z}$是一个$n \times n$的方阵,记$Z=\sum_{k} \exp \left(z_{k} / T\right)$,可以求得$q_{i}$关于$z_{j}$的偏导为:
右侧方框部分可以展开为
代入上式将括号展开,可以得到:
左侧方框分类讨论得:
代入上式得:
所以$
\partial q / \partial z
$等于:
这里就是为什么softmax函数对其输入的偏导是下列形式的原因,$g(·)$函数为softmax函数,$x$为输入向量,维度为$d$。
回到我们的问题,继续推导,可以得到:
所以: