【技術實現步驟摘要】
本專利技術屬于神經網絡,尤其涉及異構知識蒸餾中的輔助神經網絡模型訓練方法及裝置。
技術介紹
1、知識蒸餾是一種將大且復雜的神經網絡模型(可稱為教師神經網絡模型)的知識(神經網絡模型通過對數據的訓練所得到的對數據預測的能力)轉移到小且簡單的神經網絡模型(可稱為學生神經網絡模型)中的模型壓縮方法,能讓一個小而簡單的神經網絡模型的預測性能在經過蒸餾之后,得到很大提升。但這種原始知識蒸餾的方法僅適用于教師模型與學生模型屬于同一種類型的神經網絡模型且尺寸差距還不是特別大的情況。
2、為了解決教師與學生模型尺寸相差過大而導致的學生模型性能下降的情況,助教知識蒸餾(teacherassistant?knowledge?distillation,takd)提出了在教師神經網絡模型(簡稱為教師模型)與學生神經網絡模型(簡稱為學生模型)之間增加一個尺寸介于兩者之間的輔助神經網絡模型,將單步蒸餾變為多步蒸餾,減少了相同結構的師生模型尺寸差距過大造成的知識損耗。但是目前的助教知識蒸餾方法無法解決異構師生模型(即不同網絡結構的教師模型與學生模型)在尺寸差距過大的情況下導致的知識損耗問題。
技術實現思路
1、有鑒于此,本專利技術的目的在于提供一種異構知識蒸餾中的輔助神經網絡模型訓練方法及裝置,以解決上述的至少部分問題,其提供的技術方案如下:
2、第一方面,本申請提供了一種異構知識蒸餾中的輔助神經網絡模型訓練方法,包括:
3、分別構建教師神經網絡模型、學生神經網絡模型和輔助神
4、從教師神經網絡模型蒸餾知識,并通過減少輔助神經網絡模型的蒸餾損失更新輔助神經網絡模型的權重ωa;
5、從輔助神經網絡模型蒸餾知識,并通過減少學生神經網絡模型的蒸餾損失更新學生神經網絡模型的權重參數ωs;
6、通過梯度下降最小化學生神經網絡模型的損失函數更新輔助神經網絡模型的結構參數aa。
7、可選地,所述分別構建教師神經網絡模型、學生神經網絡模型和輔助神經網絡模型對應的函數,包括:
8、所述教師神經網絡模型表示為其中,ft是教師神經網絡模型的函數,ωt表示教師神經網絡模型的模型權重,at是教師神經網絡模型的結構參數;
9、所述學生神經網絡模型表示為其中,fs是學生神經網絡模型的函數,ωs表示學生神經網絡模型的模型權重,as是學生神經網絡模型的結構參數;
10、所述輔助神經網絡模型表示為其中,fa是輔助神經網絡模型的函數,ωa表示輔助神經網絡模型的模型權重,aa是輔助神經網絡模型的結構參數。
11、可選地,所述通過減少輔助神經網絡模型的蒸餾損失更新輔助神經網絡模型的權重ωa為:通過公式更新ωa。
12、可選地,所述通過減少學生神經網絡模型的蒸餾損失更新學生神經網絡模型的權重參數ωs為:
13、通過更新ωs。
14、可選地,通過梯度下降最小化學生神經網絡模型的損失函數更新輔助神經網絡模型的結構參數aa,包括:
15、通過梯度下降法使得損失函數滿足來更新aa。
16、可選地,還包括:
17、訓練過程的目標函數如下:
18、
19、
20、
21、其中,ω*表示已經訓練好的神經網絡模型的權重參數。
22、第二方面,本申請還提供了一種異構知識蒸餾中的輔助神經網絡模型訓練裝置,包括:
23、模型函數構建模塊,用于分別構建教師神經網絡模型、學生神經網絡模型和輔助神經網絡模型,模型的輸出由模型權重ω,結構參數a,以及第i個輸入x(i)決定;
24、輔助模型權重優化模塊,用于從教師神經網絡模型蒸餾知識,并通過減少教輔助神經網絡模型的蒸餾損失更新輔助神經網絡模型的權重ωa;
25、學生模型權重優化模塊,用于從輔助神經網絡模型蒸餾知識,并通過減少學生神經網絡模型的蒸餾損失更新學生神經網絡模型的權重參數ωs;
26、輔助模型結構參數更新模塊,用于通過梯度下降最小化學生神經網絡模型的損失函數,更新輔助神經網絡模型的結構參數aa。
27、可選地,所述輔助模型權重優化模塊具體用于:通過公式更新ωa;
28、所述學生模型權重優化模塊具體用于:通過更新ωs。
29、可選地,還包括:
30、訓練過程的目標函數如下:
31、
32、
33、
34、其中,ω*表示已經訓練好的神經網絡模型的權重參數。
35、第三方面,本申請還提供了一種計算設備,其特征在于,包括:處理器、存儲器及存儲在存儲器上并可在處理器上運行的程序,處理器執行程序時實現第一方面任一項所述的異構知識蒸餾中的輔助神經網絡模型訓練方法。
36、本申請實施例提供了一種異構知識蒸餾中的輔助神經網絡模型訓練方法,將輔助神經網絡模型的結構進行參數化,再通過nas技術搜索最優的輔助神經網絡模型的結構,然后以學生神經網絡模型學習的結果為導向的損失函數自動更新輔助神經網絡模型的結構參數,最終找到最優的輔助神經網絡模型結構。該方法可以在不使用人類專家知識的情況下自動尋找對學生神經網絡模型性能減少最優的輔助神經網絡模型。利用該方法可以找到一個最優的輔助神經網絡模型,提升異構知識蒸餾的性能,使得性能良好但占內存太大的教師神經網絡模型,能在最大限度減少知識損耗的情況下把知識轉移到尺寸較小的學生神經網絡模型上,以便將神經網絡模型部署到移動設備上使用。
本文檔來自技高網...【技術保護點】
1.一種異構知識蒸餾中的輔助神經網絡模型訓練方法,其特征在于,包括:
2.根據權利要求1所述的方法,其特征在于,所述分別構建教師神經網絡模型、學生神經網絡模型和輔助神經網絡模型對應的函數,包括:
3.根據權利要求1所述的方法,其特征在于,所述通過減少輔助神經網絡模型的蒸餾損失更新輔助神經網絡模型的權重ωA為:通過公式更新ωA。
4.根據權利要求1所述的方法,其特征在于,所述通過減少學生神經網絡模型的蒸餾損失更新學生神經網絡模型的權重參數ωs為:
5.根據權利要求1所述的方法,其特征在于,通過梯度下降最小化學生神經網絡模型的損失函數更新輔助神經網絡模型的結構參數aA,包括:
6.根據權利要求2-5任一項所述的方法,其特征在于,還包括:
7.一種異構知識蒸餾中的輔助神經網絡模型訓練裝置,其特征在于,包括:
8.根據權利要求7所述的裝置,其特征在于,所述輔助模型權重優化模塊具體用于:通過公式更新ωA;
9.根據權利要求7或8所述的裝置,其特征在于,還包括:
10.一種計算設備,其
...【技術特征摘要】
1.一種異構知識蒸餾中的輔助神經網絡模型訓練方法,其特征在于,包括:
2.根據權利要求1所述的方法,其特征在于,所述分別構建教師神經網絡模型、學生神經網絡模型和輔助神經網絡模型對應的函數,包括:
3.根據權利要求1所述的方法,其特征在于,所述通過減少輔助神經網絡模型的蒸餾損失更新輔助神經網絡模型的權重ωa為:通過公式更新ωa。
4.根據權利要求1所述的方法,其特征在于,所述通過減少學生神經網絡模型的蒸餾損失更新學生神經網絡模型的權重參數ωs為:
5.根據權利要求1所述的方法,其特征在于,通過梯度下降最小化學生神經網絡模型的損...
【專利技術屬性】
技術研發人員:楊丹,
申請(專利權)人:中國農業銀行股份有限公司,
類型:發明
國別省市:
還沒有人留言評論。發表了對其他瀏覽者有用的留言會獲得科技券。