Keras 3 API 文件 / 多裝置分散式 / LayoutMap API

LayoutMap API

[原始碼]

LayoutMap 類別

keras.distribution.LayoutMap(device_mesh)

一個類似字典的物件,將字串對應到 TensorLayout 實例。

LayoutMap 使用字串作為鍵,並使用 TensorLayout 作為值。普通的 Python 字典和此類別之間存在行為差異。當檢索值時,字串鍵將被視為正規表示式。請參閱 get 的 docstring 以獲取更多詳細資訊。

請參閱以下的使用範例。您可以定義 TensorLayout 的命名架構,然後檢索對應的 TensorLayout 實例。

在正常情況下,要查詢的鍵通常是 variable.path,它是變數的識別符。

作為快捷方式,當插入為值時,也允許使用軸名稱的 tuple 或 list,它們將被轉換為 TensorLayout

layout_map = LayoutMap(device_mesh)
layout_map['dense.*kernel'] = (None, 'model')
layout_map['dense.*bias'] = ('model',)
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
layout_map['conv2d.*bias'] = ('model',)

layout_1 = layout_map['dense_1.kernel']             # layout_1 == layout_2d
layout_2 = layout_map['dense_1.bias']               # layout_2 == layout_1d
layout_3 = layout_map['dense_2.kernel']             # layout_3 == layout_2d
layout_4 = layout_map['dense_2.bias']               # layout_4 == layout_1d
layout_5 = layout_map['my_model/conv2d_123/kernel'] # layout_5 == layout_4d
layout_6 = layout_map['my_model/conv2d_123/bias']   # layout_6 == layout_1d
layout_7 = layout_map['my_model/conv3d_1/kernel']   # layout_7 == None
layout_8 = layout_map['my_model/conv3d_1/bias']     # layout_8 == None

引數


[原始碼]

DeviceMesh 類別

keras.distribution.DeviceMesh(shape, axis_names, devices=None)

用於分散式計算的計算裝置叢集。

此 API 與 jax.sharding.Meshtf.dtensor.Mesh 對齊,它們表示全域上下文中的計算裝置。

請參閱 jax.sharding.Meshtf.dtensor.Mesh 以獲取更多詳細資訊。

引數

  • shape: 整數 list 的 tuple。整體 DeviceMesh 的形狀,例如,僅用於資料並行分散式的 (8,),或者用於模型+資料並行分散式的 (4, 2)
  • axis_names: 字串 list。DeviceMesh 的每個軸的邏輯名稱。axis_names 的長度應與 shape 的秩匹配。當分散資料和變數時,axis_names 將用於匹配/建立 TensorLayout
  • devices: 裝置的可選 list。預設為來自 keras.distribution.list_devices() 的所有可用的本地裝置。

[原始碼]

TensorLayout 類別

keras.distribution.TensorLayout(axes, device_mesh=None)

要套用至張量的版面配置。

此 API 與 jax.sharding.NamedShardingtf.dtensor.Layout 對齊。

請參閱 jax.sharding.NamedShardingtf.dtensor.Layout 以獲取更多詳細資訊。

引數

  • axes: 應對應到 DeviceMeshaxis_names 的字串 tuple。對於任何不需要任何分片的維度,可以使用 None 作為佔位符。
  • device_mesh: 可選的 DeviceMesh,將用於建立版面配置。在指定網格之前,張量到實際裝置的實際對應關係是未知的。

[原始碼]

distribute_tensor 函數

keras.distribution.distribute_tensor(tensor, layout)

在 jit 函數執行中變更張量值的版面配置。

引數

  • tensor: 要變更版面配置的張量。
  • layout: 要套用至該值的 TensorLayout

返回

具有指定張量版面配置的新值。