作者: Qianli Zhu
建立日期 2023/11/07
上次修改日期 2023/11/07
說明: 多後端 Keras 分散式 API 的完整指南。
Keras 分散式 API 是一個新的介面,旨在促進跨多種後端(例如 JAX、TensorFlow 和 PyTorch)的分散式深度學習。這個強大的 API 引入了一套工具,可實現資料和模型平行化,進而允許在多個加速器和主機上有效擴展深度學習模型。無論是利用 GPU 或 TPU 的強大功能,該 API 都提供了一種簡化的方法來初始化分散式環境、定義裝置網格以及協調張量在計算資源中的佈局。透過諸如 DataParallel
和 ModelParallel
之類的類別,它抽象了平行計算中涉及的複雜性,讓開發人員更容易加速其機器學習工作流程。
Keras 分散式 API 提供了一個全域程式設計模型,允許開發人員撰寫在全域環境中對張量進行操作的應用程式(如同在單一裝置上工作),同時自動管理跨多個裝置的分散。此 API 利用底層框架(例如 JAX)透過稱為單一程式、多個資料 (SPMD) 擴展的程序,根據分片指令來分散程式和張量。
透過將應用程式與分片指令分離,此 API 能夠在單一裝置、多個裝置甚至是多個用戶端上執行相同的應用程式,同時保留其全域語意。
import os
# The distribution API is only implemented for the JAX backend for now.
os.environ["KERAS_BACKEND"] = "jax"
import keras
from keras import layers
import jax
import numpy as np
from tensorflow import data as tf_data # For dataset input.
DeviceMesh
和 TensorLayout
Keras 分散式 API 中的 keras.distribution.DeviceMesh
類別代表為分散式計算設定的計算裝置叢集。它與 jax.sharding.Mesh
和 tf.dtensor.Mesh
中的類似概念一致,在其中用於將實體裝置對應到邏輯網格結構。
接著,TensorLayout
類別指定張量如何在 DeviceMesh
上分散,詳細說明了張量沿著對應於 DeviceMesh
中軸名稱的指定軸進行分片。
您可以在 TensorFlow DTensor 指南中找到更詳細的概念說明。
# Retrieve the local available gpu devices.
devices = jax.devices("gpu") # Assume it has 8 local GPUs.
# Define a 2x4 device mesh with data and model parallel axes
mesh = keras.distribution.DeviceMesh(
shape=(2, 4), axis_names=["data", "model"], devices=devices
)
# A 2D layout, which describes how a tensor is distributed across the
# mesh. The layout can be visualized as a 2D grid with "model" as rows and
# "data" as columns, and it is a [4, 2] grid when it mapped to the physical
# devices on the mesh.
layout_2d = keras.distribution.TensorLayout(axes=("model", "data"), device_mesh=mesh)
# A 4D layout which could be used for data parallel of a image input.
replicated_layout_4d = keras.distribution.TensorLayout(
axes=("data", None, None, None), device_mesh=mesh
)
Keras 中的 Distribution
類別用作設計自訂分散策略的基本抽象類別。它封裝了將模型變數、輸入資料和中間計算分散到裝置網格所需的核心邏輯。作為最終使用者,您不需要直接與此類別互動,而是與其子類別(如 DataParallel
或 ModelParallel
)互動。
Keras 分散式 API 中的 DataParallel
類別是為分散式訓練中的資料平行化策略而設計的,其中模型權重會複製到 DeviceMesh
中的所有裝置,且每個裝置都會處理一部分的輸入資料。
以下是此類別的範例用法。
# Create DataParallel with list of devices.
# As a shortcut, the devices can be skipped,
# and Keras will detect all local available devices.
# E.g. data_parallel = DataParallel()
data_parallel = keras.distribution.DataParallel(devices=devices)
# Or you can choose to create DataParallel with a 1D `DeviceMesh`.
mesh_1d = keras.distribution.DeviceMesh(
shape=(8,), axis_names=["data"], devices=devices
)
data_parallel = keras.distribution.DataParallel(device_mesh=mesh_1d)
inputs = np.random.normal(size=(128, 28, 28, 1))
labels = np.random.normal(size=(128, 10))
dataset = tf_data.Dataset.from_tensor_slices((inputs, labels)).batch(16)
# Set the global distribution.
keras.distribution.set_distribution(data_parallel)
# Note that all the model weights from here on are replicated to
# all the devices of the `DeviceMesh`. This includes the RNG
# state, optimizer states, metrics, etc. The dataset fed into `model.fit` or
# `model.evaluate` will be split evenly on the batch dimension, and sent to
# all the devices. You don't have to do any manual aggregration of losses,
# since all the computation happens in a global context.
inputs = layers.Input(shape=(28, 28, 1))
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax")(y)
model = keras.Model(inputs=inputs, outputs=y)
model.compile(loss="mse")
model.fit(dataset, epochs=3)
model.evaluate(dataset)
Epoch 1/3
8/8 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - loss: 1.0116
Epoch 2/3
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.9237
Epoch 3/3
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.8736
8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - loss: 0.8349
0.842325747013092
ModelParallel
和 LayoutMap
當模型權重太大而無法容納在單一加速器上時,ModelParallel
將會非常有用。此設定可讓您在 DeviceMesh
上所有裝置中分割模型權重或啟用張量,並為大型模型啟用水平擴展。
與所有權重完全複製的 DataParallel
模型不同,ModelParallel
下的權重佈局通常需要一些自訂才能獲得最佳效能。我們引入 LayoutMap
來讓您從全域角度為任何權重和中間張量指定 TensorLayout
。
LayoutMap
是一個類似字典的物件,可將字串對應至 TensorLayout
實例。它與一般的 Python 字典的行為不同,字串索引鍵在擷取值時會被視為 regex。此類別可讓您定義 TensorLayout
的命名架構,然後擷取對應的 TensorLayout
實例。通常,用於查詢的索引鍵是 variable.path
屬性,它是變數的識別碼。作為捷徑,也允許在插入值時使用軸名稱的 tuple 或清單,並且它會轉換為 TensorLayout
。
LayoutMap
也可以選擇性地包含 DeviceMesh
,以在未設定時填入 TensorLayout.device_mesh
。當擷取具有索引鍵的佈局時,如果沒有完全相符的項目,則佈局對應中的所有現有索引鍵都會被視為 regex,並再次與輸入索引鍵進行比對。如果有多次比對,則會引發 ValueError
。如果找不到任何相符的項目,則會傳回 None
。
mesh_2d = keras.distribution.DeviceMesh(
shape=(2, 4), axis_names=["data", "model"], devices=devices
)
layout_map = keras.distribution.LayoutMap(mesh_2d)
# The rule below means that for any weights that match with d1/kernel, it
# will be sharded with model dimensions (4 devices), same for the d1/bias.
# All other weights will be fully replicated.
layout_map["d1/kernel"] = (None, "model")
layout_map["d1/bias"] = ("model",)
# You can also set the layout for the layer output like
layout_map["d2/output"] = ("data", None)
model_parallel = keras.distribution.ModelParallel(layout_map, batch_dim_name="data")
keras.distribution.set_distribution(model_parallel)
inputs = layers.Input(shape=(28, 28, 1))
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu", name="d1")(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax", name="d2")(y)
model = keras.Model(inputs=inputs, outputs=y)
# The data will be sharded across the "data" dimension of the method, which
# has 2 devices.
model.compile(loss="mse")
model.fit(dataset, epochs=3)
model.evaluate(dataset)
Epoch 1/3
/opt/conda/envs/keras-jax/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:761: UserWarning: Some donated buffers were not usable: ShapedArray(float32[784,50]).
See an explanation at https://jax.dev.org.tw/en/latest/faq.html#buffer-donation.
warnings.warn("Some donated buffers were not usable:"
8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - loss: 1.0266
Epoch 2/3
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.9181
Epoch 3/3
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.8725
8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.8381
0.8502610325813293
也可以輕鬆變更網格結構,以調整更多資料平行或模型平行之間的計算。您可以透過調整網格的形狀來執行此動作。而且,任何其他程式碼都不需要變更。
full_data_parallel_mesh = keras.distribution.DeviceMesh(
shape=(8, 1), axis_names=["data", "model"], devices=devices
)
more_data_parallel_mesh = keras.distribution.DeviceMesh(
shape=(4, 2), axis_names=["data", "model"], devices=devices
)
more_model_parallel_mesh = keras.distribution.DeviceMesh(
shape=(2, 4), axis_names=["data", "model"], devices=devices
)
full_model_parallel_mesh = keras.distribution.DeviceMesh(
shape=(1, 8), axis_names=["data", "model"], devices=devices
)