開發者指南 / 使用 Keras 3 進行分散式訓練

使用 Keras 3 進行分散式訓練

作者: Qianli Zhu
建立日期 2023/11/07
上次修改日期 2023/11/07
說明: 多後端 Keras 分散式 API 的完整指南。

在 Colab 中檢視 GitHub 原始碼


簡介

Keras 分散式 API 是一個新的介面,旨在促進跨多種後端(例如 JAX、TensorFlow 和 PyTorch)的分散式深度學習。這個強大的 API 引入了一套工具,可實現資料和模型平行化,進而允許在多個加速器和主機上有效擴展深度學習模型。無論是利用 GPU 或 TPU 的強大功能,該 API 都提供了一種簡化的方法來初始化分散式環境、定義裝置網格以及協調張量在計算資源中的佈局。透過諸如 DataParallelModelParallel 之類的類別,它抽象了平行計算中涉及的複雜性,讓開發人員更容易加速其機器學習工作流程。


運作方式

Keras 分散式 API 提供了一個全域程式設計模型,允許開發人員撰寫在全域環境中對張量進行操作的應用程式(如同在單一裝置上工作),同時自動管理跨多個裝置的分散。此 API 利用底層框架(例如 JAX)透過稱為單一程式、多個資料 (SPMD) 擴展的程序,根據分片指令來分散程式和張量。

透過將應用程式與分片指令分離,此 API 能夠在單一裝置、多個裝置甚至是多個用戶端上執行相同的應用程式,同時保留其全域語意。


設定

import os

# The distribution API is only implemented for the JAX backend for now.
os.environ["KERAS_BACKEND"] = "jax"

import keras
from keras import layers
import jax
import numpy as np
from tensorflow import data as tf_data  # For dataset input.

DeviceMeshTensorLayout

Keras 分散式 API 中的 keras.distribution.DeviceMesh 類別代表為分散式計算設定的計算裝置叢集。它與 jax.sharding.Meshtf.dtensor.Mesh 中的類似概念一致,在其中用於將實體裝置對應到邏輯網格結構。

接著,TensorLayout 類別指定張量如何在 DeviceMesh 上分散,詳細說明了張量沿著對應於 DeviceMesh 中軸名稱的指定軸進行分片。

您可以在 TensorFlow DTensor 指南中找到更詳細的概念說明。

# Retrieve the local available gpu devices.
devices = jax.devices("gpu")  # Assume it has 8 local GPUs.

# Define a 2x4 device mesh with data and model parallel axes
mesh = keras.distribution.DeviceMesh(
    shape=(2, 4), axis_names=["data", "model"], devices=devices
)

# A 2D layout, which describes how a tensor is distributed across the
# mesh. The layout can be visualized as a 2D grid with "model" as rows and
# "data" as columns, and it is a [4, 2] grid when it mapped to the physical
# devices on the mesh.
layout_2d = keras.distribution.TensorLayout(axes=("model", "data"), device_mesh=mesh)

# A 4D layout which could be used for data parallel of a image input.
replicated_layout_4d = keras.distribution.TensorLayout(
    axes=("data", None, None, None), device_mesh=mesh
)

分散

Keras 中的 Distribution 類別用作設計自訂分散策略的基本抽象類別。它封裝了將模型變數、輸入資料和中間計算分散到裝置網格所需的核心邏輯。作為最終使用者,您不需要直接與此類別互動,而是與其子類別(如 DataParallelModelParallel)互動。


DataParallel

Keras 分散式 API 中的 DataParallel 類別是為分散式訓練中的資料平行化策略而設計的,其中模型權重會複製到 DeviceMesh 中的所有裝置,且每個裝置都會處理一部分的輸入資料。

以下是此類別的範例用法。

# Create DataParallel with list of devices.
# As a shortcut, the devices can be skipped,
# and Keras will detect all local available devices.
# E.g. data_parallel = DataParallel()
data_parallel = keras.distribution.DataParallel(devices=devices)

# Or you can choose to create DataParallel with a 1D `DeviceMesh`.
mesh_1d = keras.distribution.DeviceMesh(
    shape=(8,), axis_names=["data"], devices=devices
)
data_parallel = keras.distribution.DataParallel(device_mesh=mesh_1d)

inputs = np.random.normal(size=(128, 28, 28, 1))
labels = np.random.normal(size=(128, 10))
dataset = tf_data.Dataset.from_tensor_slices((inputs, labels)).batch(16)

# Set the global distribution.
keras.distribution.set_distribution(data_parallel)

# Note that all the model weights from here on are replicated to
# all the devices of the `DeviceMesh`. This includes the RNG
# state, optimizer states, metrics, etc. The dataset fed into `model.fit` or
# `model.evaluate` will be split evenly on the batch dimension, and sent to
# all the devices. You don't have to do any manual aggregration of losses,
# since all the computation happens in a global context.
inputs = layers.Input(shape=(28, 28, 1))
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax")(y)
model = keras.Model(inputs=inputs, outputs=y)

model.compile(loss="mse")
model.fit(dataset, epochs=3)
model.evaluate(dataset)
Epoch 1/3
 8/8 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - loss: 1.0116
Epoch 2/3
 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.9237
Epoch 3/3
 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.8736
 8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - loss: 0.8349

0.842325747013092

ModelParallelLayoutMap

當模型權重太大而無法容納在單一加速器上時,ModelParallel 將會非常有用。此設定可讓您在 DeviceMesh 上所有裝置中分割模型權重或啟用張量,並為大型模型啟用水平擴展。

與所有權重完全複製的 DataParallel 模型不同,ModelParallel 下的權重佈局通常需要一些自訂才能獲得最佳效能。我們引入 LayoutMap 來讓您從全域角度為任何權重和中間張量指定 TensorLayout

LayoutMap 是一個類似字典的物件,可將字串對應至 TensorLayout 實例。它與一般的 Python 字典的行為不同,字串索引鍵在擷取值時會被視為 regex。此類別可讓您定義 TensorLayout 的命名架構,然後擷取對應的 TensorLayout 實例。通常,用於查詢的索引鍵是 variable.path 屬性,它是變數的識別碼。作為捷徑,也允許在插入值時使用軸名稱的 tuple 或清單,並且它會轉換為 TensorLayout

LayoutMap 也可以選擇性地包含 DeviceMesh,以在未設定時填入 TensorLayout.device_mesh。當擷取具有索引鍵的佈局時,如果沒有完全相符的項目,則佈局對應中的所有現有索引鍵都會被視為 regex,並再次與輸入索引鍵進行比對。如果有多次比對,則會引發 ValueError。如果找不到任何相符的項目,則會傳回 None

mesh_2d = keras.distribution.DeviceMesh(
    shape=(2, 4), axis_names=["data", "model"], devices=devices
)
layout_map = keras.distribution.LayoutMap(mesh_2d)
# The rule below means that for any weights that match with d1/kernel, it
# will be sharded with model dimensions (4 devices), same for the d1/bias.
# All other weights will be fully replicated.
layout_map["d1/kernel"] = (None, "model")
layout_map["d1/bias"] = ("model",)

# You can also set the layout for the layer output like
layout_map["d2/output"] = ("data", None)

model_parallel = keras.distribution.ModelParallel(layout_map, batch_dim_name="data")

keras.distribution.set_distribution(model_parallel)

inputs = layers.Input(shape=(28, 28, 1))
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu", name="d1")(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax", name="d2")(y)
model = keras.Model(inputs=inputs, outputs=y)

# The data will be sharded across the "data" dimension of the method, which
# has 2 devices.
model.compile(loss="mse")
model.fit(dataset, epochs=3)
model.evaluate(dataset)
Epoch 1/3

/opt/conda/envs/keras-jax/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:761: UserWarning: Some donated buffers were not usable: ShapedArray(float32[784,50]).
See an explanation at https://jax.dev.org.tw/en/latest/faq.html#buffer-donation.
  warnings.warn("Some donated buffers were not usable:"

 8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - loss: 1.0266
Epoch 2/3
 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.9181
Epoch 3/3
 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.8725
 8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.8381  

0.8502610325813293

也可以輕鬆變更網格結構,以調整更多資料平行或模型平行之間的計算。您可以透過調整網格的形狀來執行此動作。而且,任何其他程式碼都不需要變更。

full_data_parallel_mesh = keras.distribution.DeviceMesh(
    shape=(8, 1), axis_names=["data", "model"], devices=devices
)
more_data_parallel_mesh = keras.distribution.DeviceMesh(
    shape=(4, 2), axis_names=["data", "model"], devices=devices
)
more_model_parallel_mesh = keras.distribution.DeviceMesh(
    shape=(2, 4), axis_names=["data", "model"], devices=devices
)
full_model_parallel_mesh = keras.distribution.DeviceMesh(
    shape=(1, 8), axis_names=["data", "model"], devices=devices
)

延伸閱讀

  1. JAX 分散式陣列和自動平行化
  2. JAX 分片模組
  3. 使用 DTensors 的 TensorFlow 分散式訓練
  4. TensorFlow DTensor 概念
  5. 將 DTensors 與 tf.keras 搭配使用