作者: Kenneth Borup
建立日期 2020/09/01
上次修改日期 2020/09/01
描述: 經典知識蒸餾的實作。
知識蒸餾是一種模型壓縮的程序,其中訓練一個小型(學生)模型,使其與大型預訓練(教師)模型匹配。透過最小化旨在匹配軟化的教師 logits 以及真實標籤的損失函數,將知識從教師模型轉移到學生模型。
logits 通過在 softmax 中應用「溫度」縮放函數來軟化,有效地平滑機率分佈並揭示教師學習的類別間關係。
參考文獻
import os
import keras
from keras import layers
from keras import ops
import numpy as np
Distiller()
類別自訂的 Distiller()
類別會覆寫 Model
方法 compile
、compute_loss
和 call
。為了使用蒸餾器,我們需要
temperature
,用於計算軟化學生預測和軟化教師標籤之間的差異alpha
因子,用於加權學生損失和蒸餾損失在 compute_loss
方法中,我們執行教師和學生的前向傳遞,並分別通過 alpha
和 1 - alpha
對 student_loss
和 distillation_loss
進行加權來計算損失。注意:只有學生的權重會被更新。
class Distiller(keras.Model):
def __init__(self, student, teacher):
super().__init__()
self.teacher = teacher
self.student = student
def compile(
self,
optimizer,
metrics,
student_loss_fn,
distillation_loss_fn,
alpha=0.1,
temperature=3,
):
"""Configure the distiller.
Args:
optimizer: Keras optimizer for the student weights
metrics: Keras metrics for evaluation
student_loss_fn: Loss function of difference between student
predictions and ground-truth
distillation_loss_fn: Loss function of difference between soft
student predictions and soft teacher predictions
alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
temperature: Temperature for softening probability distributions.
Larger temperature gives softer distributions.
"""
super().compile(optimizer=optimizer, metrics=metrics)
self.student_loss_fn = student_loss_fn
self.distillation_loss_fn = distillation_loss_fn
self.alpha = alpha
self.temperature = temperature
def compute_loss(
self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
):
teacher_pred = self.teacher(x, training=False)
student_loss = self.student_loss_fn(y, y_pred)
distillation_loss = self.distillation_loss_fn(
ops.softmax(teacher_pred / self.temperature, axis=1),
ops.softmax(y_pred / self.temperature, axis=1),
) * (self.temperature**2)
loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
return loss
def call(self, x):
return self.student(x)
首先,我們建立一個教師模型和一個較小的學生模型。這兩個模型都是使用 Sequential()
建立的卷積神經網路,但也可以是任何 Keras 模型。
# Create the teacher
teacher = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
layers.Flatten(),
layers.Dense(10),
],
name="teacher",
)
# Create the student
student = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
layers.Flatten(),
layers.Dense(10),
],
name="student",
)
# Clone student for later comparison
student_scratch = keras.models.clone_model(student)
用於訓練教師和蒸餾教師的資料集是 MNIST,對於任何其他資料集(例如 CIFAR-10),該程序都是等效的,但需要適當的選擇模型。學生和教師都在訓練集上進行訓練,並在測試集上進行評估。
# Prepare the train and test dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))
x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))
在知識蒸餾中,我們假設教師已經被訓練好並且是固定的。因此,我們首先以通常的方式在訓練集上訓練教師模型。
# Train teacher as usual
teacher.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Train and evaluate teacher on data.
teacher.fit(x_train, y_train, epochs=5)
teacher.evaluate(x_test, y_test)
Epoch 1/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 3ms/step - loss: 0.2408 - sparse_categorical_accuracy: 0.9259
Epoch 2/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0912 - sparse_categorical_accuracy: 0.9726
Epoch 3/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 7s 4ms/step - loss: 0.0758 - sparse_categorical_accuracy: 0.9777
Epoch 4/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0690 - sparse_categorical_accuracy: 0.9797
Epoch 5/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0582 - sparse_categorical_accuracy: 0.9825
313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0931 - sparse_categorical_accuracy: 0.9760
[0.09044107794761658, 0.978100061416626]
我們已經訓練了教師模型,只需要初始化一個 Distiller(student, teacher)
實例,使用所需的損失、超參數和優化器 compile()
它,並將教師蒸餾到學生。
# Initialize and compile distiller
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
optimizer=keras.optimizers.Adam(),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
distillation_loss_fn=keras.losses.KLDivergence(),
alpha=0.1,
temperature=10,
)
# Distill teacher to student
distiller.fit(x_train, y_train, epochs=3)
# Evaluate student on test dataset
distiller.evaluate(x_test, y_test)
Epoch 1/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 3ms/step - loss: 1.8752 - sparse_categorical_accuracy: 0.7357
Epoch 2/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 6s 3ms/step - loss: 0.0333 - sparse_categorical_accuracy: 0.9475
Epoch 3/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 6s 3ms/step - loss: 0.0223 - sparse_categorical_accuracy: 0.9621
313/313 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.0189 - sparse_categorical_accuracy: 0.9629
[0.017046602442860603, 0.969200074672699]
我們也可以從頭開始訓練一個等效的學生模型,而無需教師的指導,以便評估透過知識蒸餾獲得的效能提升。
# Train student as doen usually
student_scratch.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Train and evaluate student trained from scratch.
student_scratch.fit(x_train, y_train, epochs=3)
student_scratch.evaluate(x_test, y_test)
Epoch 1/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 4s 1ms/step - loss: 0.5111 - sparse_categorical_accuracy: 0.8460
Epoch 2/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 3s 1ms/step - loss: 0.1039 - sparse_categorical_accuracy: 0.9687
Epoch 3/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 3s 1ms/step - loss: 0.0748 - sparse_categorical_accuracy: 0.9780
313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0744 - sparse_categorical_accuracy: 0.9737
[0.0629437193274498, 0.9778000712394714]
如果教師模型訓練了 5 個完整 epoch,並且學生模型在此教師模型的基礎上進行了 3 個完整 epoch 的蒸餾,那麼在此範例中,您應該會體驗到效能提升,相較於從頭開始訓練相同的學生模型,甚至相較於教師模型本身。您應該預期教師模型的準確率約為 97.6%,從頭開始訓練的學生模型應約為 97.6%,而經過蒸餾的學生模型應約為 98.1%。移除或嘗試不同的隨機種子以使用不同的權重初始化。