Keras 3 API 文件 / 多裝置分散式處理 / ModelParallel API

ModelParallel API

[原始碼]

ModelParallel 類別

keras.distribution.ModelParallel(layout_map=None, batch_dim_name=None, **kwargs)

分散模型變數的分散式處理。

與在所有裝置上複製變數的 DataParallel 相比,ModelParallel 允許您除了輸入資料外,還能分割變數。

要建構 ModelParallel 分散式處理,您需要提供 DeviceMeshLayoutMap

  1. DeviceMesh 包含實體裝置資訊。網格中的軸名稱將用於映射變數和資料佈局。
  2. LayoutMap 包含變數路徑與其對應的 TensorLayout 之間的映射。

範例

devices = list_devices()    # Assume there are 8 devices.

# Create a mesh with 2 devices for data parallelism and 4 devices for
# model parallelism.
device_mesh = DeviceMesh(shape=(2, 4), axis_names=('batch', 'model'),
                         devices=devices)
# Create a layout map that shard the `Dense` layer and `Conv2D`
# layer variables on the last dimension.
# Based on the `device_mesh`, this means the variables
# will be split across 4 devices. Any other variable that doesn't
# match any key in the layout map will be fully replicated.
layout_map = LayoutMap(device_mesh)
layout_map['dense.*kernel'] = (None, 'model')
layout_map['dense.*bias'] = ('model',)
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
layout_map['conv2d.*bias'] = ('model',)

distribution = ModelParallel(
    layout_map=layout_map,
    batch_dim_name='batch',
)

# Set the global distribution, or via `with distribution.scope():`
set_distribution(distribution)

model = model_creation()
model.compile()
model.fit(data)

您可以快速更新裝置網格形狀,以更改變數的分散因子。例如:

# With only the shape change for the device mesh, the variables will be
# sharded across 8 devices instead of 4, which further reduces the memory
# footprint of variables on each of the device.
device_mesh = DeviceMesh(
    shape=(1, 8),
    axis_names=('batch', 'model'),
    devices=devices,
)

為了找出所有模型變數的正確佈局映射規則,您可以先列出所有模型變數路徑,這些路徑將用作將變數映射到 TensorLayout 的索引鍵。

例如:

model = create_model()
for v in model.variables:
    print(v.path)

引數

  • layout_mapLayoutMap 實例,將變數路徑映射到對應的張量佈局。
  • batch_dim_name:可選字串,裝置網格中(layout_map 物件的)將用於分散資料的軸名稱。如果未指定,將使用裝置網格的第一個軸。