Keras 3 API 文件 / 內建小型資料集 / 路透社新聞專線分類資料集

路透社新聞專線分類資料集

[原始碼]

load_data 函數

keras.datasets.reuters.load_data(
    path="reuters.npz",
    num_words=None,
    skip_top=0,
    maxlen=None,
    test_split=0.2,
    seed=113,
    start_char=1,
    oov_char=2,
    index_from=3,
)

載入路透社新聞專線分類資料集。

這是一個包含來自路透社的 11,228 篇新聞專線的資料集,標記了 46 個主題。

這最初是通過解析和預處理經典的 Reuters-21578 資料集而生成的,但預處理程式碼不再與 Keras 打包。 請參閱此 GitHub 討論 以取得更多資訊。

每篇新聞專線都編碼為單字索引(整數)的列表。為了方便起見,單字會根據其在資料集中出現的總體頻率進行索引,例如,整數「3」會編碼資料中第 3 個最常出現的單字。這允許快速篩選操作,例如:「僅考慮前 10,000 個最常見的單字,但排除前 20 個最常見的單字」。

按照慣例,「0」不代表特定的單字,而是用於編碼任何未知的單字。

引數

  • path:快取資料的位置(相對於 ~/.keras/dataset)。
  • num_words:整數或 None。單字會按照它們出現的頻率(在訓練集中)排序,並且僅保留 num_words 個最常出現的單字。任何較不頻繁的單字都會在序列資料中顯示為 oov_char 值。如果為 None,則保留所有單字。預設值為 None
  • skip_top:跳過前 N 個最常出現的單字(它們可能沒有提供太多資訊)。這些單字將會在資料集中顯示為 oov_char 值。0 表示不跳過任何單字。預設值為 0
  • maxlen:整數或 None。最大序列長度。任何更長的序列都將被截斷。None 表示不截斷。預設值為 None
  • test_split:介於 0.1. 之間的浮點數。用作測試資料的資料集部分。0.2 表示 20% 的資料集用作測試資料。預設值為 0.2
  • seed:整數。用於可重複資料洗牌的種子。
  • start_char:整數。序列的開頭將使用此字元標記。0 通常是填充字元。預設值為 1
  • oov_char:整數。詞彙外字元。由於 num_wordsskip_top 限制而被刪除的單字將會被此字元取代。
  • index_from:整數。使用此索引和更高的索引來實際索引單字。

返回值

  • Numpy 陣列的元組(x_train, y_train), (x_test, y_test)

x_trainx_test:序列列表,它們是索引(整數)的列表。如果指定了 num_words 引數,則最大可能的索引值為 num_words - 1。如果指定了 maxlen 引數,則最大可能的序列長度為 maxlen

y_trainy_test:整數標籤(1 或 0)的列表。

注意: 「詞彙外」字元僅用於訓練集中存在但由於未達到 num_words 限制而被排除的單字。在訓練集中未看到但在測試集中出現的單字只是被跳過了。


[原始碼]

get_word_index 函數

keras.datasets.reuters.get_word_index(path="reuters_word_index.json")

檢索將單字對應到它們在路透社資料集中的索引的字典。

實際的單字索引從 3 開始,其中 3 個索引保留用於:0(填充)、1(開始)、2(oov)。

例如,單字「the」的索引為 1,但在實際的訓練資料中,「the」的索引將為 1 + 3 = 4。反之亦然,要使用此對應將訓練資料中的單字索引轉譯回單字,索引需要減 3。

引數

  • path:快取資料的位置(相對於 ~/.keras/dataset)。

返回值

單字索引字典。鍵是單字字串,值是它們的索引。