Keras 3 API 文件 / 層 API / 後端特定層 / TorchModuleWrapper 層

TorchModuleWrapper 層

[原始碼]

TorchModuleWrapper 類別

keras.layers.TorchModuleWrapper(module, name=None, **kwargs)

Torch 模組封裝層。

TorchModuleWrapper 是一個封裝類別,可以將任何 torch.nn.Module 轉換為 Keras 層,特別是使其參數可被 Keras 追蹤。

TorchModuleWrapper 僅與 PyTorch 後端相容,且不能與 TensorFlow 或 JAX 後端一起使用。

參數

  • moduletorch.nn.Module 實例。 如果它是 LazyModule 實例,則必須先初始化其參數,然後再將實例傳遞給 TorchModuleWrapper (例如,呼叫它一次)。
  • name:層的名稱 (字串)。

範例

以下是如何將 TorchModuleWrapper 與原始 PyTorch 模組一起使用的範例。

import torch
import torch.nn as nn
import torch.nn.functional as F

import keras
from keras.layers import TorchModuleWrapper

class Classifier(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # Wrap `torch.nn.Module`s with `TorchModuleWrapper`
        # if they contain parameters
        self.conv1 = TorchModuleWrapper(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3))
        )
        self.conv2 = TorchModuleWrapper(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3))
        )
        self.pool = nn.MaxPool2d(kernel_size=(2, 2))
        self.flatten = nn.Flatten()
        self.dropout = nn.Dropout(p=0.5)
        self.fc = TorchModuleWrapper(nn.Linear(1600, 10))

    def call(self, inputs):
        x = F.relu(self.conv1(inputs))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = self.flatten(x)
        x = self.dropout(x)
        x = self.fc(x)
        return F.softmax(x, dim=1)


model = Classifier()
model.build((1, 28, 28))
print("# Output shape", model(torch.ones(1, 1, 28, 28).to("cuda")).shape)

model.compile(
    loss="sparse_categorical_crossentropy",
    optimizer="adam",
    metrics=["accuracy"]
)
model.fit(train_loader, epochs=5)