【技術實現步驟摘要】
本專利技術涉及人工智能和圖神經網絡領域,特別地,涉及一種基于虛擬環境推理的不變圖學習的方法。
技術介紹
1、近年來,圖神經網絡在涉及圖結構的機器學習應用,如ai輔助制藥、推薦系統等領域,取得了很大的成功。然而,大部分現有的圖機器學習算法都依賴于數據的獨立同分布假設,即測試和訓練圖數據是從同一分布中獨立抽取的。中國申請cn117523215a(申請號202311483128.3,申請日2023.11.08)公開的一種基于因果特征分離的分布外泛化方法及系統,所述方法包括:首先,根據圖像的多維度特征和真實標簽構建無向因果框架;接著,將與圖像真實標簽無選擇偏差的非因果特征過濾掉,得到中間無向因果框架;然后,將與圖像真實標簽具有選擇偏差的非因果特征與因果特征進行分離;最后根據分離出來的因果特征進行后續的目標圖像分類。采用本專利技術能夠極大地提升圖像分類的準確性、穩定性和實時性。然而在實際應用中,訓練數據和測試數據的分布通常表現出不一致性,即分布偏移,使得算法的性能大幅下降。旨在提升模型在分布偏移下性能的問題通常被稱為分布外泛化問題,近年涌現了諸多方法借助因果推斷中的不變性原理在圖像分布外泛化問題上取得了一定的進展,但在圖數據上的研究依然有限。這是因為圖數據結構本身的復雜性,使得圖數據的分布外泛化比傳統的歐式數據更普遍且更具挑戰性,從而給圖機器學習帶來了更多的挑戰。同時,因為圖數據結構的復雜性,導致圖數據的分布外泛化相比于歐式數據更普遍且更具挑戰性。解決以上問題的難點在于如何充分利用訓練數據使得圖神經網絡能夠學習到圖中不受分布偏移影響的部分
技術實現思路
1、本專利技術提供了一種基于虛擬環境推理的不變圖學習的方法,通過構造基于類的虛擬環境,提升圖神經網絡的不變圖學習能力,從而改善圖神經網絡的分布外泛化性能。
2、本專利技術的技術方案如下:
3、本專利技術的基于虛擬環境推理的不變圖學習的方法,包括以下步驟:s1.不變圖生成器:利用圖神經網絡對輸入的圖數據進行分析,得到不變圖部分和可變圖部分;s2.不變圖增強器:通過對可變圖的邊進行隨機采樣,根據邊數將其平均分為兩部分,分別與不變圖合并得到兩個增強的不變圖視圖;s3.編碼器:利用圖卷積結構和非線性結構搭建編碼器,將不變圖及其兩個增強的不變圖視圖變換到特征空間中,得到相應的特征;s4.分類器:利用圖卷積結構和歸一化結構搭建分類器,充分融合不變圖及其兩個增強的不變圖視圖的特征,得到圖分類的結果;s5.基于類的虛擬環境推理:選定基準類,將其他類中的圖特征與基準類中的圖特征進行余弦相似度計算,根據相似度將其他類中的圖平均分為“相似”和“不相似”兩部分,分別與基準類中的圖合并,構成基于類的兩個虛擬環境;s6.聯合優化器:計算分類器的分類結果和真實分類結果的交叉熵損失,同時根據構建得到的基于類的虛擬環境,計算每個虛擬環境中的對比損失;通過超參數控制兩部分損失的比重,加權相加得到最終的損失,對其反向傳播,實現對不變圖生成器、編碼器和分類器的聯合優化。
4、可選地,在上述基于虛擬環境推理的不變圖學習的方法中,在步驟s1中,給定圖g=(v,e)和對應的鄰接矩陣a,其中v表示圖中節點的集合,e表示圖中邊的集合,a表示圖中的連接關系。首先采用了一個圖神經網絡對鄰接矩陣a進行處理,得到相應的表示邊重要程度的掩碼矩陣m。然后選取最重要的k條邊作為不變圖對應的邊的集合,這k條邊對應的頂點相應的則構成不變圖對應的點的集合。最后,整個圖除去不變圖的部分則構成可變圖部分,從而最終得到圖g對應的不變圖部分gc和可變圖部分gs。
5、可選地,在上述基于虛擬環境推理的不變圖學習的方法中,在步驟s2中,將可變圖gs的邊的集合平均分為兩部分,分別與不變圖gc進行合并,從而得到不變圖的兩個增強視圖和
6、可選地,在上述基于虛擬環境推理的不變圖學習的方法中,在步驟s3中,利用圖卷積結構和非線性結構搭建編碼器,將不變圖gc及其兩個增強視圖和變換到特征空間中,得到相應的特征z,za和zb。
7、可選地,在上述基于虛擬環境推理的不變圖學習的方法中,在步驟s4中,利用圖卷積結構和歸一化結構搭建分類器,將步驟3得到的特征z,za和zb進行融合處理,最終輸出每個類對應的概率,選取概率最高的類作為最終分類結果。
8、可選地,在上述基于虛擬環境推理的不變圖學習的方法中,在步驟s5中,利用步驟s3中得到的特征來構建虛擬環境。從訓練數據集中選定基準類類canchor,通過計算其他類中圖的特征與基準類中的圖的特征的余弦相似度,從而根據相似度高低將其他類中的圖平均分為“相似”和“不相似”兩部分,進而分別與基準類中的圖進行合并,最終得到基準類類canchor對應的兩個虛擬環境。假設訓練數據集中有c個類,將每個類都作為基準類重復上述構建虛擬環境的步驟,從而得到所有的2*c個虛擬環境。
9、可選地,在上述基于虛擬環境推理的不變圖學習的方法中,在步驟s6中,基于步驟s5得到虛擬環境,對于每個虛擬環境,將基準類中圖對應的特征z,za和zb作為正樣本,而其他類中圖的特征作為負樣本,計算得到與每個虛擬環境對應的對比損失,最后將所有虛擬環境的對比損失相加得到虛擬環境對應的損失。此外,還結合計算分類器的分類結果和真實分類結果的交叉熵損失,將兩者加權相加得到最終的損失,對其反向傳播,實現對不變圖生成器、編碼器和分類器的聯合優化。
10、根據本專利技術的技術方案,產生的有益效果是:
11、本專利技術基于圖獨特的結構及其不變部分和可變部分的關系,提出了一種基于虛擬環境推理的不變圖學習的方法。本專利技術方法結合圖神經網絡技術,通過構造基于類的虛擬環境,將屬于不同類但是特征相似的類劃分到同一個虛擬環境中,結合與之相匹配的聯合優化器,使得編碼器能夠更好地學習圖的不變部分。具體地從基于類的角度出發,以編碼器得到的特征作為相似度比較的標準,借助于訓練數據集中的標簽數據,用對比學習的思想來構造基于類的虛擬環境,將類別不同但在特征維度相似的圖劃分到同一個虛擬環境中,同時設計相應的損失函數來更好地識別出不變部分和可變部分,讓編碼器更好地學習到與分布偏移無關的不變圖部分,從而提升分類器的分類效果,提高圖神經網絡的分布外泛化的能力。
12、為了更好地理解和說明本專利技術的構思、工作原理和專利技術效果,下面結合附圖,通過具體實施例,對本專利技術進行詳細說明如下:
本文檔來自技高網...【技術保護點】
1.一種基于虛擬環境推理的不變圖學習的方法,包括以下步驟:
2.根據權利要求1所述的基于虛擬環境推理的不變圖學習的方法,其特征在于,
3.根據權利要求1所述的基于虛擬環境推理的不變圖學習的方法,其特征在于,
4.根據權利要求1所述的基于虛擬環境推理的不變圖學習的方法,其特征在于,
5.根據權利要求1所述的基于虛擬環境推理的不變圖學習的方法,其特征在于,
6.根據權利要求1所述的基于虛擬環境推理的不變圖學習的方法,其特征在于,
7.根據權利要求1所述的基于虛擬環境推理的不變圖學習的方法,其特征在于,
8.根據權利要求5所述的基于虛擬環境推理的不變圖學習的方法,其特征在于,此外,還結合計算分類器的分類結果和真實分類結果的交叉熵損失,將兩者加權相加得到最終的損失,對其反向傳播,實現對不變圖生成器、編碼器和分類器的聯合優化。
【技術特征摘要】
1.一種基于虛擬環境推理的不變圖學習的方法,包括以下步驟:
2.根據權利要求1所述的基于虛擬環境推理的不變圖學習的方法,其特征在于,
3.根據權利要求1所述的基于虛擬環境推理的不變圖學習的方法,其特征在于,
4.根據權利要求1所述的基于虛擬環境推理的不變圖學習的方法,其特征在于,
5.根據權利要求1所述的基于虛擬環境推理的不變圖學習的方法,其特征在于,
【專利技術屬性】
技術研發人員:李革,徐強,高偉,王榮剛,
申請(專利權)人:北京大學深圳研究生院,
類型:發明
國別省市:
還沒有人留言評論。發表了對其他瀏覽者有用的留言會獲得科技券。