BAGAN: Data Augmentation with Balancing GAN

摘要

圖像分類數據集往往是不平衡的,這一特點對深度學習分類器的精度產生瞭負面影響。本文提出平衡GAN (BAGAN)作為一種增強工具,以恢復不平衡數據集的平衡。這是具有挑戰性的,因為少數少數類圖像可能不足以訓練GAN。我們通過在對抗性訓練中包括大多數類和少數類的所有可用圖像來克服這個問題。生成模型從多數類中學習有用的特征,並使用這些特征為少數類生成圖像。在潛空間中應用類條件作用,以驅動目標類的生成過程。GAN中的生成器用自動編碼器的編碼器模塊初始化,使我們能夠在潛空間中學習準確的類條件。將所提出的方法與最先進的GANs進行瞭比較,證明瞭BAGAN在用不平衡數據集訓練時生成瞭高質量的圖像。

引言

當訓練數據集是均衡的,即可用數據在不同類別之間不均勻分佈時,圖像分類技術的準確性會顯著下降。不平衡的數據集很常見,緩解這個問題的傳統方法是通過引入額外的少數類圖像來增加數據集,這些圖像是通過對原始圖像進行簡單的幾何變換得到的,例如旋轉或鏡像。當方向相關時,這種增強方法可能會破壞與方向相關的特征。本文提出一種平衡的生成式對抗網絡(BAGAN)作為增強工具,通過生成新的少數類圖像來恢復數據集平衡。由於這些圖像在初始數據集中非常稀少,因此訓練GAN生成新的圖像是一項挑戰。為瞭克服這個問題,所提出的方法一次性將來自少數類和多數類的所有數據包含在對抗性訓練中。這使BAGAN能夠從所有圖像開始學習特定分類問題的底層特征,然後將這些特征應用於生成新的少數類圖像。例如,讓我們考慮道路交通標志的分類[1]。所有的警告標志都具有相同的外部三角形。一旦BAGAN學會瞭用其中一個符號來畫這個形狀,我們就可以用它來畫任何其他形狀。由於BAGAN從所有類別開始學習特征,而目標是為少數類別生成圖像,因此需要一種機制來推動生成過程向所需類別方向發展。為此,本文將類條件應用於潛空間[2,3]。我們使用自編碼器初始化GAN中的鑒別器和生成器。然後,利用這個自編碼器來學習潛空間中的類條件,即學習生成模型的輸入對於不同的類應該是什麼樣子。此外,這種初始化使我們能夠從更穩定的點開始對抗性訓練,並有助於緩解傳統GANs產生的收斂問題[4,5,6,7]。這項工作的主要貢獻是:

一種用不平衡數據集訓練GANs的整體方法,同時特別旨在生成少數類圖像。基於自動編碼器的初始化策略,使我們能夠a)從一個好的初始解決方案開始訓練GAN, b)學習如何在生成器的潛空間中編碼不同的類。根據最新技術對所提出的BAGAN方法進行實證評估。

實驗結果表明,當訓練數據集不平衡時,所提出的BAGAN方法在生成圖像的多樣性和質量方面優於最先進的GAN方法。反過來,這導致在增強後的數據集上訓練的最終分類器具有更高的精度

背景

近年來,生成對抗神經網絡(GANs)[8,9,7]被提出作為一種人工生成逼真圖像的工具。其基本思想是在對抗模式下訓練一個生成網絡,對抗鑒別器網絡。生成對抗模型的一個眾所周知的問題是,當它們學習愚弄鑒別器時,它們最終可能會畫出一個或幾個愚蠢的例子。這個問題被稱為模式崩潰[8,4,5,6]。本文的目標是增強一個不平衡的圖像分類數據集,以恢復其平衡。最重要的是,增強後的數據集足夠可變,並且不包括連續重復的示例,因此我們需要避免模式崩潰。為此目的,提出瞭不同的方法。可能的解決方案是:明確促進生成器損失[10,4]中的圖像多樣性,讓生成器預測鑒別器的未來變化並對這些[11]進行自適應,讓鑒別器區分不同的類別[2,12],應用特定的正則化技術[5,6],以及將GANs與自編碼器[4,13,14,15]耦合。在這項工作中,我們應用瞭後一種方法,並結合瞭GAN和自編碼技術。引用的方法包括GAN中的其他模塊,以在整個訓練過程中嵌入自編碼器。在提出的BAGAN方法中,我們采用瞭一種更實用的方法,並使用自編碼器來初始化GAN模塊,使其接近於一個良好的解決方案,而遠離模式崩潰。由於我們的目標是專門為少數類生成圖像,因此訓練瞭一個生成器,該生成器就其繪制的圖像類而言是可控的,類似於最先進的ACGAN方法[2]。然而,ACGAN並不是專門針對不平衡數據集的,在針對少數類圖像的生成時往往是有缺陷的。

實例

最先進的gan不適合處理不平衡的數據集[16],據我們所知,所提出的BAGAN方法是第一個專門解決這個主題的方法。在介紹所提出方法的細節之前,讓我們用一個簡單的例子來演示為什麼很難應用現有的GAN技術來解決手頭的問題。讓我們考慮手寫數字的分類,從MNIST數據集[17]的不平衡版本開始,我們從訓練集中刪除瞭97.5%的可用0

一個簡單的想法是使用傳統的GAN[8, 9, 18],通過使用所有可用的數據對其進行訓練,生成許多隨機樣本,找到0實例,並使用這些實例來增強數據集。這種方法不能普遍應用:如果GAN中的生成器G被訓練成通過生成真實的圖像來欺騙鑒別器D,它將更好地專註於多數類的生成以優化其損失函數,同時坍縮與少數類相關的模式。另一方面,僅使用少數類圖像來訓練GAN實際上不是一個選項,因為少數類圖像非常稀少。在這個例子中,在去除97.5%的零之後,我們剩下大約150張少數類圖像。一般來說,很難從非常少的數據集開始訓練GAN, GAN必須有許多示例可以從[19]學習

另一種方法是用多數類和少數類聯合訓練GAN,並讓GAN明確區分不同的類。在訓練過程中,明確要求生成器繪制每個類別的圖像,並讓鑒別器相信生成的圖像是所需類別的真實圖像。在這樣做的時候,生成器會因為繪制每個類別(包括少數類別)的真實圖像而得到明確的獎勵。據我們所知,到目前為止,實現這種方法的唯一方法是ACGAN[2],其中生成器輸入可以條件化以繪制目標類。在ACGAN中,鑒別器有兩個輸出,一個用於區分真假圖像X,另一個根據其類別c對X進行分類,圖1(a)。在訓練過程中,明確要求生成器為每個類c繪制圖像Xc。調整生成器參數以最大化兩個分量的疊加。第一個分量是生成鑒別器認為真實的圖像Xc的對數似然。第二個分量是生成鑒別器與類別c相關聯的圖像Xc的對數可能性。我們觀察到,當數據集不平衡時,這兩個分量對少數類來說是矛盾的。這可以解釋如下。讓我們假設在某一時刻,生成器收斂到一個解決方案,在該解決方案中,它生成具有真實質量的少數類圖像。鑒別器將無法區分這些圖像與訓練數據集中的圖像。由於在訓練數據集中,少數類圖像很少,所以當少數類圖像在訓練期間傳遞給鑒別器時,它很可能是假圖像。為瞭優化其損失函數,鑒別器必須將假標簽與所有少數類圖像關聯起來。在這一點上,兩個生成器目標是矛盾的,生成器可以繪制看起來真實的圖像,也可以繪制代表少數類但不能同時實現這兩個目標的圖像。反過來,生成器可以因繪制看起來真實且不代表目標少數類的圖像而得到獎勵。這一事實惡化瞭生成圖像的質量。ACGAN為0位數字生成的不平衡MNIST示例圖像如圖2(a)所示。本文提出BAGAN,將類條件應用於ACGAN,但在以下幾點上有所不同。

首先,BAGAN判別器有一個輸出,返回特定於問題的類標簽c或標簽fake,如圖1(b)所示。鑒別器D被訓練為將標簽fake與G生成的圖像相關聯,並將標簽Xc與真實圖像相關聯。生成器被訓練為避免假標簽並匹配所需的類標簽。由於這現在被定義為一個單一目標而不是兩個目標的疊加,通過構建,它不能自相矛盾,並且如果鑒別器沒有將看起來真實的圖像Xc與所需的類標簽c匹配,生成器永遠不會對生成的圖像Xc獲得獎勵。其次,BAGAN將GAN和自編碼技術結合起來,以提供類條件的精確選擇,並更好地避免模式崩潰。BAGAN生成的不平衡MNIST示例的圖像具有優越的質量,如圖2(b)所示。

BAGAN

所提出的BAGAN方法旨在為不平衡數據集生成真實的少數類圖像。它利用特定分類問題的所有可用信息,將多數類和少數類聯合納入BAGAN訓練中。GAN和自編碼技術相結合,以利用兩種方法的優勢。GANs生成高質量的圖像,而自動編碼器很容易收斂到良好的解決方案[7]。一些作者建議將GANs和自動編碼器耦合起來[4,13]。盡管如此,這些工作並不直接意味著將GAN生成過程推向特定的類別。將它們泛化以使GAN能夠區分不同的類並不容易。如激勵示例所述,在這項工作中,我們應用Odena等人建議的類條件反射。[2]在BAGAN中嵌入類知識

我們實用地使用自動編碼器來初始化GAN,使其接近於一個良好的解決方案,而遠離模式崩潰。此外,應用自編碼器的編碼器部分來推斷潛空間中不同類別的分佈。基於自動編碼器的GAN初始化是通過在自動編碼器和GAN模塊中使用相同的網絡拓撲來實現的,圖3(a)和3(b)。自動編碼器的解碼階段∆與生成器g的拓撲結構匹配。自動編碼器的編碼階段E與鑒別器De的第一層拓撲結構匹配。在BAGAN中,通過相應地初始化參數權重,自動編碼器中的知識被轉移到GAN模塊中,如圖3(b)所示。為瞭完成鑒別器,具有softmax激活函數的最後一個密集層Dd將潛在特征轉換為圖像是假的或它屬於某個問題類別c1-cn的概率。當GAN模塊初始化時,通過學習圖像在潛空間中不同類別的概率分佈,建立一個類條件潛向量生成器。然後,通過進行傳統的對抗性訓練來微調生成器和鑒別器中的所有權重,圖3(c)。總的來說,BAGAN訓練方法被組織為圖3所示的三個步驟:a)自動編碼器訓練,b) GAN初始化,和c)對對抗訓練。

Autoencoder培訓。通過使用訓練數據集中的所有圖像來訓練自動編碼器。自動編碼器沒有明確的類別知識,它無條件地處理來自多數類和少數類的所有圖像。本文將l2損失最小化應用於自編碼器訓練

GAN初始化。與自編碼器不同,生成器G和鑒別器D具有明確的類知識。在對抗性訓練期間,G被要求為不同的類生成圖像,D被要求將圖像標記為假的或特定問題的類標簽c。在GAN初始化時,通過使用解碼器∆中的權重初始化G,以及使用編碼器E的權重初始化判別器De的第一層,將自編碼器知識轉移到GAN模塊中,如圖3(b)所示。判別器Dd的最後一層是一個具有softmax激活函數的密集層,並生成最終的判別器輸出。最後一層的權重隨機初始化,並在對抗性訓練期間學習鑒別器初始化隻是用於在D中包含有助於圖像分類的有意義的特征。生成器的初始化有一個深層次的原因。當對抗性訓練開始時,生成器G等同於解碼器∆。因此,輸入到生成器G的潛向量Z等效於自動編碼器的潛空間中的一個點,即Z可以被視為E的輸出或∆的輸入。因此,編碼器E將真實圖像映射到g使用的潛空間中。我們利用這一事實在開始對抗性訓練之前學習一個良好的類條件,即我們定義類c圖像的潛向量Zc應該是什麼樣子的。我們在潛空間中使用多元正態分佈Nc = N (μ c, Σc)和平均向量μ c和協方差矩陣Σc對類進行建模。對於每個類別c,考慮到訓練數據集中c類的所有真實圖像Xc,我們計算μ c和Σc以匹配Zc = E(Xc)的分佈。我們用這些概率分佈初始化類條件潛向量生成器,這是一個隨機過程,將類標簽c作為輸入,並將從Nc隨機抽取的潛向量Zc作為輸出。在對抗性訓練中,概率分佈Nc被認為是不變的,迫使生成器不會偏離潛空間中的初始類編碼。

對抗訓練。在對抗性訓練期間,數據批量流過生成器G和鑒別器D,它們的權重被微調以優化它們的損失函數。鑒別器將輸入圖像分類為屬於n個問題特定類中的一個或為假圖像。對於我們提供的每個批次,總圖像中有1/(n + 1)是假的,即我們為假類別提供瞭最佳的平衡。假數據作為G的輸出生成,G將從類條件潛向量生成器中提取的潛向量Zc作為輸入。反過來,類條件潛向量生成器將均勻分佈的類標簽c作為輸入,即假圖像均勻分佈在特定問題的類之間。當訓練鑒別器D時,我們優化稀疏分類交叉熵損失函數,以匹配真實圖像的類標簽和生成圖像的假標簽。對於鑒別器學習的每個批次,生成器g學習相同大小的批次。為此,通過在標簽c上應用均勻分佈隨機抽取一批條件潛向量Zc。這些向量由生成器處理,輸出圖像被輸入鑒別器。G中的參數被優化,以匹配鑒別器選擇的標簽與用於生成圖像的標簽c。

結果

在四個數據集上驗證瞭所提出的方法。考慮:MNIST[17]、CIFAR-10[20]、Flowers[21]和GTSRB[1]。前兩個數據集眾所周知,Flowers是一個小數據集,包含五類鮮花的真實照片,我們將其重塑為224x224的分辨率,GTSRB是一個交通標志識別數據集。這些數據集的詳細信息如表1所示。前3個數據集是平衡的,GTSRB是不平衡的。我們通過選擇一個類並從訓練集中刪除其大量實例來強制前三個數據集的不平衡。我們對每個類別重復這個過程,並為每個產生的不平衡數據集訓練不同的生成模型。當將每個類別作為少數類進行訓練時,總是會得到以下結果,我們將訓練集中遺漏的圖像稱為丟棄的圖像。由於GTSRB已經不平衡,我們不再進一步不平衡它

將所提出的BAGAN模型與最先進的ACGAN模型[2]進行瞭比較。據我們所知,ACGAN是迄今為止文獻中提出的唯一一種考慮類條件的方法,以從包括多個類的數據集(第3節)開始繪制目標類的圖像。BAGAN和ACGAN都是通過聯合使用多數類和少數類在目標數據集上進行訓練的。我們還考慮一種簡單的GAN方法,通過僅在少數類上進行訓練來學習繪制少數類圖像。為瞭公平比較,我們限制瞭所考慮的方法(BAGAN, ACGAN和GAN)之間的架構更改。本文描述瞭BAGAN和ACGAN之間的區別(即鑒別器輸出拓撲和基於自編碼器的初始化)。對於簡單的GAN,我們調整參考ACGAN鑒別器輸出以僅區分真假圖像,並刪除生成器輸入的類條件(此GAN僅在來自少數類的圖像上進行訓練)。圖4、5和6顯示瞭為CIFAR-10和GTSRB中代表最多和最少的三個類生成的代表性圖像的定性分析。對於CIFAR-10,我們隻顯示少數類圖像的結果。對於每個類,40%的該類圖像被刪除,訓練瞭生成模型,並顯示瞭隨機生成的圖像,圖4。對於CIFAR-10,簡單的GAN坍縮到每個類生成一個圖像示例。為瞭訓練這個GAN,我們隻使用瞭3000張少數類圖像(40%的少數類圖像被刪除,大多數類不包含在訓練中)。對抗性網絡需要許多示例來學習繪制新圖像[19],在這種情況下,簡單的GAN崩潰。對於ACGAN和BAGAN來說,這個問題不太相關,因為它們可以從少數類和多數類中共同學習特征。為瞭更好地理解ACGAN和BAGAN的不同行為,讓我們關註GTSRB數據集圖5和6。這個數據集最初是不平衡的,我們訓練生成模型而沒有修改它。對於大多數類別,ACGAN和BAGAN都返回高質量的結果,圖5(c)和5(b)。盡管如此,ACGAN在為少數類繪制圖像時失敗,並在為每個類生成單個示例時崩潰。在某些情況下,ACGAN生成的圖像不能代表所需的類,例如,圖6(c)中的第二行應該是一個警告標志,而繪制瞭速度限制。如果BAGAN繪制的圖像不能代表理想的類別,那麼他永遠不會得到獎勵。因此,BAGAN沒有表現出這種行為

生成圖像的定量評估

由於我們的目標是利用生成模型通過生成額外的少數類圖像來增強不平衡的數據集,因此我們的目標如下a)生成的圖像必須表示所需的類。b)生成的圖像不得重復。c)生成的圖像必須不同於訓練集中已經存在的真實圖像

Missing to meet a)意味著生成模型無法生成準確表示目標類別的圖像,它們看起來要麼是其他類別的真實示例,要麼看起來不真實。Missing to meet b)意味著生成模型坍縮到生成單個或少數模式。缺失滿足c)意味著我們隻是學會瞭重繪可用的訓練圖像。在這三個目標的基礎上評估瞭生成圖像的質量

生成圖像的準確性。為瞭驗證由考慮的方法生成的圖像代表瞭所需的類,通過在整個原始數據集上訓練的深度學習模型對它們進行分類,並驗證預測的類是否與目標類匹配。在這項工作中,我們使用ResNet18模型[22]。結果如圖7所示。簡單的GAN對生成的圖像返回的精度最差。所提出的BAGAN方法總體上優於其他方法,並生成瞭ResNet-18模型能夠以最高精度分類的圖像。我們再次觀察到,強烈的不平衡會顯著惡化生成的圖像的質量,其準確性隨著下降圖像的百分比的增加而下降。當針對MNIST數據集時,這種現象對ACGAN來說最明顯。

生成圖像的可變性。采用結構圖像相似度SSIM[23]來度量兩幅圖像之間的相似度。該指標預測人類感知相似性判斷,當兩幅圖像相同時,返回1,並隨著差異的相關性增加而減少。為瞭驗證生成的圖像是多樣化的,對於每個類,我們重復生成一些圖像並測量它們的相似性SSIM。圖8顯示瞭所考慮的所有類別的平均數據集的多樣性分析。對於MNIST、CIFAR-10和Flowers,我們改變少數類圖像在集合{40,60,80,90,95,97.5}內下降的百分比,而對於GTSRB,我們使用原始不平衡數據集。在分析中,我們還包括一個參考值,即真實圖像對之間的平均SSIM

相對於訓練集的圖像多樣性。評估生成的圖像相對於訓練集中已有圖像的可變性。我們計算生成的圖像與其最近的真實鄰居之間的SSIM。我們將該值與訓練集中的圖像可變性進行比較,即真實圖像與其最近的真實鄰居之間的SSIM值。這些SSIM值彼此非常接近,這意味著沒有出現過擬合。這種說法適用於所有考慮過的方法。特別地,MNIST、CIFAR-10、Flowers和GTSRB的SSIM值分別約為0.8、0.25、0.05和0.5

最終分類的質量

最後評估瞭在增強數據集上訓練的深度學習分類器的準確性。對於MNIST, CIFAR10和Flowers,對於每個類我們:1)將該類選擇為少數類,2)通過從訓練集中刪除該類的一定比例的圖像來生成一個不平衡的數據集,3)訓練所考慮的生成模型,4)通過生成模型來增強不平衡數據集以恢復其平衡,5)為增強數據集訓練ResNet18分類器,6)測量測試集上少數類的分類器精度。由於GTSRB已經不平衡,對於該數據集,跳過步驟1)和2)。將生成模型獲得的增強與普通不平衡數據集和水平鏡像增強方法(mirror)進行比較,其中通過鏡像訓練集中可用的圖像來生成新的少數類圖像。

不同類別的平均精度結果如圖9所示。所提出的BAGAN方法為GTSRB返回瞭最好的精度,在大多數情況下對MNIST也是如此。這兩個數據集的特點是具有對圖像方向敏感的特征,正如預期的那樣,鏡像方法返回瞭最差的精度結果,因為它破壞瞭這些特征。對於CIFAR-10和Flowers,使用鏡像方法獲得瞭最好的精度。對這些數據集進行鏡像不會破壞任何特征,鏡像圖像的質量與原始圖像一樣好。與ACGAN和GAN相比,BAGAN方法仍然提供瞭最好的精度。

從這個分析中得出結論,在從不平衡數據集開始生成少數類圖像時,BAGAN優於其他最先進的對抗性生成網絡。此外,我們得出結論:當由於方向相關的特征而不容易用傳統技術來擴充數據集時,可以應用BAGAN來提高最終的分類精度

結論

本文提出一種方法,通過使用生成對抗網絡來恢復不平衡數據集的平衡。在提出的BAGAN框架中,生成器和鑒別器模塊通過自編碼器進行初始化,以從一個良好的解決方案開始對抗性訓練,並學習如何在潛空間中表示不同的類別。我們將所提出的方法與最新的方法進行瞭比較。實驗結果表明,在從不平衡的訓練集開始生成高質量圖像方面,BAGAN優於其他生成對抗網絡。這反過來導致在恢復平衡的增強數據集上訓練的深度學習分類器的精度更高。

发表回复

相关推荐

顏值進步明顯的十大NBA球員

NBA聯盟是全球最受關註的男子運動體育秀場,想方設法提升自己的球技自然是籃球運動員最為重要的自我要求。不過,在這麼多鎂光...

· 32秒前

國男之殤——傳統與束縛的思考

剛看瞭教師節改為孔子誕辰的新聞,給反對的人點瞭贊,再簡單說兩句。指望某種程度的封建文化復蘇,來對抗“自由思想入侵”可能...

· 6分钟前

衛生巾長度(規格)都有哪些?哪個牌子的衛生巾最好最安全(排行榜前十的)

衛生巾的長度一般為24cm、28cm、32cm不等。其中,24cm和28cm是比較常見的兩種長度。不同品牌和型號的衛生巾長度可能會有所不...

· 7分钟前

新能源智能主控單芯片的解決方案,鴻智電通主控SoC芯片EPC3020

前言近期充電寶市場受到多重要素利好,市場強勢增長,成為眾多3C配件率先復蘇突圍的品類。目前充電寶新國標開始實施發證、全...

· 10分钟前

你吃过很多桂林米粉,却可能一碗正宗的都没吃到过(广西三大粉之桂林米粉)

文 | 王枪枪 中国有3个嗦粉圣地——贵州、湖南和广西。这三省米粉流派众多,拥趸无数,但其中名气最大,走得最远的,还属桂林 ...

· 13分钟前