作者: fchollet
建立日期 2023/06/25
上次修改日期 2023/06/25
描述: 在 PyTorch 中撰寫底層訓練和評估迴圈。
import os
# This guide can only be run with the torch backend.
os.environ["KERAS_BACKEND"] = "torch"
import torch
import keras
import numpy as np
Keras 提供預設的訓練和評估迴圈,fit()
和 evaluate()
。其使用方法在 使用內建方法進行訓練和評估 指南中介紹。
如果您想要自訂模型的學習演算法,同時仍然利用 fit()
的便利性(例如,使用 fit()
訓練 GAN),您可以子類別化 Model
類別並實作您自己的 train_step()
方法,該方法會在 fit()
期間重複呼叫。
現在,如果您想要對訓練和評估進行非常底層的控制,您應該從頭開始撰寫您自己的訓練和評估迴圈。這就是本指南的重點。
要撰寫自訂訓練迴圈,我們需要以下要素:
keras.optimizers
最佳化器,或來自 torch.optim
的原生 PyTorch 最佳化器。keras.losses
損失,或來自 torch.nn
的原生 PyTorch 損失。tf.data.Dataset
、PyTorch DataLoader
、Python 生成器等。讓我們將它們排成一列。在每個案例中,我們都將使用原生 torch 物件,當然,除了 Keras 模型。
首先,讓我們取得模型和 MNIST 資料集
# Let's consider a simple MNIST model
def get_model():
inputs = keras.Input(shape=(784,), name="digits")
x1 = keras.layers.Dense(64, activation="relu")(inputs)
x2 = keras.layers.Dense(64, activation="relu")(x1)
outputs = keras.layers.Dense(10, name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
# Create load up the MNIST dataset and put it in a torch DataLoader
# Prepare the training dataset.
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784)).astype("float32")
x_test = np.reshape(x_test, (-1, 784)).astype("float32")
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)
# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
# Create torch Datasets
train_dataset = torch.utils.data.TensorDataset(
torch.from_numpy(x_train), torch.from_numpy(y_train)
)
val_dataset = torch.utils.data.TensorDataset(
torch.from_numpy(x_val), torch.from_numpy(y_val)
)
# Create DataLoaders for the Datasets
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset, batch_size=batch_size, shuffle=False
)
接下來,這是我們的 PyTorch 最佳化器和 PyTorch 損失函數
# Instantiate a torch optimizer
model = get_model()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Instantiate a torch loss function
loss_fn = torch.nn.CrossEntropyLoss()
讓我們使用帶有自訂訓練迴圈的小批量梯度來訓練我們的模型。
在損失張量上呼叫 loss.backward()
會觸發反向傳播。一旦完成,您的最佳化器就會神奇地知道每個變數的梯度,並可以更新其變數,這是透過 optimizer.step()
完成的。張量、變數和最佳化器都透過隱藏的全域狀態相互連接。此外,在 loss.backward()
之前別忘了呼叫 model.zero_grad()
,否則您將無法獲得變數的正確梯度。
這是我們的訓練迴圈,逐步說明
for
迴圈,該迴圈會迭代 epochfor
迴圈,該迴圈會迭代資料集中的批量loss.backward()
以epochs = 3
for epoch in range(epochs):
for step, (inputs, targets) in enumerate(train_dataloader):
# Forward pass
logits = model(inputs)
loss = loss_fn(logits, targets)
# Backward pass
model.zero_grad()
loss.backward()
# Optimizer variable updates
optimizer.step()
# Log every 100 batches.
if step % 100 == 0:
print(
f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
)
print(f"Seen so far: {(step + 1) * batch_size} samples")
Training loss (for 1 batch) at step 0: 110.9115
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.9493
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 2.7383
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 1.6616
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 1.5927
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 1.0992
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.5425
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.3308
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.8231
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.5570
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.6321
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.4962
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 1.0833
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 1.3607
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 1.1250
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 1.2562
Seen so far: 48032 samples
Training loss (for 1 batch) at step 0: 0.5181
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.3939
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.3406
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.1122
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.2015
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.1184
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 1.0702
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.4062
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.4570
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 1.2490
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.0714
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.3677
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.8291
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.8320
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.1179
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.5390
Seen so far: 48032 samples
Training loss (for 1 batch) at step 0: 0.1309
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.4061
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.2734
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.2972
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.4282
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.3504
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.3556
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.7834
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.2522
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.2056
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.3259
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.5215
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.8051
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.4423
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.0473
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.1419
Seen so far: 48032 samples
作為替代方案,讓我們看看當使用 Keras 最佳化器和 Keras 損失函數時,迴圈會是什麼樣子。
重要差異
v.value.grad
來擷取變數的梯度。optimizer.apply()
更新變數,該函數必須在 torch.no_grad()
範圍內呼叫。還有一個很大的陷阱:雖然所有 NumPy/TensorFlow/JAX/Keras API 以及 Python unittest
API 都使用引數順序慣例 fn(y_true, y_pred)
(參考值優先,預測值第二),但 PyTorch 實際上對其損失使用 fn(y_pred, y_true)
。因此,請務必反轉 logits
和 targets
的順序。
model = get_model()
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
for epoch in range(epochs):
print(f"\nStart of epoch {epoch}")
for step, (inputs, targets) in enumerate(train_dataloader):
# Forward pass
logits = model(inputs)
loss = loss_fn(targets, logits)
# Backward pass
model.zero_grad()
trainable_weights = [v for v in model.trainable_weights]
# Call torch.Tensor.backward() on the loss to compute gradients
# for the weights.
loss.backward()
gradients = [v.value.grad for v in trainable_weights]
# Update weights
with torch.no_grad():
optimizer.apply(gradients, trainable_weights)
# Log every 100 batches.
if step % 100 == 0:
print(
f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
)
print(f"Seen so far: {(step + 1) * batch_size} samples")
Start of epoch 0
Training loss (for 1 batch) at step 0: 98.9569
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 5.3304
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.3246
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 1.6745
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 1.0936
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 1.4159
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.2796
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 2.3532
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.7533
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 1.0432
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.3959
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.4722
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.3851
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.8599
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.1237
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.4919
Seen so far: 48032 samples
Start of epoch 1
Training loss (for 1 batch) at step 0: 0.8972
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.5844
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.1285
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.0671
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.4296
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.1483
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.0230
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.1368
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.1531
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.0472
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.2343
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.4449
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.3942
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.3236
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.0717
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.9288
Seen so far: 48032 samples
Start of epoch 2
Training loss (for 1 batch) at step 0: 0.9393
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.2383
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.1116
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.6736
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.6713
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.3394
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.2385
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.4248
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.0200
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.1259
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.7566
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.0594
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.2821
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.2088
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.5654
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.0512
Seen so far: 48032 samples
讓我們將度量監控新增到這個基本的訓練迴圈中。
您可以輕鬆地在從頭開始撰寫的此類訓練迴圈中重複使用內建的 Keras 度量(或您撰寫的自訂度量)。這是流程:
metric.update_state()
metric.result()
metric.reset_state()
(通常在 epoch 結束時)讓我們利用這些知識來計算每個 epoch 結束時訓練和驗證資料的 CategoricalAccuracy
# Get a fresh model
model = get_model()
# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
# Prepare the metrics.
train_acc_metric = keras.metrics.CategoricalAccuracy()
val_acc_metric = keras.metrics.CategoricalAccuracy()
這是我們的訓練和評估迴圈
for epoch in range(epochs):
print(f"\nStart of epoch {epoch}")
for step, (inputs, targets) in enumerate(train_dataloader):
# Forward pass
logits = model(inputs)
loss = loss_fn(targets, logits)
# Backward pass
model.zero_grad()
trainable_weights = [v for v in model.trainable_weights]
# Call torch.Tensor.backward() on the loss to compute gradients
# for the weights.
loss.backward()
gradients = [v.value.grad for v in trainable_weights]
# Update weights
with torch.no_grad():
optimizer.apply(gradients, trainable_weights)
# Update training metric.
train_acc_metric.update_state(targets, logits)
# Log every 100 batches.
if step % 100 == 0:
print(
f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
)
print(f"Seen so far: {(step + 1) * batch_size} samples")
# Display metrics at the end of each epoch.
train_acc = train_acc_metric.result()
print(f"Training acc over epoch: {float(train_acc):.4f}")
# Reset training metrics at the end of each epoch
train_acc_metric.reset_state()
# Run a validation loop at the end of each epoch.
for x_batch_val, y_batch_val in val_dataloader:
val_logits = model(x_batch_val, training=False)
# Update val metrics
val_acc_metric.update_state(y_batch_val, val_logits)
val_acc = val_acc_metric.result()
val_acc_metric.reset_state()
print(f"Validation acc: {float(val_acc):.4f}")
Start of epoch 0
Training loss (for 1 batch) at step 0: 59.2206
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 8.9801
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 5.2990
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 3.6978
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 1.9965
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 2.1896
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 1.2416
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.9403
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.1838
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.5884
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.7836
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.7015
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.3335
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.2763
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.4787
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.2562
Seen so far: 48032 samples
Training acc over epoch: 0.8411
Validation acc: 0.8963
Start of epoch 1
Training loss (for 1 batch) at step 0: 0.3417
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 1.1465
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.7274
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.1273
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.6500
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.2008
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.7483
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.5821
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.5696
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.3112
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.1761
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.1811
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.2736
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.3848
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.4627
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.3934
Seen so far: 48032 samples
Training acc over epoch: 0.9053
Validation acc: 0.9221
Start of epoch 2
Training loss (for 1 batch) at step 0: 0.5743
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.4448
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.9880
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.2268
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.5607
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.1178
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.4305
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.1712
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.3109
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.1548
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.1090
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.5169
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.3791
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.6963
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.6204
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.1111
Seen so far: 48032 samples
Training acc over epoch: 0.9216
Validation acc: 0.9356
層和模型會遞迴追蹤在層呼叫 self.add_loss(value)
的正向傳遞期間建立的任何損失。純量損失值的結果清單可透過正向傳遞結束時的屬性 model.losses
取得。
如果您想要使用這些損失元件,您應該將它們加總並將它們新增到訓練步驟中的主要損失。
請考慮此層,它會建立活動正規化損失
class ActivityRegularizationLayer(keras.layers.Layer):
def call(self, inputs):
self.add_loss(1e-2 * torch.sum(inputs))
return inputs
讓我們建立一個使用它的非常簡單的模型
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu")(inputs)
# Insert activity regularization as a layer
x = ActivityRegularizationLayer()(x)
x = keras.layers.Dense(64, activation="relu")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
這就是我們現在的訓練迴圈應該看起來的樣子
# Get a fresh model
model = get_model()
# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
# Prepare the metrics.
train_acc_metric = keras.metrics.CategoricalAccuracy()
val_acc_metric = keras.metrics.CategoricalAccuracy()
for epoch in range(epochs):
print(f"\nStart of epoch {epoch}")
for step, (inputs, targets) in enumerate(train_dataloader):
# Forward pass
logits = model(inputs)
loss = loss_fn(targets, logits)
if model.losses:
loss = loss + torch.sum(*model.losses)
# Backward pass
model.zero_grad()
trainable_weights = [v for v in model.trainable_weights]
# Call torch.Tensor.backward() on the loss to compute gradients
# for the weights.
loss.backward()
gradients = [v.value.grad for v in trainable_weights]
# Update weights
with torch.no_grad():
optimizer.apply(gradients, trainable_weights)
# Update training metric.
train_acc_metric.update_state(targets, logits)
# Log every 100 batches.
if step % 100 == 0:
print(
f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
)
print(f"Seen so far: {(step + 1) * batch_size} samples")
# Display metrics at the end of each epoch.
train_acc = train_acc_metric.result()
print(f"Training acc over epoch: {float(train_acc):.4f}")
# Reset training metrics at the end of each epoch
train_acc_metric.reset_state()
# Run a validation loop at the end of each epoch.
for x_batch_val, y_batch_val in val_dataloader:
val_logits = model(x_batch_val, training=False)
# Update val metrics
val_acc_metric.update_state(y_batch_val, val_logits)
val_acc = val_acc_metric.result()
val_acc_metric.reset_state()
print(f"Validation acc: {float(val_acc):.4f}")
Start of epoch 0
Training loss (for 1 batch) at step 0: 138.7979
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 4.4268
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 1.0779
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 1.7229
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.5801
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.4298
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.4717
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 1.3369
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 1.3239
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.5972
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.1983
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.5228
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 1.0025
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.3424
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.5196
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.4287
Seen so far: 48032 samples
Training acc over epoch: 0.8089
Validation acc: 0.8947
Start of epoch 1
Training loss (for 1 batch) at step 0: 0.2903
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.4118
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.6533
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.0402
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.3638
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.3313
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.5119
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.1628
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.4793
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.2726
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.5721
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.5783
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.2533
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.2218
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.1232
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.6805
Seen so far: 48032 samples
Training acc over epoch: 0.8970
Validation acc: 0.9097
Start of epoch 2
Training loss (for 1 batch) at step 0: 0.4553
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.3975
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 1.2382
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.0927
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.3530
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.3842
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.6423
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.1751
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.4769
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.1854
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.3130
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.1633
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.1446
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.4661
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.9977
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.3392
Seen so far: 48032 samples
Training acc over epoch: 0.9182
Validation acc: 0.9200
就這樣!