蒸餾是一個化學上的詞匯,百科上對於蒸餾的解釋為:“蒸餾是一種熱力學的分離工藝,它利用混合液體或液-固體系中各組分沸點不同,使低沸點組分蒸發,再冷凝以分離整個組分的單元操作過程,是蒸發和冷凝兩種單元操作的聯合”。原理如下圖所示:
蒸餾的目的是在不同的溫度下提出瞭特定的成分。說回到知識蒸餾(knowledge distillation),其是模型壓縮的一種常用的方法,最早得到推廣的版本是由Hinton在2015年[1]提出並應用在分類任務上。與蒸餾的目的一致,在知識蒸餾中,希望通過提取性能更好的大模型的監督信息,構建一個小模型,同時使得小模型具有較好的性能和精度,而此處的大小模型成為Teacher,小模型稱為Student。
隨著計算能力的不斷提升,現在的模型也越來越大,網絡也越來越深,結構變得異常復雜,這帶來瞭模型準確率的提升,同時,計算復雜度也隨之提升。知識蒸餾就是一種有效的模型壓縮的方法,同時能夠使得壓縮後的模型的效果並未下降太多。
在知識蒸餾中,首先需要有一個大模型,也稱為Teacher模型,該模型的特點是模型復雜,此時需要對該模型壓縮,得到一個較小的模型,也稱為Student模型,在蒸餾的過程中,將Teacher模型中學習到的“知識”遷移到Student模型中,以使得Student模型具有與Teacher模型一致的效果。具體的過程如下圖所示:
a54535eae22da8b84f26f1259ed6b5af
對於一個完整的知識蒸餾過程,有兩個模型,分別為Teacher模型和Student模型,通過學習將已經訓練好的Teacher模型中的知識遷移到小的Student模型中。其具體過程如下圖所示[2]:
對於Student模型,其目標函數有兩個,分別為蒸餾的loss(distillation loss)和自身的loss(student loss),其最終的損失函數為:
L=alpha L_{soft}+beta L_{hard}
其中, L_{hard} 為student模型自身的損失,對於分類問題來說,可以通過交叉熵計算 L_{hard} :
L_{hard}=-sum_{i=1}^{n}c_ilogleft ( q_i right )
其中, c_i 為樣本的真實標簽,對於分類問題來說即為0或者1, q_i 為Student模型的輸出。通常, q_i 可以通過softmax計算得到:
q_i=frac{expleft ( z_i right )}{sum _jexpleft ( z_j right )}
對於softmax的計算,是在網絡的logits結果 z_i 上,在softmax計算後得到的概率分佈會放大logits,會使得類目之間的差異變大。因此在知識蒸餾中,通常在logits的基礎上加上一個溫度變量 T ,來對logits結果縮放:
q_i=frac{expleft ( z_i/T right )}{sum _jexpleft ( z_j/T right )}
當 T=1 時即為正常的輸出,上述的 L_{hard} 即在 T=1 的情況下計算得到。對於 L_{soft} 的計算,通常有兩種方式,一種是計算softmax輸出結果的差異,另一種是直接比較logits結果的差異。
對於softmax結果的差異,由於softmax的結果是概率分佈,因此可通過交叉熵計算分佈之間的差異:
L_{soft}=-sum_{i=1}^{n}p_ilogleft ( q_i right )
其中, p_i 為Teacher模型的輸出, q_i 為Student模型的輸出。且輸出是在 T=t 的情況下計算得到。
對於logits結果的差異,可以直接比較Teacher網絡和Student網絡輸出logits的平方差,即:
L_{soft}=sum_{i=1}^{n}left ( v_i-z_i right )^2
其中, v_i 為Teacher模型的logits輸出, z_i 為Student模型的logits輸出。
知識蒸餾通過對Teacher模型的壓縮得到效果接近的Student模型,由於網絡模型復雜度的減小,使得壓縮後的Student模型的性能得到較大提升。
[1] Hinton G , Vinyals O , Dean J . Distilling the Knowledge in a Neural Network[J]. Computer Science, 2015, 14(7):38-39.
[2] Knowledge Distillation
[3] 【經典簡讀】知識蒸餾(Knowledge Distillation) 經典之作
[4] 一分鐘帶你認識深度學習中的知識蒸餾
上一篇
聲明 | 本文不含商業合作作者 | 夜風遊戲行業曾創造過不少財富神話。在一些上市公司眼裡,遊戲業務就是能拯救報表的「萬金油...
AnyConnect IOS系统的使用教程: 1、下载客户端软件 打开App Store应用商店搜索“anyconnect”下载并安装 AnyConnect苹果IOS系 ...