TorchModuleWrapper
類別keras.layers.TorchModuleWrapper(module, name=None, **kwargs)
Torch 模組封裝層。
TorchModuleWrapper
是一個封裝類別,可以將任何 torch.nn.Module
轉換為 Keras 層,特別是使其參數可被 Keras 追蹤。
TorchModuleWrapper
僅與 PyTorch 後端相容,且不能與 TensorFlow 或 JAX 後端一起使用。
參數
torch.nn.Module
實例。 如果它是 LazyModule
實例,則必須先初始化其參數,然後再將實例傳遞給 TorchModuleWrapper
(例如,呼叫它一次)。範例
以下是如何將 TorchModuleWrapper
與原始 PyTorch 模組一起使用的範例。
import torch
import torch.nn as nn
import torch.nn.functional as F
import keras
from keras.layers import TorchModuleWrapper
class Classifier(keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Wrap `torch.nn.Module`s with `TorchModuleWrapper`
# if they contain parameters
self.conv1 = TorchModuleWrapper(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3))
)
self.conv2 = TorchModuleWrapper(
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3))
)
self.pool = nn.MaxPool2d(kernel_size=(2, 2))
self.flatten = nn.Flatten()
self.dropout = nn.Dropout(p=0.5)
self.fc = TorchModuleWrapper(nn.Linear(1600, 10))
def call(self, inputs):
x = F.relu(self.conv1(inputs))
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.pool(x)
x = self.flatten(x)
x = self.dropout(x)
x = self.fc(x)
return F.softmax(x, dim=1)
model = Classifier()
model.build((1, 28, 28))
print("# Output shape", model(torch.ones(1, 1, 28, 28).to("cuda")).shape)
model.compile(
loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"]
)
model.fit(train_loader, epochs=5)