Keras 3 API 文件 / 多裝置分散式 / LayoutMap API

LayoutMap API

[原始碼]

LayoutMap 類別

keras.distribution.LayoutMap(device_mesh)

一個類似字典的物件,將字串對應到 TensorLayout 實例。

LayoutMap 使用字串作為鍵,TensorLayout 作為值。一般的 Python 字典和此類別之間存在行為差異。當檢索值時,字串鍵將被視為正則表達式。詳情請參閱 get 的 docstring。

請參閱下方的使用範例。您可以定義 TensorLayout 的命名模式,然後檢索對應的 TensorLayout 實例。

在一般情況下,要查詢的鍵通常是 variable.path,它是變數的識別符。

作為快捷方式,當作為值插入時,也允許使用軸名稱的元組或列表,並將轉換為 TensorLayout

layout_map = LayoutMap(device_mesh)
layout_map['dense.*kernel'] = (None, 'model')
layout_map['dense.*bias'] = ('model',)
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
layout_map['conv2d.*bias'] = ('model',)

layout_1 = layout_map['dense_1.kernel']             # layout_1 == layout_2d
layout_2 = layout_map['dense_1.bias']               # layout_2 == layout_1d
layout_3 = layout_map['dense_2.kernel']             # layout_3 == layout_2d
layout_4 = layout_map['dense_2.bias']               # layout_4 == layout_1d
layout_5 = layout_map['my_model/conv2d_123/kernel'] # layout_5 == layout_4d
layout_6 = layout_map['my_model/conv2d_123/bias']   # layout_6 == layout_1d
layout_7 = layout_map['my_model/conv3d_1/kernel']   # layout_7 == None
layout_8 = layout_map['my_model/conv3d_1/bias']     # layout_8 == None

引數


[原始碼]

DeviceMesh 類別

keras.distribution.DeviceMesh(shape, axis_names, devices=None)

用於分散式計算的運算裝置叢集。

此 API 與 jax.sharding.Meshtf.dtensor.Mesh 對齊,它們代表全域上下文中的運算裝置。

更多詳細資訊請參閱 jax.sharding.Meshtf.dtensor.Mesh

引數

  • shape: 整數列表的元組。整體 DeviceMesh 的形狀,例如,資料並行分佈為 (8,),或模型+資料並行分佈為 (4, 2)
  • axis_names: 字串列表。DeviceMesh 中每個軸的邏輯名稱。axis_names 的長度應與 shape 的秩相符。當分佈資料和變數時,axis_names 將用於匹配/建立 TensorLayout
  • devices: 裝置的可選列表。預設為從 keras.distribution.list_devices() 本機可用的所有裝置。

[原始碼]

TensorLayout 類別

keras.distribution.TensorLayout(axes, device_mesh=None)

要應用於張量的佈局。

此 API 與 jax.sharding.NamedShardingtf.dtensor.Layout 對齊。

更多詳細資訊請參閱 jax.sharding.NamedShardingtf.dtensor.Layout

引數

  • axes: 應對應到 DeviceMeshaxis_names 的字串元組。對於任何不需要分片的維度,可以使用 None 作為佔位符。
  • device_mesh: 可選的 DeviceMesh,將用於建立佈局。在指定 mesh 之前,張量到物理裝置的實際映射是未知的。

[原始碼]

distribute_tensor 函數

keras.distribution.distribute_tensor(tensor, layout)

在 jit 函數執行中變更張量值的佈局。

引數

  • tensor: 要變更佈局的張量。
  • layout: 要應用於值的 TensorLayout

返回

具有指定張量佈局的新值。