LayoutMap
類別keras.distribution.LayoutMap(device_mesh)
一個類似字典的物件,將字串對應到 TensorLayout
實例。
LayoutMap
使用字串作為鍵,TensorLayout
作為值。一般的 Python 字典和此類別之間存在行為差異。當檢索值時,字串鍵將被視為正則表達式。詳情請參閱 get
的 docstring。
請參閱下方的使用範例。您可以定義 TensorLayout
的命名模式,然後檢索對應的 TensorLayout
實例。
在一般情況下,要查詢的鍵通常是 variable.path
,它是變數的識別符。
作為快捷方式,當作為值插入時,也允許使用軸名稱的元組或列表,並將轉換為 TensorLayout
。
layout_map = LayoutMap(device_mesh)
layout_map['dense.*kernel'] = (None, 'model')
layout_map['dense.*bias'] = ('model',)
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
layout_map['conv2d.*bias'] = ('model',)
layout_1 = layout_map['dense_1.kernel'] # layout_1 == layout_2d
layout_2 = layout_map['dense_1.bias'] # layout_2 == layout_1d
layout_3 = layout_map['dense_2.kernel'] # layout_3 == layout_2d
layout_4 = layout_map['dense_2.bias'] # layout_4 == layout_1d
layout_5 = layout_map['my_model/conv2d_123/kernel'] # layout_5 == layout_4d
layout_6 = layout_map['my_model/conv2d_123/bias'] # layout_6 == layout_1d
layout_7 = layout_map['my_model/conv3d_1/kernel'] # layout_7 == None
layout_8 = layout_map['my_model/conv3d_1/bias'] # layout_8 == None
引數
keras.distribution.DeviceMesh
實例。DeviceMesh
類別keras.distribution.DeviceMesh(shape, axis_names, devices=None)
用於分散式計算的運算裝置叢集。
此 API 與 jax.sharding.Mesh
和 tf.dtensor.Mesh
對齊,它們代表全域上下文中的運算裝置。
更多詳細資訊請參閱 jax.sharding.Mesh 和 tf.dtensor.Mesh。
引數
DeviceMesh
的形狀,例如,資料並行分佈為 (8,)
,或模型+資料並行分佈為 (4, 2)
。DeviceMesh
中每個軸的邏輯名稱。axis_names
的長度應與 shape
的秩相符。當分佈資料和變數時,axis_names
將用於匹配/建立 TensorLayout
。keras.distribution.list_devices()
本機可用的所有裝置。TensorLayout
類別keras.distribution.TensorLayout(axes, device_mesh=None)
要應用於張量的佈局。
此 API 與 jax.sharding.NamedSharding
和 tf.dtensor.Layout
對齊。
更多詳細資訊請參閱 jax.sharding.NamedSharding 和 tf.dtensor.Layout。
引數
DeviceMesh
中 axis_names
的字串元組。對於任何不需要分片的維度,可以使用 None
作為佔位符。DeviceMesh
,將用於建立佈局。在指定 mesh 之前,張量到物理裝置的實際映射是未知的。distribute_tensor
函數keras.distribution.distribute_tensor(tensor, layout)
在 jit 函數執行中變更張量值的佈局。
引數
TensorLayout
。返回
具有指定張量佈局的新值。