LayoutMap
類別keras.distribution.LayoutMap(device_mesh)
一個類似字典的物件,將字串對應到 TensorLayout
實例。
LayoutMap
使用字串作為鍵,並使用 TensorLayout
作為值。普通的 Python 字典和此類別之間存在行為差異。當檢索值時,字串鍵將被視為正規表示式。請參閱 get
的 docstring 以獲取更多詳細資訊。
請參閱以下的使用範例。您可以定義 TensorLayout
的命名架構,然後檢索對應的 TensorLayout
實例。
在正常情況下,要查詢的鍵通常是 variable.path
,它是變數的識別符。
作為快捷方式,當插入為值時,也允許使用軸名稱的 tuple 或 list,它們將被轉換為 TensorLayout
。
layout_map = LayoutMap(device_mesh)
layout_map['dense.*kernel'] = (None, 'model')
layout_map['dense.*bias'] = ('model',)
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
layout_map['conv2d.*bias'] = ('model',)
layout_1 = layout_map['dense_1.kernel'] # layout_1 == layout_2d
layout_2 = layout_map['dense_1.bias'] # layout_2 == layout_1d
layout_3 = layout_map['dense_2.kernel'] # layout_3 == layout_2d
layout_4 = layout_map['dense_2.bias'] # layout_4 == layout_1d
layout_5 = layout_map['my_model/conv2d_123/kernel'] # layout_5 == layout_4d
layout_6 = layout_map['my_model/conv2d_123/bias'] # layout_6 == layout_1d
layout_7 = layout_map['my_model/conv3d_1/kernel'] # layout_7 == None
layout_8 = layout_map['my_model/conv3d_1/bias'] # layout_8 == None
引數
keras.distribution.DeviceMesh
實例。DeviceMesh
類別keras.distribution.DeviceMesh(shape, axis_names, devices=None)
用於分散式計算的計算裝置叢集。
此 API 與 jax.sharding.Mesh
和 tf.dtensor.Mesh
對齊,它們表示全域上下文中的計算裝置。
請參閱 jax.sharding.Mesh 和 tf.dtensor.Mesh 以獲取更多詳細資訊。
引數
DeviceMesh
的形狀,例如,僅用於資料並行分散式的 (8,)
,或者用於模型+資料並行分散式的 (4, 2)
。DeviceMesh
的每個軸的邏輯名稱。axis_names
的長度應與 shape
的秩匹配。當分散資料和變數時,axis_names
將用於匹配/建立 TensorLayout
。keras.distribution.list_devices()
的所有可用的本地裝置。TensorLayout
類別keras.distribution.TensorLayout(axes, device_mesh=None)
要套用至張量的版面配置。
此 API 與 jax.sharding.NamedSharding
和 tf.dtensor.Layout
對齊。
請參閱 jax.sharding.NamedSharding 和 tf.dtensor.Layout 以獲取更多詳細資訊。
引數
DeviceMesh
中 axis_names
的字串 tuple。對於任何不需要任何分片的維度,可以使用 None
作為佔位符。DeviceMesh
,將用於建立版面配置。在指定網格之前,張量到實際裝置的實際對應關係是未知的。distribute_tensor
函數keras.distribution.distribute_tensor(tensor, layout)
在 jit 函數執行中變更張量值的版面配置。
引數
TensorLayout
。返回
具有指定張量版面配置的新值。