Keras 2 API 文件 / 層 API / 核心層 / 輸入物件

輸入物件

[原始碼]

Input 函數

tf_keras.Input(
    shape=None,
    batch_size=None,
    name=None,
    dtype=None,
    sparse=None,
    tensor=None,
    ragged=None,
    type_spec=None,
    **kwargs
)

Input() 用於實例化 TF-Keras 張量。

TF-Keras 張量是一個符號張量類型的物件,我們透過一些屬性來增強它,這些屬性讓我們僅透過知道模型的輸入和輸出就可以建立 TF-Keras 模型。

例如,如果 abc 是 TF-Keras 張量,則可以執行以下操作: model = Model(input=[a, b], output=c)

參數

  • shape:形狀元組(整數),不包含批次大小。例如,shape=(32,) 表示預期的輸入將是 32 維向量的批次。此元組的元素可以是 None;'None' 元素表示形狀未知維度。
  • batch_size:可選的靜態批次大小(整數)。
  • name:層的可選名稱字串。在模型中應該是唯一的(不要重複使用相同的名稱)。如果未提供,它將自動產生。
  • dtype:輸入預期的資料類型,以字串形式(float32float64int32...)。
  • sparse:一個布林值,指定要建立的佔位符是否為稀疏的。'ragged' 和 'sparse' 只能有一個為 True。請注意,如果 sparse 為 False,稀疏張量仍然可以傳遞到輸入中,它們將被以預設值 0 進行密集化。
  • tensor:可選的現有張量以包裝到 Input 層中。如果設定,該層將使用此張量的 tf.TypeSpec 而不是建立新的佔位符張量。
  • ragged:一個布林值,指定要建立的佔位符是否為不規則的。'ragged' 和 'sparse' 只能有一個為 True。在這種情況下,'shape' 參數中的 'None' 值表示不規則的維度。有關 RaggedTensors 的更多資訊,請參閱本指南
  • type_spec:一個 tf.TypeSpec 物件,用於從中建立輸入佔位符。如果提供,則除了名稱之外的所有其他參數都必須為 None。
  • **kwargs:已棄用的參數支援。支援 batch_shapebatch_input_shape

傳回

一個 tensor

範例

# this is a logistic regression in Keras
x = Input(shape=(32,))
y = Dense(16, activation='softmax')(x)
model = Model(x, y)

請注意,即使啟用了 eager 執行,Input 也會產生一個符號張量類型的物件(即佔位符)。此符號張量類型的物件可以與將張量作為輸入的較低層級 TensorFlow 運算一起使用,如下所示

x = Input(shape=(32,))
y = tf.square(x)  # This op will be treated like a layer
model = Model(x, y)

(此行為不適用於較高階的 TensorFlow API,例如控制流程,並且不能直接被 tf.GradientTape 監控)。

但是,產生的模型不會追蹤任何被用作 TensorFlow 運算輸入的變數。所有變數的使用都必須在 TF-Keras 層中進行,以確保它們將被模型的權重追蹤。

TF-Keras Input 也可以從任意的 tf.TypeSpec 建立佔位符,例如

x = Input(type_spec=tf.RaggedTensorSpec(shape=[None, None],
                                        dtype=tf.float32, ragged_rank=1))
y = x.values
model = Model(x, y)

當傳遞任意的 tf.TypeSpec 時,它必須表示整個批次的簽名,而不僅僅是一個範例。

引發

  • ValueError:如果同時提供 sparseragged
  • ValueError:如果同時提供 shape 和 (batch_input_shapebatch_shape)。
  • ValueError:如果 shapetensortype_spec 都是 None。
  • ValueError:如果在傳遞 type_spec 的同時,除了 type_spec 之外的參數為非 None。
  • ValueError:如果提供了任何無法識別的參數。