Learning to Compare : Relation Network
之前的one-shot方法例如siamese network等是直接利用像歐氏距離或者餘弦距離這種pre-defined fixed distance metric learning的方法來計算的樣本相似度。但是其實我們不知道這些固定的預先設定好的評價是不是最合適的,因此這篇文章[1] (手動@作者@Flood Sung) 的想法是如何學習一種非線性的表達可以更加準確的評估這之間的關係。
網路結構如圖1和圖2,要注意的地方是在sample/query set和support/test set的構建過程中。首先得注意區分文中的幾個概念,本筆記中的引數是基於Omniglot資料集中的設定。
圖1、網路結構流程圖
圖2、模型結構示意圖
對於包含c個不同的類別,每個類別有k個樣本的support set,文中稱之為
c-way k-shot
。如圖2就是一個5-way 1-shot的結構。
在訓練過程中構建的是
sample set
和
query set,
用來模擬
測試時的support set和test set
。在training階段構建這兩個set的時候,是先從1200個用來訓練的類裡隨機挑選出5個類,再在每個類所對應的20個樣本中隨機挑選出一個作為sample set,也就是對應的測試時的support set,剩下的19個樣本在訓練的時候也都用來做query set。而在測試階段,test set是隻用了一個樣本的。
即我們可以這麼理解,我們在訓練的時候,將這19×5個樣本與隨機挑選出來的5個樣本兩兩之間進行比較,會有95×5=475對兒比較。這讓我們的模型在判斷相似度的時候,可以減小同類之間的距離而加大不同類之間的距離。最終選擇間距最小的類別作為預測類別。
網路在訓練過程中是相當於一個episode分段式的策略。先將sample set和query set的樣本經過embedding module得到feature,再將5個sample類的樣本的feature分別和19個query set的19*5個樣本的feature在depth維度上兩兩concat到一起。第二步是將concat之後的feature經過Relation Network輸出關係得分,把輸出的relation score看做是一個從0到1的數值。0就代表極不相似,而1則代表完全相似。因此就非常直接的採用平方差MSE作為網路訓練的loss。
將class semantic vector作為support set,可以將這個問題轉化為一個zsl的問題,同樣的網路也可以解決。
可能存在的一個小bug是作者沒有在訓練階段和測試階段使用。train()/。val(),如果加上的話準確率相對於論文中會有些許的下降~
最後記錄構建set時的一些引數
task
:{
character_folder
:[
1200
path_to_images_classes
],
test_labels
:[
19
*
5
(
0
-
1
)],
num_classes
=
5
;
test_num
=
19
;
train_num
=
1
;
test_roots
:[
19
*
5
path_to_images
],
train_labels
:[
0
,
1
,
2
,
3
,
4
],
train_roots
:[
5
path_to_images
]
}
# 相當於驗證階段:TEST_EPISODE
task
:{
character_folder
:[
423
path_to_images_classes
],
test_labels
:[
1
*
5
(
0
-
1
)],
num_classes
=
5
;
test_num
=
1
;
# 訓練時挑19個,測試時挑一個。
train_num
=
1
;
test_roots
:[
1
*
5
path_to_images
],
train_labels
:[
0
,
1
,
2
,
3
,
4
],
train_roots
:[
5
path_to_images
]
}
sample_images
=
[
5
,
1
,
28
,
28
]
;
sample_labels
=
[
5
]
test_images
=
[
5
,
1
,
28
,
28
]
;
test_labels
=
[
5
]
然後樓下小賣鋪去哪兒了?
randperm功能是隨機打亂一個數字序列
。
語法格式:
y
=
torch
。
randperm
(
n
)
y是把1到n這些數隨機打亂得到的一個數字序列
。
=============================================================
pytorch
:
如何優雅的將
int
list
轉成
one
-
hot形式
:
使用交叉熵損失函式的時候會自動把
label轉化成onehot
,所以不用手動轉化,
而使用
MSE需要手動轉化成onehot編碼
。
# LongTensor的shape剛好與x的shape對應,也就是LongTensor每個index指定x中一個數據的填充位置。
dim
=
0
,表示按行填充,主要理解按行填充。舉例
LongTensor中的第0行第2列index
=
2
,
表示在第
2
行(從
0
開始)進行填充,對應到
zeros
(
3
,
5
)
中就是位置(
2
,
2
)。
所以此處要求
zeros
(
3
,
5
)
的列數要與
x列數相同
,而
LongTensor中的index最大值應與zeros
(
3
,
5
)
行數相一致。
>>>
torch
。
zeros
(
3
,
5
)
。
scatter_
(
0
,
torch
。
LongTensor
([[
0
,
1
,
2
,
0
,
0
],
[
2
,
0
,
0
,
1
,
2
]]),
x
)
# out[index[i, j], j] = value[i, j] dim=0
# out[i,index[i, j]] = value[i, j]] dim=1
參考資料:
[1]、Learning to Compare: Relation Network for Few-Shot Learning