【Graph Neural Network】GCN: 演算法原理,實現和應用
半年前寫過一系列關於Graph Embedding技術的介紹文章:
【Graph Embedding】DeepWalk:演算法原理,實現和應用
【Graph Embedding】LINE:演算法原理,實現和應用
【Graph Embedding】node2vec:演算法原理,實現和應用
【Graph Embedding】SDNE:演算法原理,實現和應用
【Graph Embedding】Struc2Vec:演算法原理,實現和應用
簡單來說,Graph Embedding技術一般透過特定的策略對圖中的頂點進行遊走取樣進而學習到圖中的頂點的相似性,可以看做是一種將圖的拓撲結構進行向量表示的方法。
然而現實世界中,圖中的頂點還包含若干的屬性資訊,如社交網路中的使用者畫像資訊,引文網路中的文字資訊等,對於這類資訊,基於GraphEmbedding的方法通常是將屬性特徵拼接到頂點向量中提供給後續任務使用。
本文介紹的GCN則可以直接透過對圖的拓撲結構和頂點的屬性資訊進行學習來得到任務結果。
GCN演算法原理
首先,如果想要完整了解GCN的理論基礎,我們還需要去了解空間域卷積,譜圖卷積,傅立葉變換,Laplacian運算元這些,本文不涉及這些內容,感興趣的同學可以自行查閱相關資料。
我們現在先記住一個結論,GCN是譜圖卷積的一階區域性近似,是一個
多層的圖卷積神經網路,每一個卷積層僅處理一階鄰域資訊,透過疊加若干卷積層可以實現多階鄰域的資訊傳遞
。
每一個
卷積層的傳播規則
如下:
其中
是無向圖G的鄰接矩陣加上自連線(就是每個頂點和自身加一條邊),
是單位矩陣。
是
的度矩陣,即
是第
層的啟用單元矩陣,
是每一層的引數矩陣
簡單解釋下,GCN的每一層透過鄰接矩陣
和特徵矩陣
相乘得到每個頂點鄰居特徵的彙總,然後再乘上一個引數矩陣
加上啟用函式
做一次非線性變換得到聚合鄰接頂點特徵的矩陣
。
之所以鄰接矩陣
要加上一個單位矩陣
,是因為我們希望在進行資訊傳播的時候頂點自身的特徵資訊也得到保留。
而對鄰居矩陣
進行歸一化操作
是為了資訊傳遞的過程中保持特徵矩陣
的原有分佈,防止一些度數高的頂點和度數低的頂點在特徵分佈上產生較大的差異。
GCN的實現
GCN卷積層實現
output
=
tf
。
matmul
(
tf
。
sparse_tensor_dense_matmul
(
A
,
features
),
self
。
kernel
)
if
self
。
bias
:
output
+=
self
。
bias
act
=
self
。
activation
(
output
)
上述程式碼片段對應的就是
,只不過多了一個偏置項。
GCN的實現
def
GCN
(
adj_dim
,
num_class
,
feature_dim
,
dropout_rate
=
0。5
,
l2_reg
=
0
,
feature_less
=
True
,
):
Adjs
=
[
Input
(
shape
=
(
None
,),
sparse
=
True
)]
if
feature_less
:
X_in
=
Input
(
shape
=
(
1
,),
)
emb
=
Embedding
(
adj_dim
,
feature_dim
,
embeddings_initializer
=
Identity
(
1。0
),
trainable
=
False
)
X_emb
=
emb
(
X_in
)
H
=
Reshape
([
X_emb
。
shape
[
-
1
]])(
X_emb
)
else
:
X_in
=
Input
(
shape
=
(
feature_dim
,),
)
H
=
X_in
H
=
GraphConvolution
(
16
,
activation
=
‘relu’
,
dropout_rate
=
dropout_rate
,
l2_reg
=
l2_reg
)(
[
H
]
+
Adjs
)
Y
=
GraphConvolution
(
num_class
,
activation
=
‘softmax’
,
dropout_rate
=
dropout_rate
,
l2_reg
=
0
)(
[
H
]
+
Adjs
)
model
=
Model
(
inputs
=
[
X_in
]
+
Adjs
,
outputs
=
Y
)
return
model
這裡
feature_less
的作用是告訴模型我們是否有額外的頂點特徵輸入,當
feature_less
為
True
的時候,我們直接輸入一個單位矩陣作為特徵矩陣,相當於對每個頂點進行了onehot表示。
GCN應用
本例中的訓練,評測和視覺化的完整程式碼在下面的git倉庫中,後面還會陸續更新一些其他GNN演算法
我們使用論文引用網路資料集Cora進行測試,Cora資料集包含2708個頂點, 5429條邊,每個頂點包含1433個特徵,共有7個類別。
按照論文的設定,從每個類別中選取20個共140個頂點作為訓練,500個頂點作為驗證集合,1000個頂點作為測試集。DeepWalk從全體頂點集合中進行取樣,最後使用同樣的140個頂點訓練一個LR模型進行分類。
頂點分類任務結果
從分類任務結果可以看到,在使用較少訓練樣本的條件下GCN的效果是高於DeepWalk的,而不含頂點特徵的GCN的效果則會變差很多。
不含頂點特徵的GCN相當於僅僅在學習圖的拓撲結構,而對於圖的拓撲結構的學習GraphEmbedding方法也能做到,這也說明了GCN的優勢在於能夠同時融入了圖的拓撲結構和頂點的特徵進行學習。
頂點向量視覺化
從對得到的頂點向量的視覺化結果來看,GCN得到的向量相比於DeepWalk產出的向量確實更加能夠將同類的頂點聚集在一起,不同類的頂點區分開來。
DeepWalk視覺化
GCN視覺化
最後,作為一個系列的文章,我也會陸續進行更新,歡迎大家關注我的專欄~
相關文章
參考資料
Kipf T N, Welling M。 Semi-supervised classification with graph convolutional networks[J]。 arXiv preprint arXiv:1609。02907, 2016。(
https://
arxiv。org/pdf/1609。0290
7
)
為了方便大家學習,我把一些相關的經典文章和程式碼實現進行了打包彙總並放在了github倉庫裡,感興趣的同學可以看看~