ModelParallel
類別keras.distribution.ModelParallel(layout_map=None, batch_dim_name=None, **kwargs)
分散模型變數的分散式處理。
與在所有裝置上複製變數的 DataParallel
相比,ModelParallel
允許您除了輸入資料外,還能分割變數。
要建構 ModelParallel
分散式處理,您需要提供 DeviceMesh
和 LayoutMap
。
DeviceMesh
包含實體裝置資訊。網格中的軸名稱將用於映射變數和資料佈局。LayoutMap
包含變數路徑與其對應的 TensorLayout
之間的映射。範例
devices = list_devices() # Assume there are 8 devices.
# Create a mesh with 2 devices for data parallelism and 4 devices for
# model parallelism.
device_mesh = DeviceMesh(shape=(2, 4), axis_names=('batch', 'model'),
devices=devices)
# Create a layout map that shard the `Dense` layer and `Conv2D`
# layer variables on the last dimension.
# Based on the `device_mesh`, this means the variables
# will be split across 4 devices. Any other variable that doesn't
# match any key in the layout map will be fully replicated.
layout_map = LayoutMap(device_mesh)
layout_map['dense.*kernel'] = (None, 'model')
layout_map['dense.*bias'] = ('model',)
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
layout_map['conv2d.*bias'] = ('model',)
distribution = ModelParallel(
layout_map=layout_map,
batch_dim_name='batch',
)
# Set the global distribution, or via `with distribution.scope():`
set_distribution(distribution)
model = model_creation()
model.compile()
model.fit(data)
您可以快速更新裝置網格形狀,以更改變數的分散因子。例如:
# With only the shape change for the device mesh, the variables will be
# sharded across 8 devices instead of 4, which further reduces the memory
# footprint of variables on each of the device.
device_mesh = DeviceMesh(
shape=(1, 8),
axis_names=('batch', 'model'),
devices=devices,
)
為了找出所有模型變數的正確佈局映射規則,您可以先列出所有模型變數路徑,這些路徑將用作將變數映射到 TensorLayout
的索引鍵。
例如:
model = create_model()
for v in model.variables:
print(v.path)
引數
LayoutMap
實例,將變數路徑映射到對應的張量佈局。layout_map
物件的)將用於分散資料的軸名稱。如果未指定,將使用裝置網格的第一個軸。