Keras 3 API 文件 / 回呼 API / TensorBoard

TensorBoard

[原始碼]

TensorBoard 類別

keras.callbacks.TensorBoard(
    log_dir="logs",
    histogram_freq=0,
    write_graph=True,
    write_images=False,
    write_steps_per_second=False,
    update_freq="epoch",
    profile_batch=0,
    embeddings_freq=0,
    embeddings_metadata=None,
)

啟用 TensorBoard 的視覺化功能。

TensorBoard 是 TensorFlow 提供的視覺化工具。使用此回呼需要安裝 TensorFlow。

此回呼會記錄 TensorBoard 的事件,包括

  • 指標摘要圖
  • 訓練圖形視覺化
  • 權重直方圖
  • 取樣分析

除了週期摘要之外,當在 model.evaluate() 或常規驗證中使用時,將會有一個摘要記錄評估指標與 model.optimizer.iterations 的關係。指標名稱將會加上 evaluation 前綴,而 model.optimizer.iterations 是視覺化的 TensorBoard 中的步驟。

如果您已透過 pip 安裝 TensorFlow,您應該可以從命令列啟動 TensorBoard

tensorboard --logdir=path_to_your_logs

您可以在這裡找到更多關於 TensorBoard 的資訊。

參數

  • log_dir: 要儲存日誌檔案的路徑,這些檔案將由 TensorBoard 解析。例如,log_dir = os.path.join(working_dir, 'logs')。此目錄不應被任何其他回呼重複使用。
  • histogram_freq: 計算模型層權重直方圖的頻率(以週期為單位)。如果設定為 0,則不會計算直方圖。必須指定驗證資料(或分割)以進行直方圖視覺化。
  • write_graph: (目前不支援) 是否在 TensorBoard 中視覺化圖形。請注意,當 write_graph 設定為 True 時,日誌檔案可能會變得非常大。
  • write_images: 是否寫入模型權重以在 TensorBoard 中視覺化為圖像。
  • write_steps_per_second: 是否將每秒訓練步驟記錄到 TensorBoard 中。這支援週期和批次頻率記錄。
  • update_freq: "batch""epoch" 或整數。當使用 "epoch" 時,會在每個週期後將損失和指標寫入 TensorBoard。如果使用整數,例如 1000,則所有指標和損失(包括由 Model.compile 新增的自訂指標和損失)將每 1000 個批次記錄到 TensorBoard。"batch" 是 1 的同義詞,表示它們將在每個批次寫入。但請注意,過於頻繁地寫入 TensorBoard 可能會減慢您的訓練速度,尤其是在與分散式策略一起使用時,因為這會產生額外的同步開銷。批次層級摘要寫入也可透過 train_step 覆寫使用。請參閱 TensorBoard Scalars 教學以取得更多詳細資訊。
  • profile_batch: 分析批次以取樣計算特性。profile_batch 必須是非負整數或整數元組。一對正整數表示要分析的批次範圍。預設情況下,分析已停用。
  • embeddings_freq: 視覺化嵌入層的頻率(以週期為單位)。如果設定為 0,則不會視覺化嵌入。
  • embeddings_metadata: 字典,將嵌入層名稱映射到檔案的檔案名稱,以在其中儲存嵌入層的元資料。如果所有嵌入層都使用相同的元資料檔案,則可以傳遞單個檔案名稱。

範例

tensorboard_callback = keras.callbacks.TensorBoard(log_dir="./logs")
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
# Then run the tensorboard command to view the visualizations.

子類別模型中的自訂批次層級摘要

class MyModel(keras.Model):

    def build(self, _):
        self.dense = keras.layers.Dense(10)

    def call(self, x):
        outputs = self.dense(x)
        tf.summary.histogram('outputs', outputs)
        return outputs

model = MyModel()
model.compile('sgd', 'mse')

# Make sure to set `update_freq=N` to log a batch-level summary every N
# batches.  In addition to any [`tf.summary`](https://tensorflow.dev.org.tw/api_docs/python/tf/summary) contained in `model.call()`,
# metrics added in `Model.compile` will be logged every N batches.
tb_callback = keras.callbacks.TensorBoard('./logs', update_freq=1)
model.fit(x_train, y_train, callbacks=[tb_callback])

Functional API 模型中的自訂批次層級摘要

def my_summary(x):
    tf.summary.histogram('x', x)
    return x

inputs = keras.Input(10)
x = keras.layers.Dense(10)(inputs)
outputs = keras.layers.Lambda(my_summary)(x)
model = keras.Model(inputs, outputs)
model.compile('sgd', 'mse')

# Make sure to set `update_freq=N` to log a batch-level summary every N
# batches. In addition to any [`tf.summary`](https://tensorflow.dev.org.tw/api_docs/python/tf/summary) contained in `Model.call`,
# metrics added in `Model.compile` will be logged every N batches.
tb_callback = keras.callbacks.TensorBoard('./logs', update_freq=1)
model.fit(x_train, y_train, callbacks=[tb_callback])

分析

# Profile a single batch, e.g. the 5th batch.
tensorboard_callback = keras.callbacks.TensorBoard(
    log_dir='./logs', profile_batch=5)
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])

# Profile a range of batches, e.g. from 10 to 20.
tensorboard_callback = keras.callbacks.TensorBoard(
    log_dir='./logs', profile_batch=(10,20))
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])