set_random_seed
函式keras.utils.set_random_seed(seed)
設定所有隨機種子 (Python、NumPy 和後端框架,例如 TF)。
您可以使用此工具程式使幾乎所有 Keras 程式都完全具確定性。在涉及網路通訊 (例如參數伺服器分佈) 的情況下,以及在涉及某些非確定性 cuDNN 運算時,會有一些限制,這些情況會產生額外的隨機性來源。
呼叫此工具程式相當於以下操作
import random
random.seed(seed)
import numpy as np
np.random.seed(seed)
import tensorflow as tf # Only if TF is installed
tf.random.set_seed(seed)
import torch # Only if the backend is 'torch'
torch.manual_seed(seed)
請注意,即使您未使用 TensorFlow 作為後端框架,TensorFlow 種子也會被設定,因為許多工作流程利用 tf.data
管線 (其具有隨機洗牌功能)。同樣地,許多工作流程可能會利用 NumPy API。
引數
split_dataset
函式keras.utils.split_dataset(
dataset, left_size=None, right_size=None, shuffle=False, seed=None
)
將資料集分割為左半部和右半部 (例如訓練/測試)。
引數
tf.data.Dataset
、torch.utils.data.Dataset
物件,或具有相同長度的陣列/元組列表。[0, 1]
內),則表示要封裝在左側資料集中的資料比例。如果為整數,則表示要封裝在左側資料集中的樣本數。如果為 None
,則預設為 right_size
的補數。預設為 None
。[0, 1]
內),則表示要封裝在右側資料集中的資料比例。如果為整數,則表示要封裝在右側資料集中的樣本數。如果為 None
,則預設為 left_size
的補數。預設為 None
。回傳值
tf.data.Dataset
物件的元組:左側和右側分割。範例
>>> data = np.random.random(size=(1000, 4))
>>> left_ds, right_ds = keras.utils.split_dataset(data, left_size=0.8)
>>> int(left_ds.cardinality())
800
>>> int(right_ds.cardinality())
200
pack_x_y_sample_weight
函式keras.utils.pack_x_y_sample_weight(x, y=None, sample_weight=None)
將使用者提供的資料封裝到元組中。
這是一個方便的工具程式,用於將資料封裝為 Model.fit()
使用的元組格式。
範例
>>> x = ops.ones((10, 1))
>>> data = pack_x_y_sample_weight(x)
>>> isinstance(data, ops.Tensor)
True
>>> y = ops.ones((10, 1))
>>> data = pack_x_y_sample_weight(x, y)
>>> isinstance(data, tuple)
True
>>> x, y = data
引數
Model
的特徵。Model
的真實目標。回傳值
Model.fit()
中使用的格式的元組。
get_file
函式keras.utils.get_file(
fname=None,
origin=None,
untar=False,
md5_hash=None,
file_hash=None,
cache_subdir="datasets",
hash_algorithm="auto",
extract=False,
archive_format="auto",
cache_dir=None,
force_download=False,
)
如果快取中還沒有檔案,則從 URL 下載檔案。
預設情況下,URL origin
上的檔案會下載到快取目錄 ~/.keras
,放置在快取子目錄 datasets
中,並給定檔名 fname
。檔案 example.txt
的最終位置因此將為 ~/.keras/datasets/example.txt
。.tar
、.tar.gz
、.tar.bz
和 .zip
格式的檔案也可以解壓縮。
傳遞雜湊值將在下載後驗證檔案。命令列程式 shasum
和 sha256sum
可以計算雜湊值。
範例
path_to_downloaded_file = get_file(
origin="https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz",
extract=True,
)
引數
None
,將使用 origin
上的檔案名稱。如果下載和解壓縮目錄封存檔,則提供的 fname
將用作解壓縮目錄名稱 (僅當它沒有副檔名時)。extract
引數。布林值,檔案是否為應解壓縮的 tar 封存檔。file_hash
引數。用於檔案完整性驗證的檔案 md5 雜湊值。"/path/to/folder"
,檔案將儲存在該位置。"md5'
、"sha256'
和 "auto'
。預設值 'auto' 會偵測使用中的雜湊演算法。True
,則解壓縮封存檔。僅適用於壓縮封存檔案,如 tar 或 zip。"auto'
、"tar'
、"zip'
和 None
。"tar"
包括 tar、tar.gz 和 tar.bz 檔案。預設值 "auto"
對應於 ["tar", "zip"]
。None
或空列表將回傳找不到相符項目。KERAS_HOME
環境變數,則預設為 $KERAS_HOME
,否則預設為 ~/.keras/
。True
,則無論快取狀態如何,都將始終重新下載檔案。回傳值
下載檔案的路徑。
⚠️ 惡意下載警告 ⚠️
從網際網路下載任何內容都帶有風險。如果您不信任來源,請永遠不要下載檔案/封存檔。我們建議您指定 file_hash
引數 (如果已知來源檔案的雜湊值),以確保您取得的檔案是您期望的檔案。
Progbar
類別keras.utils.Progbar(
target, width=20, verbose=1, interval=0.05, stateful_metrics=None, unit_name="step"
)
顯示進度列。
引數
PyDataset
類別keras.utils.PyDataset(workers=1, use_multiprocessing=False, max_queue_size=10)
使用 Python 程式碼定義平行資料集的基本類別。
每個 PyDataset
都必須實作 __getitem__()
和 __len__()
方法。如果您想在 epoch 之間修改資料集,您可以額外實作 on_epoch_end()
,或在每個 epoch 開始時呼叫的 on_epoch_begin
。__getitem__()
方法應回傳完整的批次 (而非單個樣本),而 __len__
方法應回傳資料集中的批次數 (而非樣本數)。
引數
True
表示您的資料集將在多個 fork 程序中複製。為了從平行處理中獲得運算層級 (而非 I/O 層級) 的好處,這是必要的。但是,只有當您的資料集可以安全地 pickle 化時,才能將其設定為 True
。注意事項
PyDataset
是一種更安全的多處理方式。此結構保證模型在每個 epoch 中只會針對每個樣本訓練一次,而 Python 產生器則不然。workers
、use_multiprocessing
和 max_queue_size
存在是為了設定 fit()
如何使用平行處理來迭代資料集。它們並非由 PyDataset
類別直接使用。當您手動迭代 PyDataset
時,不會套用平行處理。範例
from skimage.io import imread
from skimage.transform import resize
import numpy as np
import math
# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.
class CIFAR10PyDataset(keras.utils.PyDataset):
def __init__(self, x_set, y_set, batch_size, **kwargs):
super().__init__(**kwargs)
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
# Return number of batches.
return math.ceil(len(self.x) / self.batch_size)
def __getitem__(self, idx):
# Return x, y for batch idx.
low = idx * self.batch_size
# Cap upper bound at array length; the last batch may be smaller
# if the total number of items is not a multiple of batch size.
high = min(low + self.batch_size, len(self.x))
batch_x = self.x[low:high]
batch_y = self.y[low:high]
return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)
to_categorical
函式keras.utils.to_categorical(x, num_classes=None)
將類別向量 (整數) 轉換為二元類別矩陣。
例如,用於 categorical_crossentropy
。
引數
num_classes - 1
的整數) 的類陣列。None
,則會推斷為 max(x) + 1
。預設為 None
。回傳值
作為 NumPy 陣列的輸入的二元矩陣表示。類別軸位於最後。
範例
>>> a = keras.utils.to_categorical([0, 1, 2, 3], num_classes=4)
>>> print(a)
[[1. 0. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]]
>>> b = np.array([.9, .04, .03, .03,
... .3, .45, .15, .13,
... .04, .01, .94, .05,
... .12, .21, .5, .17],
... shape=[4, 4])
>>> loss = keras.ops.categorical_crossentropy(a, b)
>>> print(np.around(loss, 5))
[0.10536 0.82807 0.1011 1.77196]
>>> loss = keras.ops.categorical_crossentropy(a, a)
>>> print(np.around(loss, 5))
[0. 0. 0. 0.]
normalize
函式keras.utils.normalize(x, axis=-1, order=2)
正規化陣列。
如果輸入是 NumPy 陣列,則將回傳 NumPy 陣列。如果它是後端張量,則將回傳後端張量。
引數
order=2
代表 L2 範數)。回傳值
陣列的正規化副本。