(最近要使用KL散度計算損失函式,發現自己對KL散度還是一知半解,於是花了些時間去學校,使用pytorch也踩了些坑,做些筆記,以下是我個人對KL散度的理解,若有出錯還請大佬們指點指點~)

概念

KL散度可以用來衡量兩個機率分佈之間的相似性,兩個機率分佈越相近,KL散度越小。

D_{KL}(P||Q) = \sum_{i=1}^{N}{[p(x_{i})log p(x_{i}) - p(x_{i})log q(x_{i})]}

(1)

通常P為真實事件的機率分佈,Q為理論擬合出來的該事件的機率分佈。因為

D_{KL}(P||Q)

(P擬合Q)和

D_{KL}(Q||P)

(Q擬合P)是不一樣的。

舉個栗子

班裡男生人數佔40%,女生佔60%,則班裡隨機抽取一個人的性別的機率分佈是Q = [0。4, 0。6]。作為真實事件的機率分佈。

小明猜測班裡男生佔30%,女生佔70%,則小明擬合的機率分佈P1 = [0。3, 0。7]。

小紅猜測班裡男生佔20%,女生佔80%,則小紅擬合的機率分佈P2 = [0。2, 0。8]。

那麼現在,小明和小紅誰預測的機率分佈離真實分佈比較近?這時候就可以用KL散度來衡量P1與Q的相似性、P2與Q的相似性,然後對比可得誰更相似。

KL_{1} = [0.3\times log(0.3)-0.3\times log(0.4)] + [0.7\times log(0.7)-0.7\times log(0.6)]  \\ = 0.0216

KL_{2} = [0.2\times log(0.2)-0.2\times log(0.4)] + [0.8\times log(0.8)-0.8\times log(0.6)]  \\ = 0.0915

KL_{1}

KL_{2}

小,說明P1與Q更相近。

這個例子很直觀,不用計算就可以猜測出結果,但是當分佈複雜的情況下,用KL散度就比較好度量。如一個數據集分佈未知,想用數學公式來表達,比如高斯分佈、泊松分佈、韋伯分佈等,這些分佈哪個更適合用來表示資料集的分佈。則可以計算擬合曲線與資料集真實分佈的KL散度,選擇KL散度最小的作為資料集的機率分佈表示式。

如:用高斯分佈擬合數據集分佈時,統計均值μ,標準差σ,則可得到高斯分佈表示式:

f(x) = \frac{1}{\sigma \sqrt{2\pi}}e^{-\frac{(x-\mu)^{2}}{2\sigma^{2}}}

再用高斯分佈表示式不同自變數x1,x2,。。。計算出不同類別的機率p1,p2。。。,即機率分佈P=[p1, p2,。。。],與真實的機率分佈Q = [q1,q2,。。。]透過公式(1)計算得到KL散度。

同理,計算其他擬合分佈與真實分佈的KL散度,對比得到最優用來擬合真實資料的機率分佈表示式。

pytorch計算KL散度

現在,明白了什麼是KL散度,可以用pytorch自帶的庫函式來計算KL散度。

使用pytorch進行KL散度計算,可以使用pytorch的kl_div函式,小白的我經過不斷嘗試,才明白這個函式的正確開啟方式。

假設y為真實分佈,x為預測分佈。

import torch。nn。functional as F

kl = F。kl_div(x。softmax(dim=-1)。log(), y。softmax(dim=-1), reduction=‘sum’)

其中kl_div接收三個引數,第一個為預測分佈,第二個為真實分佈,第三個為reduction。(其實還有其他引數,只是基本用不到)

這裡有一些細節需要注意,第一個引數與第二個引數都要進行softmax(dim=-1),目的是使兩個機率分佈的所有值之和都為1,若不進行此操作,如果x或y機率分佈所有值的和大於1,則可能會使計算的KL為負數。softmax接收一個引數dim,dim=-1表示在最後一維進行softmax操作。除此之外,第一個引數還要進行log()操作(至於為什麼,大概是為了方便pytorch的程式碼組織,pytorch定義的損失函式都呼叫handle_torch_function函式,方便權重控制等),才能得到正確結果。

第三個引數reduction有三種取值,為 none 時,各點的損失單獨計算,輸出損失與輸入(x)形狀相同;為 mean 時,輸出為所有損失的平均值;為 sum 時,輸出為所有損失的總和。

我理解的交叉熵與KL

交叉熵作為深度學習常用的損失函式,可以理解為是KL散度的一個特例。當機率分佈中的值只取1或0時,可以看作KL散度。但是兩者又有區別,KL散度中機率分佈所有值之和為1,而交叉熵則可以大於1,如[0,1,0,1,0,0,]。

從概念上講,KL 散度通常用來度量兩個機率分佈之間的差異。

交叉熵用來求目標與預測值之間的差距,資料分佈不一定是機率分佈。

——————————

KL散度理解以及使用pytorch計算KL散度