Loading [MathJax]/jax/output/HTML-CSS/jax.js

Softmax + 交叉熵

考虑一个广义的Softmax函数,q的logits为zi其中T是温度,这是从统计力学中的玻尔兹曼分布中借用的概念。容易证明,当温度T趋向0时,softmax输出将收敛为one-hot向量,温度T趋向无穷时,输出更「软」。

qi=exp(zi/T)jexp(zj/T)

因此,在知识蒸馏中,训练新模型的时候,可以使用较高的T,使得softmax产生的分布足够软,这时让新模型的softmax输出近似原模型;在训练结束以后再使用正常的温度T=1来预测。记新模型产生的分布为q,原模型产生的分布为p,p的logits为vi(下面的推导只需要把T设为1,p设为one-hot向量,就是平时用数据集从头训练模型时的softmax+交叉熵得到的损失函数)

需要最小化的损失函数为C:

C=plogq

下面求C关于z的偏导数,由链式法则得:

Cz=qzCq

p是原模型产生的softmax输出,与q无关。

Cqi=piqi

Cq是一个n维向量:

Cq=[p1q1p2q2pnqn]

qz是一个n×n的方阵,记Z=kexp(zk/T),可以求得qi关于zj的偏导为:

qizj=1Z2(Zexp(zi/T)zjexp(zi/T)[Zzj])

右侧方框部分可以展开为

Zzj=1Texp(zj/T)

代入上式将括号展开,可以得到:

qizj=1Zexp(zi/T)zj1TZ2exp(zi/T)exp(zj/T)=1Zexp(zi/T)zj1Texp(zi/T)Zexp(zj/T)Z=1Z[exp(zi/T)zj]1Tqiqj

左侧方框分类讨论得:

exp(zi/T)zj={1Texp(zi/T), if i=j0, if ij

代入上式得:

qizj={1T(exp(zi/T)Zqiqj), if i=j1Tqiqj, if ij1T(qiqiqj), if i=j1Tqiqj, if ij

所以q/z等于:

qz=1T[q1q21q1q2q1qnq2q1q2q22q2qnqnq1qnq2qnq2n]

这里就是为什么softmax函数对其输入的偏导是下列形式的原因,g(·)函数为softmax函数,x为输入向量,维度为d

g(x)x=diag(ˆy)ˆyˆyRd×dg(x)x=[ˆy1000ˆy2000ˆyd][ˆy21ˆy1ˆy2ˆy1ˆydˆy2ˆy1ˆy22ˆy2ˆydˆydˆy1ˆydˆy2ˆy2d]

回到我们的问题,继续推导,可以得到:

Cz=1T[q1q21q1q2q1qnq2q1q2q22q2qnqnq1qnq2qnq2n][p1q1p2q2pnqn]=1T[p1+kpkq1p2+kpkq2pn+kpkqn]=1T[p1+q1p2+q2pn+qn]=1T(qp)

所以:

Czi=1T(qipi)

参考链接:https://zhuanlan.zhihu.com/p/90049906

11FjxS.jpg

Powered By Valine
v1.5.2