作者: akensert
建立日期 2021/09/13
上次修改日期 2021/12/26
說明: 用於節點分類的圖注意力網路 (GAT) 的實作。
圖神經網路是處理結構化為圖形的資料(例如,社交網路或分子結構)的首選神經網路架構,可產生比全連接網路或卷積網路更好的結果。
在本教學中,我們將實作一個特定的圖神經網路,稱為 圖注意力網路 (GAT),根據引用的論文類型來預測科學論文的標籤(使用 Cora 資料集)。
有關 GAT 的更多資訊,請參閱原始論文 圖注意力網路以及 DGL 的圖注意力網路文件。
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import pandas as pd
import os
import warnings
warnings.filterwarnings("ignore")
pd.set_option("display.max_columns", 6)
pd.set_option("display.max_rows", 6)
np.random.seed(2)
Cora 資料集的準備工作與 使用圖神經網路進行節點分類教學中的準備工作相同。有關資料集和探索性資料分析的更多詳細資訊,請參閱本教學。簡而言之,Cora 資料集由兩個檔案組成:cora.cites
包含論文之間的有向連結(引用);而 cora.content
包含相應論文的特徵以及七個標籤之一(論文的主題)。
zip_file = keras.utils.get_file(
fname="cora.tgz",
origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz",
extract=True,
)
data_dir = os.path.join(os.path.dirname(zip_file), "cora")
citations = pd.read_csv(
os.path.join(data_dir, "cora.cites"),
sep="\t",
header=None,
names=["target", "source"],
)
papers = pd.read_csv(
os.path.join(data_dir, "cora.content"),
sep="\t",
header=None,
names=["paper_id"] + [f"term_{idx}" for idx in range(1433)] + ["subject"],
)
class_values = sorted(papers["subject"].unique())
class_idx = {name: id for id, name in enumerate(class_values)}
paper_idx = {name: idx for idx, name in enumerate(sorted(papers["paper_id"].unique()))}
papers["paper_id"] = papers["paper_id"].apply(lambda name: paper_idx[name])
citations["source"] = citations["source"].apply(lambda name: paper_idx[name])
citations["target"] = citations["target"].apply(lambda name: paper_idx[name])
papers["subject"] = papers["subject"].apply(lambda value: class_idx[value])
print(citations)
print(papers)
target source
0 0 21
1 0 905
2 0 906
... ... ...
5426 1874 2586
5427 1876 1874
5428 1897 2707
[5429 rows x 2 columns]
paper_id term_0 term_1 ... term_1431 term_1432 subject
0 462 0 0 ... 0 0 2
1 1911 0 0 ... 0 0 5
2 2002 0 0 ... 0 0 4
... ... ... ... ... ... ... ...
2705 2372 0 0 ... 0 0 1
2706 955 0 0 ... 0 0 0
2707 376 0 0 ... 0 0 2
[2708 rows x 1435 columns]
# Obtain random indices
random_indices = np.random.permutation(range(papers.shape[0]))
# 50/50 split
train_data = papers.iloc[random_indices[: len(random_indices) // 2]]
test_data = papers.iloc[random_indices[len(random_indices) // 2 :]]
# Obtain paper indices which will be used to gather node states
# from the graph later on when training the model
train_indices = train_data["paper_id"].to_numpy()
test_indices = test_data["paper_id"].to_numpy()
# Obtain ground truth labels corresponding to each paper_id
train_labels = train_data["subject"].to_numpy()
test_labels = test_data["subject"].to_numpy()
# Define graph, namely an edge tensor and a node feature tensor
edges = tf.convert_to_tensor(citations[["target", "source"]])
node_states = tf.convert_to_tensor(papers.sort_values("paper_id").iloc[:, 1:-1])
# Print shapes of the graph
print("Edges shape:\t\t", edges.shape)
print("Node features shape:", node_states.shape)
Edges shape: (5429, 2)
Node features shape: (2708, 1433)
GAT 將圖形(即邊張量和節點特徵張量)作為輸入,並輸出 [更新的] 節點狀態。節點狀態對於每個目標節點來說,是 N 跳的鄰域聚合資訊(其中 N 由 GAT 的層數決定)。重要的是,與 圖卷積網路 (GCN) 相比,GAT 利用注意力機制來聚合來自相鄰節點(或來源節點)的資訊。換句話說,GAT 不是簡單地平均/總結來自來源節點(來源論文)到目標節點(目標論文)的節點狀態,而是先將標準化的注意力分數應用於每個來源節點狀態,然後進行總和。
GAT 模型實作了多頭圖注意力層。MultiHeadGraphAttention
層只是多個圖注意力層 (GraphAttention
) 的串聯(或平均),每個圖注意力層都有單獨的可學習權重 W
。GraphAttention
層執行以下操作
考慮由 W^{l}
線性轉換的輸入節點狀態 h^{l}
,得到 z^{l}
。
對於每個目標節點
j
的成對注意力分數 a^{l}^{T}(z^{l}_{i}||z^{l}_{j})
,得到 e_{ij}
(對於所有 j
)。||
表示串聯,_{i}
對應於目標節點,而 _{j}
對應於給定的 1 跳鄰居/來源節點。e_{ij}
標準化,以便傳入邊的注意力分數總和到目標節點(sum_{k}{e_{norm}_{ik}}
)會加總為 1。e_{norm}_{ij}
應用於 z_{j}
,並將其加入新的目標節點狀態 h^{l+1}_{i}
,適用於所有 j
。class GraphAttention(layers.Layer):
def __init__(
self,
units,
kernel_initializer="glorot_uniform",
kernel_regularizer=None,
**kwargs,
):
super().__init__(**kwargs)
self.units = units
self.kernel_initializer = keras.initializers.get(kernel_initializer)
self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
def build(self, input_shape):
self.kernel = self.add_weight(
shape=(input_shape[0][-1], self.units),
trainable=True,
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
name="kernel",
)
self.kernel_attention = self.add_weight(
shape=(self.units * 2, 1),
trainable=True,
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
name="kernel_attention",
)
self.built = True
def call(self, inputs):
node_states, edges = inputs
# Linearly transform node states
node_states_transformed = tf.matmul(node_states, self.kernel)
# (1) Compute pair-wise attention scores
node_states_expanded = tf.gather(node_states_transformed, edges)
node_states_expanded = tf.reshape(
node_states_expanded, (tf.shape(edges)[0], -1)
)
attention_scores = tf.nn.leaky_relu(
tf.matmul(node_states_expanded, self.kernel_attention)
)
attention_scores = tf.squeeze(attention_scores, -1)
# (2) Normalize attention scores
attention_scores = tf.math.exp(tf.clip_by_value(attention_scores, -2, 2))
attention_scores_sum = tf.math.unsorted_segment_sum(
data=attention_scores,
segment_ids=edges[:, 0],
num_segments=tf.reduce_max(edges[:, 0]) + 1,
)
attention_scores_sum = tf.repeat(
attention_scores_sum, tf.math.bincount(tf.cast(edges[:, 0], "int32"))
)
attention_scores_norm = attention_scores / attention_scores_sum
# (3) Gather node states of neighbors, apply attention scores and aggregate
node_states_neighbors = tf.gather(node_states_transformed, edges[:, 1])
out = tf.math.unsorted_segment_sum(
data=node_states_neighbors * attention_scores_norm[:, tf.newaxis],
segment_ids=edges[:, 0],
num_segments=tf.shape(node_states)[0],
)
return out
class MultiHeadGraphAttention(layers.Layer):
def __init__(self, units, num_heads=8, merge_type="concat", **kwargs):
super().__init__(**kwargs)
self.num_heads = num_heads
self.merge_type = merge_type
self.attention_layers = [GraphAttention(units) for _ in range(num_heads)]
def call(self, inputs):
atom_features, pair_indices = inputs
# Obtain outputs from each attention head
outputs = [
attention_layer([atom_features, pair_indices])
for attention_layer in self.attention_layers
]
# Concatenate or average the node states from each head
if self.merge_type == "concat":
outputs = tf.concat(outputs, axis=-1)
else:
outputs = tf.reduce_mean(tf.stack(outputs, axis=-1), axis=-1)
# Activate and return node states
return tf.nn.relu(outputs)
train_step
、test_step
和 predict_step
方法實作訓練邏輯請注意,GAT 模型在所有階段(訓練、驗證和測試)都對整個圖形(即 node_states
和 edges
)進行操作。因此,node_states
和 edges
會傳遞給 keras.Model
的建構函式,並用作屬性。各階段之間的差異是索引(和標籤),這些索引會收集某些輸出(tf.gather(outputs, indices)
)。
class GraphAttentionNetwork(keras.Model):
def __init__(
self,
node_states,
edges,
hidden_units,
num_heads,
num_layers,
output_dim,
**kwargs,
):
super().__init__(**kwargs)
self.node_states = node_states
self.edges = edges
self.preprocess = layers.Dense(hidden_units * num_heads, activation="relu")
self.attention_layers = [
MultiHeadGraphAttention(hidden_units, num_heads) for _ in range(num_layers)
]
self.output_layer = layers.Dense(output_dim)
def call(self, inputs):
node_states, edges = inputs
x = self.preprocess(node_states)
for attention_layer in self.attention_layers:
x = attention_layer([x, edges]) + x
outputs = self.output_layer(x)
return outputs
def train_step(self, data):
indices, labels = data
with tf.GradientTape() as tape:
# Forward pass
outputs = self([self.node_states, self.edges])
# Compute loss
loss = self.compiled_loss(labels, tf.gather(outputs, indices))
# Compute gradients
grads = tape.gradient(loss, self.trainable_weights)
# Apply gradients (update weights)
optimizer.apply_gradients(zip(grads, self.trainable_weights))
# Update metric(s)
self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))
return {m.name: m.result() for m in self.metrics}
def predict_step(self, data):
indices = data
# Forward pass
outputs = self([self.node_states, self.edges])
# Compute probabilities
return tf.nn.softmax(tf.gather(outputs, indices))
def test_step(self, data):
indices, labels = data
# Forward pass
outputs = self([self.node_states, self.edges])
# Compute loss
loss = self.compiled_loss(labels, tf.gather(outputs, indices))
# Update metric(s)
self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))
return {m.name: m.result() for m in self.metrics}
# Define hyper-parameters
HIDDEN_UNITS = 100
NUM_HEADS = 8
NUM_LAYERS = 3
OUTPUT_DIM = len(class_values)
NUM_EPOCHS = 100
BATCH_SIZE = 256
VALIDATION_SPLIT = 0.1
LEARNING_RATE = 3e-1
MOMENTUM = 0.9
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = keras.optimizers.SGD(LEARNING_RATE, momentum=MOMENTUM)
accuracy_fn = keras.metrics.SparseCategoricalAccuracy(name="acc")
early_stopping = keras.callbacks.EarlyStopping(
monitor="val_acc", min_delta=1e-5, patience=5, restore_best_weights=True
)
# Build model
gat_model = GraphAttentionNetwork(
node_states, edges, HIDDEN_UNITS, NUM_HEADS, NUM_LAYERS, OUTPUT_DIM
)
# Compile model
gat_model.compile(loss=loss_fn, optimizer=optimizer, metrics=[accuracy_fn])
gat_model.fit(
x=train_indices,
y=train_labels,
validation_split=VALIDATION_SPLIT,
batch_size=BATCH_SIZE,
epochs=NUM_EPOCHS,
callbacks=[early_stopping],
verbose=2,
)
_, test_accuracy = gat_model.evaluate(x=test_indices, y=test_labels, verbose=0)
print("--" * 38 + f"\nTest Accuracy {test_accuracy*100:.1f}%")
Epoch 1/100
5/5 - 26s - loss: 1.8418 - acc: 0.2980 - val_loss: 1.5117 - val_acc: 0.4044 - 26s/epoch - 5s/step
Epoch 2/100
5/5 - 6s - loss: 1.2422 - acc: 0.5640 - val_loss: 1.0407 - val_acc: 0.6471 - 6s/epoch - 1s/step
Epoch 3/100
5/5 - 5s - loss: 0.7092 - acc: 0.7906 - val_loss: 0.8201 - val_acc: 0.7868 - 5s/epoch - 996ms/step
Epoch 4/100
5/5 - 5s - loss: 0.4768 - acc: 0.8604 - val_loss: 0.7451 - val_acc: 0.8088 - 5s/epoch - 934ms/step
Epoch 5/100
5/5 - 5s - loss: 0.2641 - acc: 0.9294 - val_loss: 0.7499 - val_acc: 0.8088 - 5s/epoch - 945ms/step
Epoch 6/100
5/5 - 5s - loss: 0.1487 - acc: 0.9663 - val_loss: 0.6803 - val_acc: 0.8382 - 5s/epoch - 967ms/step
Epoch 7/100
5/5 - 5s - loss: 0.0970 - acc: 0.9811 - val_loss: 0.6688 - val_acc: 0.8088 - 5s/epoch - 960ms/step
Epoch 8/100
5/5 - 5s - loss: 0.0597 - acc: 0.9934 - val_loss: 0.7295 - val_acc: 0.8162 - 5s/epoch - 981ms/step
Epoch 9/100
5/5 - 5s - loss: 0.0398 - acc: 0.9967 - val_loss: 0.7551 - val_acc: 0.8309 - 5s/epoch - 991ms/step
Epoch 10/100
5/5 - 5s - loss: 0.0312 - acc: 0.9984 - val_loss: 0.7666 - val_acc: 0.8309 - 5s/epoch - 987ms/step
Epoch 11/100
5/5 - 5s - loss: 0.0219 - acc: 0.9992 - val_loss: 0.7726 - val_acc: 0.8309 - 5s/epoch - 1s/step
----------------------------------------------------------------------------
Test Accuracy 76.5%
test_probs = gat_model.predict(x=test_indices)
mapping = {v: k for (k, v) in class_idx.items()}
for i, (probs, label) in enumerate(zip(test_probs[:10], test_labels[:10])):
print(f"Example {i+1}: {mapping[label]}")
for j, c in zip(probs, class_idx.keys()):
print(f"\tProbability of {c: <24} = {j*100:7.3f}%")
print("---" * 20)
Example 1: Probabilistic_Methods
Probability of Case_Based = 0.919%
Probability of Genetic_Algorithms = 0.180%
Probability of Neural_Networks = 37.896%
Probability of Probabilistic_Methods = 59.801%
Probability of Reinforcement_Learning = 0.705%
Probability of Rule_Learning = 0.044%
Probability of Theory = 0.454%
------------------------------------------------------------
Example 2: Genetic_Algorithms
Probability of Case_Based = 0.005%
Probability of Genetic_Algorithms = 99.993%
Probability of Neural_Networks = 0.001%
Probability of Probabilistic_Methods = 0.000%
Probability of Reinforcement_Learning = 0.000%
Probability of Rule_Learning = 0.000%
Probability of Theory = 0.000%
------------------------------------------------------------
Example 3: Theory
Probability of Case_Based = 8.151%
Probability of Genetic_Algorithms = 1.021%
Probability of Neural_Networks = 0.569%
Probability of Probabilistic_Methods = 40.220%
Probability of Reinforcement_Learning = 0.792%
Probability of Rule_Learning = 6.910%
Probability of Theory = 42.337%
------------------------------------------------------------
Example 4: Neural_Networks
Probability of Case_Based = 0.097%
Probability of Genetic_Algorithms = 0.026%
Probability of Neural_Networks = 93.539%
Probability of Probabilistic_Methods = 6.206%
Probability of Reinforcement_Learning = 0.028%
Probability of Rule_Learning = 0.010%
Probability of Theory = 0.094%
------------------------------------------------------------
Example 5: Theory
Probability of Case_Based = 25.259%
Probability of Genetic_Algorithms = 4.381%
Probability of Neural_Networks = 11.776%
Probability of Probabilistic_Methods = 15.053%
Probability of Reinforcement_Learning = 1.571%
Probability of Rule_Learning = 23.589%
Probability of Theory = 18.370%
------------------------------------------------------------
Example 6: Genetic_Algorithms
Probability of Case_Based = 0.000%
Probability of Genetic_Algorithms = 100.000%
Probability of Neural_Networks = 0.000%
Probability of Probabilistic_Methods = 0.000%
Probability of Reinforcement_Learning = 0.000%
Probability of Rule_Learning = 0.000%
Probability of Theory = 0.000%
------------------------------------------------------------
Example 7: Neural_Networks
Probability of Case_Based = 0.296%
Probability of Genetic_Algorithms = 0.291%
Probability of Neural_Networks = 93.419%
Probability of Probabilistic_Methods = 5.696%
Probability of Reinforcement_Learning = 0.050%
Probability of Rule_Learning = 0.072%
Probability of Theory = 0.177%
------------------------------------------------------------
Example 8: Genetic_Algorithms
Probability of Case_Based = 0.000%
Probability of Genetic_Algorithms = 100.000%
Probability of Neural_Networks = 0.000%
Probability of Probabilistic_Methods = 0.000%
Probability of Reinforcement_Learning = 0.000%
Probability of Rule_Learning = 0.000%
Probability of Theory = 0.000%
------------------------------------------------------------
Example 9: Theory
Probability of Case_Based = 4.103%
Probability of Genetic_Algorithms = 5.217%
Probability of Neural_Networks = 14.532%
Probability of Probabilistic_Methods = 66.747%
Probability of Reinforcement_Learning = 3.008%
Probability of Rule_Learning = 1.782%
Probability of Theory = 4.611%
------------------------------------------------------------
Example 10: Case_Based
Probability of Case_Based = 99.566%
Probability of Genetic_Algorithms = 0.017%
Probability of Neural_Networks = 0.016%
Probability of Probabilistic_Methods = 0.155%
Probability of Reinforcement_Learning = 0.026%
Probability of Rule_Learning = 0.192%
Probability of Theory = 0.028%
------------------------------------------------------------
結果看起來還不錯!GAT 模型似乎根據論文的引用內容,約有 80% 的時間能正確預測論文的主題。可以透過微調 GAT 的超參數來進一步改進。例如,嘗試變更層數、隱藏單元數或最佳化工具/學習率;加入正規化(例如,dropout);或修改預處理步驟。我們也可以嘗試實作自我迴路(即論文 X 引用論文 X)和/或使圖形無向。