考虑一个广义的Softmax函数,q的logits为zi其中T是温度,这是从统计力学中的玻尔兹曼分布中借用的概念。容易证明,当温度T趋向0时,softmax输出将收敛为one-hot向量,温度T趋向无穷时,输出更「软」。
qi=exp(zi/T)∑jexp(zj/T)因此,在知识蒸馏中,训练新模型的时候,可以使用较高的T,使得softmax产生的分布足够软,这时让新模型的softmax输出近似原模型;在训练结束以后再使用正常的温度T=1来预测。记新模型产生的分布为q,原模型产生的分布为p,p的logits为vi(下面的推导只需要把T设为1,p设为one-hot向量,就是平时用数据集从头训练模型时的softmax+交叉熵得到的损失函数)
需要最小化的损失函数为C:
C=−p⊤logq下面求C关于z的偏导数,由链式法则得:
∂C∂z=∂q∂z∂C∂qp是原模型产生的softmax输出,与q无关。
∂C∂qi=−piqi∂C∂q是一个n维向量:
∂C∂q=[−p1q1−p2q2⋮−pnqn]∂q∂z是一个n×n的方阵,记Z=∑kexp(zk/T),可以求得qi关于zj的偏导为:
∂qi∂zj=1Z2(Z∂exp(zi/T)∂zj−exp(zi/T)[∂Z∂zj])右侧方框部分可以展开为
∂Z∂zj=1Texp(zj/T)代入上式将括号展开,可以得到:
∂qi∂zj=1Z∂exp(zi/T)∂zj−1TZ2exp(zi/T)exp(zj/T)=1Z∂exp(zi/T)∂zj−1Texp(zi/T)Zexp(zj/T)Z=1Z[∂exp(zi/T)∂zj]−1Tqiqj左侧方框分类讨论得:
∂exp(zi/T)∂zj={1Texp(zi/T), if i=j0, if i≠j代入上式得:
∂qi∂zj={1T(exp(zi/T)Z−qiqj), if i=j−1Tqiqj, if i≠j1T(qi−qiqj), if i=j−1Tqiqj, if i≠j所以∂q/∂z等于:
∂q∂z=1T[q1−q21−q1q2⋯−q1qn−q2q1q2−q22⋯−q2qn⋮⋮⋱⋮−qnq1−qnq2⋯qn−q2n]这里就是为什么softmax函数对其输入的偏导是下列形式的原因,g(·)函数为softmax函数,x为输入向量,维度为d。
∂g(x)∂x=diag(ˆy)−ˆyˆy⊤∈Rd×d∂g(x)∂x=[ˆy10⋯00ˆy2⋯0⋮⋮⋱⋮00⋯ˆyd]−[ˆy21ˆy1ˆy2⋯ˆy1ˆydˆy2ˆy1ˆy22⋯ˆy2ˆyd⋮⋮⋱⋮ˆydˆy1ˆydˆy2⋯ˆy2d]回到我们的问题,继续推导,可以得到:
∂C∂z=1T[q1−q21−q1q2⋯−q1qn−q2q1q2−q22⋯−q2qn⋮⋮⋱⋮−qnq1−qnq2⋯qn−q2n][−p1q1−p2q2⋮−pnqn]=1T[−p1+∑kpkq1−p2+∑kpkq2⋮−pn+∑kpkqn]=1T[−p1+q1−p2+q2⋮−pn+qn]=1T(q−p)所以:
∂C∂zi=1T(qi−pi)
v1.5.2