作者: fchollet
建立日期 2020/04/28
上次修改日期 2023/06/29
描述: 使用 TensorFlow 進行 Keras 模型多 GPU 訓練的指南。
一般來說,有兩種方法可以在多個裝置之間分配計算
資料平行處理:將單一模型複製到多個裝置或多部機器上。每個裝置處理不同的資料批次,然後合併結果。這種設定有多種變體,差異在於不同的模型複本如何合併結果、它們是否在每個批次保持同步,或者它們的耦合程度是否較鬆散等等。
模型平行處理:單一模型不同的部分在不同的裝置上執行,一起處理單一批次的資料。這對於具有自然平行架構的模型效果最佳,例如具有多個分支的模型。
本指南重點介紹資料平行處理,特別是同步資料平行處理,模型複本在處理完每個批次後保持同步。同步性使模型收斂行為與您在單裝置訓練中看到的行為相同。
具體而言,本指南教您如何使用 tf.distribute
API 在多個 GPU 上訓練 Keras 模型,只需對您的程式碼進行少量變更,這些 GPU 通常安裝在單一機器上 (單主機,多裝置訓練),數量為 2 到 16 個。這是研究人員和小規模工業工作流程最常見的設定。
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import keras
在這種設定中,您有一部機器,上面有多個 GPU (通常是 2 到 16 個)。每個裝置都會執行一個模型副本 (稱為複本)。為了簡單起見,在接下來的內容中,我們假設我們處理的是 8 個 GPU,這不會影響一般性。
運作方式
在訓練的每個步驟中
實際上,同步更新模型複本權重的過程是在每個個別權重變數的層級處理的。這是透過鏡像變數物件完成的。
如何使用
若要使用 Keras 模型執行單主機、多裝置同步訓練,您可以使用 tf.distribute.MirroredStrategy
API。以下是其運作方式
MirroredStrategy
,選擇性地設定您要使用的特定裝置 (依預設,策略將使用所有可用的 GPU)。fit()
也可能會建立變數,因此最好將 fit()
呼叫也放在範圍內。fit()
訓練模型。重要的是,我們建議您使用 tf.data.Dataset
物件在多裝置或分散式工作流程中載入資料。
簡略的表示,它看起來像這樣
# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
# Open a strategy scope.
with strategy.scope():
# Everything that creates variables should be under the strategy scope.
# In general this is only model construction & `compile()`.
model = Model(...)
model.compile(...)
# Train the model on all available devices.
model.fit(train_dataset, validation_data=val_dataset, ...)
# Test the model on all available devices.
model.evaluate(test_dataset)
以下是一個簡單的端對端可執行範例
def get_compiled_model():
# Make a simple 2-layer densely-connected neural network.
inputs = keras.Input(shape=(784,))
x = keras.layers.Dense(256, activation="relu")(inputs)
x = keras.layers.Dense(256, activation="relu")(x)
outputs = keras.layers.Dense(10)(x)
model = keras.Model(inputs, outputs)
model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
return model
def get_dataset():
batch_size = 32
num_val_samples = 10000
# Return the MNIST dataset in the form of a [`tf.data.Dataset`](https://tensorflow.dev.org.tw/api_docs/python/tf/data/Dataset).
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Preprocess the data (these are Numpy arrays)
x_train = x_train.reshape(-1, 784).astype("float32") / 255
x_test = x_test.reshape(-1, 784).astype("float32") / 255
y_train = y_train.astype("float32")
y_test = y_test.astype("float32")
# Reserve num_val_samples samples for validation
x_val = x_train[-num_val_samples:]
y_val = y_train[-num_val_samples:]
x_train = x_train[:-num_val_samples]
y_train = y_train[:-num_val_samples]
return (
tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size),
tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size),
tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size),
)
# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
print("Number of devices: {}".format(strategy.num_replicas_in_sync))
# Open a strategy scope.
with strategy.scope():
# Everything that creates variables should be under the strategy scope.
# In general this is only model construction & `compile()`.
model = get_compiled_model()
# Train the model on all available devices.
train_dataset, val_dataset, test_dataset = get_dataset()
model.fit(train_dataset, epochs=2, validation_data=val_dataset)
# Test the model on all available devices.
model.evaluate(test_dataset)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
Number of devices: 1
Epoch 1/2
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 7s 4ms/step - loss: 0.3830 - sparse_categorical_accuracy: 0.8884 - val_loss: 0.1361 - val_sparse_categorical_accuracy: 0.9574
Epoch 2/2
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 9s 3ms/step - loss: 0.1068 - sparse_categorical_accuracy: 0.9671 - val_loss: 0.0894 - val_sparse_categorical_accuracy: 0.9724
313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0988 - sparse_categorical_accuracy: 0.9673
當使用分散式訓練時,您應始終確保您有策略可以從故障中恢復 (容錯能力)。處理此問題最簡單的方法是將 ModelCheckpoint
回呼傳遞給 fit()
,以定期儲存您的模型 (例如,每 100 個批次或每個 epoch)。然後,您可以從您儲存的模型重新開始訓練。
以下是一個簡單的範例
# Prepare a directory to store all the checkpoints.
checkpoint_dir = "./ckpt"
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
def make_or_restore_model():
# Either restore the latest model, or create a fresh one
# if there is no checkpoint available.
checkpoints = [checkpoint_dir + "/" + name for name in os.listdir(checkpoint_dir)]
if checkpoints:
latest_checkpoint = max(checkpoints, key=os.path.getctime)
print("Restoring from", latest_checkpoint)
return keras.models.load_model(latest_checkpoint)
print("Creating a new model")
return get_compiled_model()
def run_training(epochs=1):
# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
# Open a strategy scope and create/restore the model
with strategy.scope():
model = make_or_restore_model()
callbacks = [
# This callback saves a SavedModel every epoch
# We include the current epoch in the folder name.
keras.callbacks.ModelCheckpoint(
filepath=checkpoint_dir + "/ckpt-{epoch}.keras",
save_freq="epoch",
)
]
model.fit(
train_dataset,
epochs=epochs,
callbacks=callbacks,
validation_data=val_dataset,
verbose=2,
)
# Running the first time creates the model
run_training(epochs=1)
# Calling the same function again will resume from where we left off
run_training(epochs=1)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
Creating a new model
1563/1563 - 7s - 4ms/step - loss: 0.2275 - sparse_categorical_accuracy: 0.9320 - val_loss: 0.1373 - val_sparse_categorical_accuracy: 0.9571
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
Restoring from ./ckpt/ckpt-1.keras
1563/1563 - 6s - 4ms/step - loss: 0.0944 - sparse_categorical_accuracy: 0.9717 - val_loss: 0.0972 - val_sparse_categorical_accuracy: 0.9710
tf.data
效能秘訣在執行分散式訓練時,載入資料的效率通常會變得至關重要。以下是一些秘訣,可確保您的 tf.data
管線盡可能快速地執行。
關於資料集批次的注意事項
在建立資料集時,請確保它已使用全域批次大小進行批次處理。例如,如果您的 8 個 GPU 中的每一個都能夠執行 64 個樣本的批次,您可以使用 512 的全域批次大小。
呼叫 dataset.cache()
如果您在資料集上呼叫 .cache()
,其資料將在執行第一次資料迭代後進行快取。每個後續迭代都將使用快取的資料。快取可以在記憶體中 (預設),也可以在您指定的本機檔案中。
當符合以下情況時,這可以提高效能
呼叫 dataset.prefetch(buffer_size)
您幾乎應該在建立資料集後始終呼叫 .prefetch(buffer_size)
。這表示您的資料管線將與您的模型非同步執行,在目前批次的樣本用於訓練模型的同時,新的樣本將被預先處理並儲存在緩衝區中。在目前批次結束時,下一個批次將會被預取到 GPU 記憶體中。
就這樣!