DARTS是第一個提出基於松弛連續化的,使用梯度下降進行搜索的神經網絡架構搜索(neural architecture search, NAS)算法,將最早礦佬們(說的就是你Google)花成千上萬個GPU-hour(即用一塊卡跑一小時)的搜索算法降低到瞭一塊卡四天就能跑完。這使得我們這種窮苦的實驗室也有瞭研究NAS這個酷炫方法的可能,不愧是年輕人的第一個NAS模型哈哈哈。
DARTS最大的貢獻在於使用瞭Softmax對本來離散的搜索空間進行瞭連續化,並用類似於元學習中MAMAL的梯度近似,使得隻在一個超網絡上就可以完成整個模型的搜索,無需反復訓練多個模型。(當然,之後基於演化算法和強化學習的NAS方法也迅速地借鑒瞭超網絡這一特性)。
而實際上,DART還有個更大的貢獻就是開源瞭。不過源代碼是基於Pytorch 0.3寫的,和主流的1.x版本差別很大,所以需要一段時間進行重新復現。我現在寫瞭1.4版本的分佈式並行搜索代碼,等後續可視化和保存模型都摸熟瞭會發實現的教程(在做瞭在做瞭)。
除此之外,DARTS雖然想法十分優雅,但我個人覺得其在假設上有不少值得推敲的地方。在略讀瞭CVPR2020裡有關NAS的20篇論文後,的確很多工作都是針對我覺得很“有趣”的假設做的233。之後,我也會持續更新CVPR2020中NAS有關的論文,尤其是基於梯度下降方法的。
DARTS通過以可微分的方式描述任務來解決架構搜索的可擴展性挑戰。
與傳統的在離散的、不可微的搜索空間上應用進化或強化學習的方法不同(這些方法需要再一堆離散的候選網絡中間搜索),我們的方法基於連續松弛的結構表示,允許在驗證集上使用梯度下降對結構進行高效搜索。
DARTS在大搜索空間中搜索有復雜拓撲結構的高效架構cell,而過去使用梯度下降更多是找濾波器形狀,分支方式之類的低維超參數。
b65adf830d3e1381df1e62fdbabece1f
上圖為DARTS的算法示意圖。首先,這裡的搜索空間(對CNN而言)是一個組成模型的Cell,其可以被描述成一個有 N 個節點的有向無環圖。 在這個有向無環圖中,有節點 x^{(i)} 和邊 o^{(i, j)} ,其中
而每個中間節點(特征圖)都是由有向無環圖中所有的前繼節點計算得來的,即:
x^{(j)}=sum_{i<j} o^{(i, j)}left(x^{(i)}right)
此外,所有的邊(操作)都是在一個候選操作集 mathcal{O} 中選取出來的。
在傳統的方法中,為瞭在候選操作集 mathcal{O} 中尋找最好的操作,都是使用強化學習或者演化算法等啟發式算法取某個操作,也就是說這種選擇是非此即彼的離散操作。
離散的操作不好求導,所以需要引入連續松弛化這個概念。 具體地,實際上在搜索過程中,操作集的每個操作都會處理每個節點的特征圖。之後,再對所有所有操作得到的結果加權求和,即
bar{o}^{(i, j)}(x)=sum_{o in mathcal{O}} frac{exp left(alpha_{o}^{(i, j)}right)}{sum_{o^{prime} in mathcal{O}} exp left(alpha_{o^{prime}}^{(i, j)}right)} o(x)
可以看到這裡引入瞭新的符號 alpha_o^{(i,j)} ,其含義為:第 i 個特征圖到第 j 個特征圖之間的操作 o^{(i, j)} 的權重。這也是我們之後需要搜索的架構參數。
舉個例子,如果這個操作的權重 alpha_o^{(i,j)}=0 ,那麼就可以認為我們完全不需要這個操作。
而為瞭保證所有節點的輸出大致穩定,我們要對每兩個節點之間的架構參數(即操作的權重)進行Softmax操作,即 frac{exp left(alpha_{o}^{(i, j)}right)}{sum_{o^{prime} in mathcal{O}} exp left(alpha_{o^{prime}}^{(i, j)}right)} 。
可以看到,如果每個操作的權重確定,那麼最終的網絡架構也隨之確定,因此我們後續可以稱 alpha 為網絡架構(的編碼)本身。
和之前基於強化學習方法的獎勵函數,或者基於演化算法的種群適應性一樣。DARTS的優化目標也是在驗證集上的損失函數(隻不過DARTS直接用梯度下降優化)。
這裡令訓練損失和驗證損失分別為 mathcal{L}_{train} 和 mathcal{L}_{val} 。網絡中操作的的權重為 w ,有 * 上標則說明其為最優的。 因此,我們其實希望找到的是一個能在訓練集訓練好之後(最優權重 w^* ),在驗證集上損失最小的架構( alpha ^ * )。
這裡就有個問題,每次我們判斷架構好不好的之前,首先他要先在訓練集上收斂,即 w^{*}=operatorname{argmin}_{w} mathcal{L}_{mathrm{train}}left(w, alpha^{*}right) 。而最優的權重本身必然是和架構對應的,架構變化,對應的權重也會跟著變化。
把上面這個過程用數學語言描述,就是以架構 alpha 為上級變量,權重 w 為下級變量的兩級最優化問題:
begin{array}{cl} min _{alpha} & mathcal{L}_{v a l}left(w^{*}(alpha), alpharight) \ text { s.t. } & w^{*}(alpha)=operatorname{argmin}_{w} mathcal{L}_{t r a i n}(w, alpha) end{array}
但是,實際上這種問題看起來復雜,卻在元學習領域十分常見。尤其是基於梯度下降的超參數優化問題(比如著名的MAMAL)。所以,你也可以把這個問題看成元學習問題,而架構參數本身也是超參數,隻不過這個超參數維度高的有點點離譜233...
之前也提到過,DARTS對架構參數的更新方法實際上是在驗證集上對架構參數做梯度下降。
但是這又有個問題,那就是由二級最優化的定義,每次更新架構參數都理應重新訓練模型的權重,但這顯然是不可接受的(因為太慢瞭……)。DARTS算法實際上在驗證集上的搜索過程中,權重是不會變的,這就需要某種梯度近似的方法。這裡先給出作者提出的(二階)梯度近似(其中 xi 是權重的學習率)
begin{aligned} & nabla_{alpha} mathcal{L}_{v a l}left(w^{}(alpha), alpharight) approx nabla_{alpha} mathcal{L}_{v a l}left(w-xi nabla{w} mathcal{L}_{t r a i n}(w, alpha), alpharight) end{aligned}
這種近似在架構於訓練集上達到局部極值點( nabla_{omega} mathcal{L}_{train }(omega, alpha)=0 )時, omega=omega^{*}(alpha) 。
也就是說,這種近似實際上是用 w-xi nabla_{w} mathcal{L}_{t r a i n}(w, alpha) (訓練集上對權重執行一次梯度下降)來近似最優權重 w^{*}(alpha) 。
用人話來講,實際上就是交替進行以下兩步:
類似的方法在元學習(MAMAL)、基於梯度的超參數調整與避免GAN崩潰(unrolled GAN)中都能看到。
那麼這個近似梯度究竟需要如何求解呢?
PS:以下流程很大程度上參考瞭浙大李斌大佬的專欄(羨慕一波數學功底),也推薦配合食用。
首先,上面這個近似的梯度涉及瞭二元復合函數,因此對其求導需要用到鏈式法則。 為瞭簡單可以先將 nabla_{alpha} mathcal{L}_{v a l}left(w-xi nabla_{w} mathcal{L}_{t r a i n}(w, alpha), alpharight) 記為 nabla_{alpha} fleft(g_{1}(alpha), g_{2}(alpha)right) ,其中:
現在對這個復合函數求導
begin{aligned} & nabla_{alpha} fleft(g_{1}(alpha), g_{2}(alpha)right) \ =& nabla_{alpha} g_{1}(alpha) cdot D_{1} fleft(g_{1}(alpha), g_{2}(alpha)right)+nabla_{alpha} g_{2}(alpha) cdot D_{2} fleft(g_{1}(alpha), g_{2}(alpha)right) end{aligned}
其中( w^{prime}=w-xi nabla_{w} mathcal{L}_{t r a i n}(w, alpha) ):
帶入並整理,我們可得到具體可求的梯度:
begin{aligned} & nabla_{alpha} mathcal{L}_{v a l}left(omega-xi nabla_{omega} mathcal{L}_{t r a i n}(omega, alpha), alpharight) \ =& nabla_{alpha} mathcal{L}_{v a l}left(omega^{prime}, alpharight)-xi nabla_{alpha, omega}^{2} mathcal{L}_{t r a i n}(omega, alpha) cdot nabla_{omega^{prime}} mathcal{L}_{v a l}left(omega^{prime}, alpharight) end{aligned}
為啥說是具體可求呢,因為在上式第二行的結果中 w' 變成瞭一個常數,而不是之前一個變量為架構參數 alpha 的復合函數!
但是,這個梯度的第二項依然十分麻煩,因為對兩個變量(權重和架構參數)的二階梯度以及權重的梯度求解涉及到很麻煩的向量-矩陣乘積。
因此作者提出使用有限差分近似來求解,具體地,設有一小標量 epsilon (經驗中取 epsilon=0.01 /left|nabla_{w^{prime}} mathcal{L}_{v a l}left(w^{prime}, alpharight)right|_{2} )。
nabla_{alpha, omega}^{2} mathcal{L}_{text {train}}(omega, alpha) cdot nabla_{omega^{prime}} mathcal{L}_{text {val}}left(omega^{prime}, alpharight) approx frac{nabla_{alpha} mathcal{L}_{text {train}}left(omega^{+}, alpharight)-nabla_{alpha} mathcal{L}_{text {train}}left(omega^{-}, alpharight)}{2 epsilon}
其中 omega^{pm}=omega pm epsilon nabla_{omega^{prime}} mathcal{L}_{v a l}left(omega^{prime}, alpharight)
這個具體是怎麼做到的呢,答案是——泰勒展開(說到底都是本科知識,但就是想不起來)
fleft(x_{0}+hright)=fleft(x_{0}right)+frac{f^{prime}left(x_{0}right)}{1 !} h+ldots
現在我們用 hA 來替代 h ,則有
begin{array}{l} fleft(x_{0}+h Aright)=fleft(x_{0}right)+frac{f^{prime}left(x_{0}right)}{1 !} h A+ldots \ fleft(x_{0}-h Aright)=fleft(x_{0}right)-frac{f^{prime}left(x_{0}right)}{1 !} h A+ldots end{array}
將上面兩個式子相減,可以得到
f^{prime}left(x_{0}right) cdot A approx frac{fleft(x_{0}+h Aright)-fleft(x_{0}-h Aright)}{2 h}
之後,用 epsilon 替代 h , 再把 A 換成 nabla_{omega^{prime}} mathcal{L}_{v a l}left(omega^{prime}, alpharight) ,還有把 x_{0} 換成 w ,最後把 f 換成 nabla_{alpha} mathcal{L}_{text {train}}(cdot, cdot) ,就是有限差分近似的結果啦~
實際上這種有限差分近似隻需要對梯度進行兩次前向傳播,以及對架構進行兩次反向傳播。其計算復雜度也會從 O(|alpha||w|) 降至 O(|alpha|+|w|)
最後,如果你覺得這個有限差分近似依然很煩,你w可以直接把第二項扔掉,即隻保留 nabla_{alpha} mathcal{L}_{v a l}left(omega, alpharight) 。這種操作等價於假設當前的權重 w 就是最優權重 w^* ,此時梯度將退化為一階近似。
一階近似的速度更快,但最後的效果沒有二階近似好。
假設通過之前說的這些流程,架構參數已經訓練的挺不錯瞭。那麼,接下來就要提取真正的模型瞭,因為直至目前,架構依然是計算瞭所有的操作,而所有操作依然是連續組合而不是離散的。 但是,和分類問題一樣,我們可以取出每條邊上權重最大的 $$k$$ 個操作(在CNN中DARTS取2個最大的操作,並忽略0操作)。
實際上,如果從本文作者的角度看DARTS,我能發現有以下假設是值得註意的,從CVPR2020的paper看,這些也是被後來者當靶子的點。 (當然,我要是真的能做出這個檔次的工作,我做夢都會笑醒……)
假設2和6 Follow瞭谷歌大腦18年CVPR的論文Learning Transferable Architectures for Scalable Image Recognition。不過這篇論文中取的兩個輸入具體是之前哪兩個cell實際上是不固定的,而Darts則固定隻取前兩個(存疑)