程式碼範例 / 結構化資料 / 從零開始的結構化資料分類

從零開始的結構化資料分類

作者: fchollet
建立日期 2020/06/09
上次修改日期 2020/06/09
描述: 結構化資料(包括數值和類別特徵)的二元分類。

ⓘ 此範例使用 Keras 3

在 Colab 中檢視 GitHub 來源


簡介

此範例示範如何從原始 CSV 檔案開始進行結構化資料分類。我們的資料包含數值和類別特徵。我們將使用 Keras 預處理層來正規化數值特徵並向量化類別特徵。

請注意,此範例應在 TensorFlow 2.5 或更高版本上執行。

資料集

我們的資料集由克里夫蘭醫學中心基金會提供,用於心臟病研究。這是一個 CSV 檔案,包含 303 列。每列包含有關患者的資訊(一個樣本),每欄描述患者的一個屬性(一個特徵)。我們使用這些特徵來預測患者是否患有心臟病(二元分類)。

以下是每個特徵的描述

描述 特徵類型
年齡 年齡(年) 數值
性別 (1 = 男性;0 = 女性) 類別
CP 胸痛類型 (0, 1, 2, 3, 4) 類別
Trestbpd 靜息血壓(入院時以毫米汞柱為單位) 數值
Chol 血清膽固醇(毫克/分升) 數值
FBS 空腹血糖(120 毫克/分升)(1 = 真;0 = 假) 類別
RestECG 靜息心電圖結果 (0, 1, 2) 類別
Thalach 達到的最大心率 數值
Exang 運動引起的心絞痛(1 = 是;0 = 否) 類別
Oldpeak 運動相對於靜息引起的 ST 段下降 數值
Slope 峰值運動 ST 段的斜率 數值
CA 螢光透視著色的主要血管數量 (0-3) 數值和類別
Thal 3 = 正常;6 = 固定缺陷;7 = 可逆缺陷 類別
Target 心臟病診斷(1 = 真;0 = 假) Target

設定

import os

os.environ["KERAS_BACKEND"] = "torch"  # or torch, or tensorflow

import pandas as pd
import keras
from keras import layers

準備資料

讓我們下載資料並將其載入到 Pandas 資料框架中

file_url = "http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"
dataframe = pd.read_csv(file_url)

資料集包含 303 個樣本,每個樣本有 14 欄(13 個特徵,加上目標標籤)

dataframe.shape
(303, 14)

以下是一些樣本的預覽

dataframe.head()
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal target
0 63 1 1 145 233 1 2 150 0 2.3 3 0 fixed 0
1 67 1 4 160 286 0 2 108 1 1.5 2 3 normal 1
2 67 1 4 120 229 0 2 129 1 2.6 2 2 reversible 0
3 37 1 3 130 250 0 0 187 0 3.5 3 0 normal 0
4 41 0 2 130 204 0 2 172 0 1.4 1 0 normal 0

最後一欄「target」表示患者是否患有心臟病 (1) 或沒有 (0)。

讓我們將資料分割為訓練集和驗證集

val_dataframe = dataframe.sample(frac=0.2, random_state=1337)
train_dataframe = dataframe.drop(val_dataframe.index)

print(
    f"Using {len(train_dataframe)} samples for training "
    f"and {len(val_dataframe)} for validation"
)
Using 242 samples for training and 61 for validation

定義資料集元資料

在這裡,我們定義資料集的元資料,這對於將資料讀取和剖析為輸入特徵,以及根據輸入特徵的類型對其進行編碼非常有用。

COLUMN_NAMES = [
    "age",
    "sex",
    "cp",
    "trestbps",
    "chol",
    "fbs",
    "restecg",
    "thalach",
    "exang",
    "oldpeak",
    "slope",
    "ca",
    "thal",
    "target",
]
# Target feature name.
TARGET_FEATURE_NAME = "target"
# Numeric feature names.
NUMERIC_FEATURE_NAMES = ["age", "trestbps", "thalach", "oldpeak", "slope", "chol"]
# Categorical features and their vocabulary lists.
# Note that we add 'v=' as a prefix to all categorical feature values to make
# sure that they are treated as strings.

CATEGORICAL_FEATURES_WITH_VOCABULARY = {
    feature_name: sorted(
        [
            # Integer categorcal must be int and string must be str
            value if dataframe[feature_name].dtype == "int64" else str(value)
            for value in list(dataframe[feature_name].unique())
        ]
    )
    for feature_name in COLUMN_NAMES
    if feature_name not in list(NUMERIC_FEATURE_NAMES + [TARGET_FEATURE_NAME])
}
# All features names.
FEATURE_NAMES = NUMERIC_FEATURE_NAMES + list(
    CATEGORICAL_FEATURES_WITH_VOCABULARY.keys()
)

使用 Keras 層進行特徵預處理

以下特徵是以整數編碼的類別特徵

  • sex
  • cp
  • fbs
  • restecg
  • exang
  • ca

我們將使用獨熱編碼來編碼這些特徵。我們在這裡有兩個選項

  • 使用 CategoryEncoding(),這需要知道輸入值的範圍,並且會對超出範圍的輸入產生錯誤。
  • 使用 IntegerLookup(),它將為輸入建立一個查找表,並為未知的輸入值保留一個輸出索引。

對於此範例,我們需要一個簡單的解決方案,可以在推論時處理超出範圍的輸入,因此我們將使用 IntegerLookup()

我們還有一個以字串編碼的類別特徵:thal。我們將建立所有可能特徵的索引,並使用 StringLookup() 層編碼輸出。

最後,以下特徵是連續數值特徵

  • age
  • trestbps
  • chol
  • thalach
  • oldpeak
  • slope

對於每個這些特徵,我們將使用 Normalization() 層來確保每個特徵的平均值為 0,標準差為 1。

下面,我們定義了 2 個實用函數來執行操作

  • encode_numerical_feature 將特徵式正規化應用於數值特徵。
  • process 對字串或整數類別特徵進行獨熱編碼。
# Tensorflow required for tf.data.Dataset
import tensorflow as tf


# We process our datasets elements here (categorical) and convert them to indices to avoid this step
# during model training since only tensorflow support strings.
def encode_categorical(features, target):
    for feature_name in features:
        if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
            lookup_class = (
                layers.StringLookup
                if features[feature_name].dtype == "string"
                else layers.IntegerLookup
            )
            vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
            # Create a lookup to convert a string values to an integer indices.
            # Since we are not using a mask token nor expecting any out of vocabulary
            # (oov) token, we set mask_token to None and  num_oov_indices to 0.
            index = lookup_class(
                vocabulary=vocabulary,
                mask_token=None,
                num_oov_indices=0,
                output_mode="binary",
            )
            # Convert the string input values into integer indices.
            value_index = index(features[feature_name])
            features[feature_name] = value_index

        else:
            pass

    # Change features from OrderedDict to Dict to match Inputs as they are Dict.
    return dict(features), target


def encode_numerical_feature(feature, name, dataset):
    # Create a Normalization layer for our feature
    normalizer = layers.Normalization()
    # Prepare a Dataset that only yields our feature
    feature_ds = dataset.map(lambda x, y: x[name])
    feature_ds = feature_ds.map(lambda x: tf.expand_dims(x, -1))
    # Learn the statistics of the data
    normalizer.adapt(feature_ds)
    # Normalize the input feature
    encoded_feature = normalizer(feature)
    return encoded_feature

讓我們為每個資料框架產生 tf.data.Dataset 物件

def dataframe_to_dataset(dataframe):
    dataframe = dataframe.copy()
    labels = dataframe.pop("target")
    ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels)).map(
        encode_categorical
    )
    ds = ds.shuffle(buffer_size=len(dataframe))
    return ds


train_ds = dataframe_to_dataset(train_dataframe)
val_ds = dataframe_to_dataset(val_dataframe)

每個 Dataset 產生一個元組 (input, target),其中 input 是一個特徵字典,而 target 是值 01

for x, y in train_ds.take(1):
    print("Input:", x)
    print("Target:", y)
Input: {'age': <tf.Tensor: shape=(), dtype=int64, numpy=45>, 'sex': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1])>, 'cp': <tf.Tensor: shape=(5,), dtype=int64, numpy=array([0, 0, 0, 0, 1])>, 'trestbps': <tf.Tensor: shape=(), dtype=int64, numpy=142>, 'chol': <tf.Tensor: shape=(), dtype=int64, numpy=309>, 'fbs': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, 'restecg': <tf.Tensor: shape=(3,), dtype=int64, numpy=array([0, 0, 1])>, 'thalach': <tf.Tensor: shape=(), dtype=int64, numpy=147>, 'exang': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1])>, 'oldpeak': <tf.Tensor: shape=(), dtype=float64, numpy=0.0>, 'slope': <tf.Tensor: shape=(), dtype=int64, numpy=2>, 'ca': <tf.Tensor: shape=(4,), dtype=int64, numpy=array([0, 0, 0, 1])>, 'thal': <tf.Tensor: shape=(5,), dtype=int64, numpy=array([0, 0, 0, 0, 1])>}
Target: tf.Tensor(1, shape=(), dtype=int64)

讓我們批次處理資料集

train_ds = train_ds.batch(32)
val_ds = val_ds.batch(32)

建立模型

完成此操作後,我們可以建立我們的端對端模型

# Categorical features have different shapes after the encoding, dependent on the
# vocabulary or unique values of each feature. We create them accordinly to match the
# input data elements generated by tf.data.Dataset after pre-processing them
def create_model_inputs():
    inputs = {}

    # This a helper function for creating categorical features
    def create_input_helper(feature_name):
        num_categories = len(CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name])
        inputs[feature_name] = layers.Input(
            name=feature_name, shape=(num_categories,), dtype="int64"
        )
        return inputs

    for feature_name in FEATURE_NAMES:
        if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
            # Categorical features
            create_input_helper(feature_name)
        else:
            # Make them float32, they are Real numbers
            feature_input = layers.Input(name=feature_name, shape=(1,), dtype="float32")
            # Process the Inputs here
            inputs[feature_name] = encode_numerical_feature(
                feature_input, feature_name, train_ds
            )
    return inputs


# This Layer defines the logic of the Model to perform the classification
class Classifier(keras.layers.Layer):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.dense_1 = layers.Dense(32, activation="relu")
        self.dropout = layers.Dropout(0.5)
        self.dense_2 = layers.Dense(1, activation="sigmoid")

    def call(self, inputs):
        all_features = layers.concatenate(list(inputs.values()))
        x = self.dense_1(all_features)
        x = self.dropout(x)
        output = self.dense_2(x)
        return output

    # Surpress build warnings
    def build(self, input_shape):
        self.built = True


# Create the Classifier model
def create_model():
    all_inputs = create_model_inputs()
    output = Classifier()(all_inputs)
    model = keras.Model(all_inputs, output)
    return model


model = create_model()
model.compile("adam", "binary_crossentropy", metrics=["accuracy"])
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:106: UserWarning: When providing `inputs` as a dict, all keys in the dict must match the names of the corresponding tensors. Received key 'age' mapping to value <KerasTensor shape=(None, 1), dtype=float32, sparse=False, name=keras_tensor> which has name 'keras_tensor'. Change the tensor name to 'age' (via `Input(..., name='age')`)
  warnings.warn(
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:106: UserWarning: When providing `inputs` as a dict, all keys in the dict must match the names of the corresponding tensors. Received key 'trestbps' mapping to value <KerasTensor shape=(None, 1), dtype=float32, sparse=False, name=keras_tensor_1> which has name 'keras_tensor_1'. Change the tensor name to 'trestbps' (via `Input(..., name='trestbps')`)
  warnings.warn(
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:106: UserWarning: When providing `inputs` as a dict, all keys in the dict must match the names of the corresponding tensors. Received key 'thalach' mapping to value <KerasTensor shape=(None, 1), dtype=float32, sparse=False, name=keras_tensor_2> which has name 'keras_tensor_2'. Change the tensor name to 'thalach' (via `Input(..., name='thalach')`)
  warnings.warn(
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:106: UserWarning: When providing `inputs` as a dict, all keys in the dict must match the names of the corresponding tensors. Received key 'oldpeak' mapping to value <KerasTensor shape=(None, 1), dtype=float32, sparse=False, name=keras_tensor_3> which has name 'keras_tensor_3'. Change the tensor name to 'oldpeak' (via `Input(..., name='oldpeak')`)
  warnings.warn(
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:106: UserWarning: When providing `inputs` as a dict, all keys in the dict must match the names of the corresponding tensors. Received key 'slope' mapping to value <KerasTensor shape=(None, 1), dtype=float32, sparse=False, name=keras_tensor_4> which has name 'keras_tensor_4'. Change the tensor name to 'slope' (via `Input(..., name='slope')`)
  warnings.warn(
/home/humbulani/tensorflow-env/env/lib/python3.11/site-packages/keras/src/models/functional.py:106: UserWarning: When providing `inputs` as a dict, all keys in the dict must match the names of the corresponding tensors. Received key 'chol' mapping to value <KerasTensor shape=(None, 1), dtype=float32, sparse=False, name=keras_tensor_5> which has name 'keras_tensor_5'. Change the tensor name to 'chol' (via `Input(..., name='chol')`)
  warnings.warn(

讓我們視覺化我們的連線圖

# `rankdir='LR'` is to make the graph horizontal.
keras.utils.plot_model(model, show_shapes=True, rankdir="LR")

png


訓練模型

model.fit(train_ds, epochs=50, validation_data=val_ds)
Epoch 1/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 102 毫秒/步 - 準確度:0.4688 - 損失:8.0563



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 8 毫秒/步 - 準確度:0.4732 - 損失:7.9796



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 22 毫秒/步 - 準確度:0.4725 - 損失:7.9848 - 驗證準確度:0.2295 - 驗證損失:12.0816

Epoch 2/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 105 毫秒/步 - 準確度:0.5000 - 損失:6.6368



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 8 毫秒/步 - 準確度:0.4532 - 損失:7.8320



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 18 毫秒/步 - 準確度:0.4547 - 損失:7.8310 - 驗證準確度:0.2459 - 驗證損失:6.2543

Epoch 3/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 91 毫秒/步 - 準確度:0.5000 - 損失:7.6558



7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0 秒 9 毫秒/步 - 準確度:0.5041 - 損失:7.3378



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 19 毫秒/步 - 準確度:0.5087 - 損失:7.2802 - 驗證準確度:0.6885 - 驗證損失:2.1633

Epoch 4/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 93 毫秒/步 - 準確度:0.4375 - 損失:8.9030



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 8 毫秒/步 - 準確度:0.4815 - 損失:8.0109



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 19 毫秒/步 - 準確度:0.4858 - 損失:7.9351 - 驗證準確度:0.7705 - 驗證損失:3.3916

Epoch 5/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 88 毫秒/步 - 準確度:0.4688 - 損失:8.1279



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 10 毫秒/步 - 準確度:0.5049 - 損失:7.4815



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 22 毫秒/步 - 準確度:0.5117 - 損失:7.4054 - 驗證準確度:0.7705 - 驗證損失:3.6911

Epoch 6/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 107 毫秒/步 - 準確度:0.4688 - 損失:7.8832



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 10 毫秒/步 - 準確度:0.4940 - 損失:7.4615



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 22 毫秒/步 - 準確度:0.5121 - 損失:7.1851 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 7/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 101 毫秒/步 - 準確度:0.5312 - 損失:6.9446



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 12 毫秒/步 - 準確度:0.5357 - 損失:6.5511



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 22 毫秒/步 - 準確度:0.5497 - 損失:6.3711 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 8/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 110 毫秒/步 - 準確度:0.5938 - 損失:6.3905



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 10 毫秒/步 - 準確度:0.6192 - 損失:5.9601



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 22 毫秒/步 - 準確度:0.6101 - 損失:6.0728 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 9/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 108 毫秒/步 - 準確度:0.5938 - 損失:6.5442



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 10 毫秒/步 - 準確度:0.6006 - 損失:6.3309



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 21 毫秒/步 - 準確度:0.5949 - 損失:6.3647 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 10/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 113 毫秒/步 - 準確度:0.5625 - 損失:6.8250



7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0 秒 10 毫秒/步 - 準確度:0.5675 - 損失:6.5020



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 22 毫秒/步 - 準確度:0.5764 - 損失:6.3308 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 11/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 116 毫秒/步 - 準確度:0.6250 - 損失:4.3582



7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0 秒 9 毫秒/步 - 準確度:0.6053 - 損失:5.4824



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 22 毫秒/步 - 準確度:0.6076 - 損失:5.4500 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 12/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 118 毫秒/步 - 準確度:0.5625 - 損失:7.0064



7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0 秒 9 毫秒/步 - 準確度:0.5740 - 損失:6.4431



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 21 毫秒/步 - 準確度:0.5787 - 損失:6.3510 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 13/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 115 毫秒/步 - 準確度:0.7500 - 損失:3.7382



7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0 秒 10 毫秒/步 - 準確度:0.6812 - 損失:4.7893



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 21 毫秒/步 - 準確度:0.6712 - 損失:4.9453 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 14/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 114 毫秒/步 - 準確度:0.6562 - 損失:5.5498



7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0 秒 9 毫秒/步 - 準確度:0.6580 - 損失:5.4636



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 20 毫秒/步 - 準確度:0.6578 - 損失:5.4379 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 15/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 113 毫秒/步 - 準確度:0.5938 - 損失:5.8118



7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0 秒 8 毫秒/步 - 準確度:0.5978 - 損失:5.9295



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 21 毫秒/步 - 準確度:0.6045 - 損失:5.8426 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 16/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 115 毫秒/步 - 準確度:0.6562 - 損失:4.4893



7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0 秒 9 毫秒/步 - 準確度:0.5763 - 損失:5.9135



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 22 毫秒/步 - 準確度:0.5814 - 損失:5.8590 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 17/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 127 毫秒/步 - 準確度:0.5625 - 損失:7.0281



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 10 毫秒/步 - 準確度:0.6071 - 損失:6.0424



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 24 毫秒/步 - 準確度:0.6179 - 損失:5.8262 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 18/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 130 毫秒/步 - 準確度:0.6562 - 損失:5.3547



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 11 毫秒/步 - 準確度:0.6701 - 損失:5.0648



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 25 毫秒/步 - 準確度:0.6713 - 損失:5.0607 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 19/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 121 毫秒/步 - 準確度:0.7500 - 損失:4.0295



5/8 ━━━━━━━━━━━━ [37m━━━━━━━━ 0 秒 13 毫秒/步 - 準確度:0.7157 - 損失:4.3995



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 25 毫秒/步 - 準確度:0.7077 - 損失:4.4886 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 20/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 129 毫秒/步 - 準確度:0.6250 - 損失:6.0278



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 11 毫秒/步 - 準確度:0.6479 - 損失:5.4982



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 24 毫秒/步 - 準確度:0.6461 - 損失:5.4898 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 21/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 134 毫秒/步 - 準確度:0.5938 - 損失:5.8592



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 10 毫秒/步 - 準確度:0.6782 - 損失:4.7529



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 23 毫秒/步 - 準確度:0.6627 - 損失:5.0219 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 22/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 127 毫秒/步 - 準確度:0.6875 - 損失:5.0149



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 10 毫秒/步 - 準確度:0.6342 - 損失:5.5898



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 25 毫秒/步 - 準確度:0.6290 - 損失:5.6701 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 23/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 121 毫秒/步 - 準確度:0.5938 - 損失:6.0783



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 11 毫秒/步 - 準確度:0.6259 - 損失:5.6908



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 24 毫秒/步 - 準確度:0.6352 - 損失:5.5719 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 24/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 112 毫秒/步 - 準確度:0.7812 - 損失:3.1021



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 12 毫秒/步 - 準確度:0.7353 - 損失:3.8725



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 23 毫秒/步 - 準確度:0.7163 - 損失:4.1637 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 25/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 112 毫秒/步 - 準確度:0.5625 - 損失:6.9224



5/8 ━━━━━━━━━━━━ [37m━━━━━━━━ 0 秒 13 毫秒/步 - 準確度:0.6331 - 損失:5.5663



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 23 毫秒/步 - 準確度:0.6416 - 損失:5.4024 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 26/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 117 毫秒/步 - 準確度:0.6875 - 損失:4.4043



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 10 毫秒/步 - 準確度:0.6668 - 損失:5.0742



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 22 毫秒/步 - 準確度:0.6743 - 損失:4.9986 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 27/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 104 毫秒/步 - 準確度:0.6562 - 損失:5.3405



7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0 秒 8 毫秒/步 - 準確度:0.6868 - 損失:4.7990



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 20 毫秒/步 - 準確度:0.6838 - 損失:4.8458 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 28/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 116 毫秒/步 - 準確度:0.6562 - 損失:4.8092



7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0 秒 9 毫秒/步 - 準確度:0.7061 - 損失:4.3996



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 20 毫秒/步 - 準確度:0.7053 - 損失:4.4297 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 29/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 114 毫秒/步 - 準確度:0.6250 - 損失:5.6655



7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0 秒 10 毫秒/步 - 準確度:0.6536 - 損失:5.3912



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 21 毫秒/步 - 準確度:0.6589 - 損失:5.3014 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 30/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 116 毫秒/步 - 準確度:0.7812 - 損失:3.5258



7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0 秒 9 毫秒/步 - 準確度:0.6900 - 損失:4.7711



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 20 毫秒/步 - 準確度:0.6882 - 損失:4.8074 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 31/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 123 毫秒/步 - 準確度:0.5938 - 損失:6.5425



7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0 秒 10 毫秒/步 - 準確度:0.6346 - 損失:5.6779



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 22 毫秒/步 - 準確度:0.6423 - 損失:5.5672 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 32/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 120 毫秒/步 - 準確度:0.6250 - 損失:5.6215



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 11 毫秒/步 - 準確度:0.6451 - 損失:5.2140



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 23 毫秒/步 - 準確度:0.6556 - 損失:5.0993 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 33/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 115 毫秒/步 - 準確度:0.7188 - 損失:4.2096



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 10 毫秒/步 - 準確度:0.7218 - 損失:4.3075



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 23 毫秒/步 - 準確度:0.7143 - 損失:4.4143 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 34/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 114 毫秒/步 - 準確度:0.5625 - 損失:7.0242



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 11 毫秒/步 - 準確度:0.6608 - 損失:5.3428



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 24 毫秒/步 - 準確度:0.6675 - 損失:5.2031 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 35/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 105 毫秒/步 - 準確度:0.6875 - 損失:5.0369



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 11 毫秒/步 - 準確度:0.6601 - 損失:5.2386



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 24 毫秒/步 - 準確度:0.6675 - 損失:5.0972 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 36/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 114 毫秒/步 - 準確度:0.6562 - 損失:4.8957



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 11 毫秒/步 - 準確度:0.7086 - 損失:4.4144



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 23 毫秒/步 - 準確度:0.6980 - 損失:4.5912 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 37/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 115 毫秒/步 - 準確度:0.6250 - 損失:6.0333



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 11 毫秒/步 - 準確度:0.6438 - 損失:5.6852



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 23 毫秒/步 - 準確度:0.6551 - 損失:5.4504 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 38/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 119 毫秒/步 - 準確度:0.5938 - 損失:6.4043



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 10 毫秒/步 - 準確度:0.6659 - 損失:5.2220



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 22 毫秒/步 - 準確度:0.6751 - 損失:5.0637 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 39/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 122 毫秒/步 - 準確度:0.5625 - 損失:7.0517



7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0 秒 10 毫秒/步 - 準確度:0.6782 - 損失:5.0396



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 22 毫秒/步 - 準確度:0.6854 - 損失:4.9129 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 40/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 121 毫秒/步 - 準確度:0.6562 - 損失:5.4278



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 10 毫秒/步 - 準確度:0.6575 - 損失:5.2183



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 22 毫秒/步 - 準確度:0.6676 - 損失:5.0430 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 41/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 120 毫秒/步 - 準確度:0.7500 - 損失:3.9611



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 10 毫秒/步 - 準確度:0.7322 - 損失:4.2233



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 24 毫秒/步 - 準確度:0.7325 - 損失:4.2274 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 42/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 127 毫秒/步 - 準確度:0.8438 - 損失:2.5075



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 11 毫秒/步 - 準確度:0.7483 - 損失:3.8605



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 26 毫秒/步 - 準確度:0.7305 - 損失:4.1423 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 43/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 132 毫秒/步 - 準確度:0.7188 - 損失:4.5277



5/8 ━━━━━━━━━━━━ [37m━━━━━━━━ 0 秒 15 毫秒/步 - 準確度:0.6698 - 損失:5.2541



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 27 毫秒/步 - 準確度:0.6831 - 損失:4.9995 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 44/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 1 秒 149 毫秒/步 - 準確度:0.7188 - 損失:4.3368



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 12 毫秒/步 - 準確度:0.6884 - 損失:4.8941



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 26 毫秒/步 - 準確度:0.6877 - 損失:4.9237 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 45/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 113 毫秒/步 - 準確度:0.7188 - 損失:3.6048



7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0 秒 9 毫秒/步 - 準確度:0.6953 - 損失:4.5189



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 24 毫秒/步 - 準確度:0.6914 - 損失:4.6078 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 46/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 120 毫秒/步 - 準確度:0.7188 - 損失:4.5277



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 11 毫秒/步 - 準確度:0.7298 - 損失:4.2710



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 25 毫秒/步 - 準確度:0.7214 - 損失:4.4175 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 47/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 117 毫秒/步 - 準確度:0.7500 - 損失:4.0295



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 11 毫秒/步 - 準確度:0.6962 - 損失:4.8892



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 26 毫秒/步 - 準確度:0.6981 - 損失:4.8478 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 48/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 122 毫秒/步 - 準確度:0.7812 - 損失:3.4540



7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0 秒 9 毫秒/步 - 準確度:0.7095 - 損失:4.5553



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 21 毫秒/步 - 準確度:0.7080 - 損失:4.5585 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 49/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 117 毫秒/步 - 準確度:0.6875 - 損失:4.5707



7/8 ━━━━━━━━━━━━━━━━━ [37m━━━ 0 秒 10 毫秒/步 - 準確度:0.6914 - 損失:4.7756



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 22 毫秒/步 - 準確度:0.6939 - 損失:4.7445 - 驗證準確度:0.7705 - 驗證損失:3.6992

Epoch 50/50

1/8 ━━ [37m━━━━━━━━━━━━━━━━━━ 0 秒 124 毫秒/步 - 準確度:0.7188 - 損失:4.0735



6/8 ━━━━━━━━━━━━━━━ [37m━━━━━ 0 秒 11 毫秒/步 - 準確度:0.7049 - 損失:4.3802



8/8 ━━━━━━━━━━━━━━━━━━━━ 0 秒 22 毫秒/步 - 準確度:0.6987 - 損失:4.5132 - 驗證準確度:0.7705 - 驗證損失:3.6992

<keras.src.callbacks.history.History at 0x747bef08e590>

我們很快達到 80% 的驗證準確度。


新資料的推論

若要取得新樣本的預測,您可以直接呼叫 model.predict()。您只需要做兩件事

  1. 將純量包裝到清單中,以便具有批次維度(模型僅處理批次資料,而不是單個樣本)
  2. 在每個特徵上呼叫 convert_to_tensor
sample = {
    "age": 60,
    "sex": 1,
    "cp": 1,
    "trestbps": 145,
    "chol": 233,
    "fbs": 1,
    "restecg": 2,
    "thalach": 150,
    "exang": 0,
    "oldpeak": 2.3,
    "slope": 3,
    "ca": 0,
    "thal": "fixed",
}


# Given the category (in the sample above - key) and the category value (in the sample above - value),
# we return its one-hot encoding
def get_cat_encoding(cat, cat_value):
    # Create a list of zeros with the same length as categories
    encoding = [0] * len(cat)
    # Find the index of category_value in categories and set the corresponding position to 1
    if cat_value in cat:
        encoding[cat.index(cat_value)] = 1
    return encoding


for name, value in sample.items():
    if name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
        sample.update(
            {
                name: get_cat_encoding(
                    CATEGORICAL_FEATURES_WITH_VOCABULARY[name], sample[name]
                )
            }
        )
# Convert inputs to tensors
input_dict = {name: tf.convert_to_tensor([value]) for name, value in sample.items()}
predictions = model.predict(input_dict)

print(
    f"This particular patient had a {100 * predictions[0][0]:.1f} "
    "percent probability of having a heart disease, "
    "as evaluated by our model."
)

1/1 ━━━━━━━━━━━━━━━━━━━━ 0 秒 77 毫秒/步



1/1 ━━━━━━━━━━━━━━━━━━━━ 0 秒 79 毫秒/步

This particular patient had a 0.0 percent probability of having a heart disease, as evaluated by our model.

結論

  • 原始模型(僅在 tensorflow 上運行的模型)快速收斂到約 80%,並在很長一段時間內保持在那裡,有時達到 85%
  • 更新後的模型(後端不可知)模型可能會在 78% 到 83% 之間波動,有時達到 86% 的驗證準確度,並且也收斂在 80% 左右。