4. Transformer based models

from ai4water import Model
from ai4water.utils.utils import TrainTestSplit
from ai4water.models.utils import gen_cat_vocab
from ai4water.models import FTTransformer, TabTransformer

from utils import make_data, evaluate_model

Tab Transformer

data, *_ = make_data()
data = data.drop(labels=800, axis=0)  # 'GSAC-Ce-1'
data = data.drop(labels=[921, 923], axis=0)  # 'AB25'
data = data.drop(labels=[920, 922], axis=0)  # 'TRAC'
data = data.drop(labels=801, axis=0)  # 'GSAC'

NUMERIC_FEATURES = data.columns.to_list()[0:10]
CAT_FEATURES = ["Adsorbent", "Dye"]
LABEL = "Adsorption"


splitter = TrainTestSplit(seed=313)

data[NUMERIC_FEATURES] = data[NUMERIC_FEATURES].astype(float)
data[CAT_FEATURES] = data[CAT_FEATURES].astype(str)
data['Adsorption'] = data['Adsorption'].astype(float)

train_data, test_data, _, _ = splitter.split_by_random(data)

# create vocabulary of unique values of categorical features
cat_vocabulary = gen_cat_vocab(data)

make a list of input arrays for training data

train_x = [train_data[NUMERIC_FEATURES].values, train_data[CAT_FEATURES].values]
test_x = [test_data[NUMERIC_FEATURES].values, test_data[CAT_FEATURES].values]
depth = 3
num_heads = 4
hidden_units = 16
final_mpl_units = [84, 42]
num_numeric_features = len(NUMERIC_FEATURES)
model = Model(model=TabTransformer(
    cat_vocabulary=cat_vocabulary,
    num_numeric_features=num_numeric_features,
    hidden_units=hidden_units,
    final_mlp_units = final_mpl_units,
    depth=depth,
    num_heads=num_heads,
))
            building DL model for
            regression problem using Model
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
Input_cat (InputLayer)          [(None, 2)]          0
__________________________________________________________________________________________________
cat_embeddings (CatEmbeddings)  (None, 2, 16)        960         Input_cat[0][0]
__________________________________________________________________________________________________
layer_normalization_1 (LayerNor (None, 2, 16)        32          cat_embeddings[0][0]
__________________________________________________________________________________________________
multi_head_attention (MultiHead ((None, 2, 16), (Non 4304        layer_normalization_1[0][0]
                                                                 layer_normalization_1[0][0]
__________________________________________________________________________________________________
add (Add)                       (None, 2, 16)        0           layer_normalization_1[0][0]
                                                                 multi_head_attention[0][0]
__________________________________________________________________________________________________
sequential (Sequential)         (None, 2, 16)        304         add[0][0]
__________________________________________________________________________________________________
add_1 (Add)                     (None, 2, 16)        0           sequential[0][0]
                                                                 add[0][0]
__________________________________________________________________________________________________
layer_normalization_3 (LayerNor (None, 2, 16)        32          add_1[0][0]
__________________________________________________________________________________________________
layer_normalization_4 (LayerNor (None, 2, 16)        32          layer_normalization_3[0][0]
__________________________________________________________________________________________________
multi_head_attention_1 (MultiHe ((None, 2, 16), (Non 4304        layer_normalization_4[0][0]
                                                                 layer_normalization_4[0][0]
__________________________________________________________________________________________________
add_2 (Add)                     (None, 2, 16)        0           layer_normalization_4[0][0]
                                                                 multi_head_attention_1[0][0]
__________________________________________________________________________________________________
sequential_1 (Sequential)       (None, 2, 16)        304         add_2[0][0]
__________________________________________________________________________________________________
add_3 (Add)                     (None, 2, 16)        0           sequential_1[0][0]
                                                                 add_2[0][0]
__________________________________________________________________________________________________
layer_normalization_6 (LayerNor (None, 2, 16)        32          add_3[0][0]
__________________________________________________________________________________________________
layer_normalization_7 (LayerNor (None, 2, 16)        32          layer_normalization_6[0][0]
__________________________________________________________________________________________________
multi_head_attention_2 (MultiHe ((None, 2, 16), (Non 4304        layer_normalization_7[0][0]
                                                                 layer_normalization_7[0][0]
__________________________________________________________________________________________________
add_4 (Add)                     (None, 2, 16)        0           layer_normalization_7[0][0]
                                                                 multi_head_attention_2[0][0]
__________________________________________________________________________________________________
sequential_2 (Sequential)       (None, 2, 16)        304         add_4[0][0]
__________________________________________________________________________________________________
add_5 (Add)                     (None, 2, 16)        0           sequential_2[0][0]
                                                                 add_4[0][0]
__________________________________________________________________________________________________
Input_num (InputLayer)          [(None, 10)]         0
__________________________________________________________________________________________________
layer_normalization_9 (LayerNor (None, 2, 16)        32          add_5[0][0]
__________________________________________________________________________________________________
layer_normalization (LayerNorma (None, 10)           20          Input_num[0][0]
__________________________________________________________________________________________________
flatten (Flatten)               (None, 32)           0           layer_normalization_9[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 42)           0           layer_normalization[0][0]
                                                                 flatten[0][0]
__________________________________________________________________________________________________
MLP (Sequential)                (None, 42)           7350        concatenate[0][0]
__________________________________________________________________________________________________
Dense_out (Dense)               (None, 1)            43          MLP[0][0]
==================================================================================================
Total params: 22,389
Trainable params: 22,305
Non-trainable params: 84
__________________________________________________________________________________________________
dot plot of model could not be plotted due to ('You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) ', 'for plot_model/model_to_dot to work.')
model.fit(x=train_x, y= train_data[LABEL].values,
              validation_data=(test_x, test_data[LABEL].values),
              epochs=500, verbose=0)
transformer models
assigning name Input_num to IteratorGetNext:0 with shape (None, 10)
assigning name Input_cat to IteratorGetNext:1 with shape (None, 2)
assigning name Input_num to IteratorGetNext:0 with shape (None, 10)
assigning name Input_cat to IteratorGetNext:1 with shape (None, 2)
assigning name Input_num to IteratorGetNext:0 with shape (None, 10)
assigning name Input_cat to IteratorGetNext:1 with shape (None, 2)
********** Successfully loaded weights from weights_293_5989.81787.hdf5 file **********

<keras.callbacks.History object at 0x7f7f53bb54d0>
train_p = model.predict(x=train_x,)
assigning name Input_num to IteratorGetNext:0 with shape (None, 10)
assigning name Input_cat to IteratorGetNext:1 with shape (None, 2)

 1/33 [..............................] - ETA: 11s
19/33 [================>.............] - ETA: 0s 
33/33 [==============================] - 0s 3ms/step
evaluate_model(train_data[LABEL].values, train_p)
mse 8614.777600479118
rmse 92.81582623927407
r2 0.9562427570186125
r2_score 0.9528200924934528
mape inf
mae 34.05553273263597
test_p = model.predict(x=test_x,)
 1/15 [=>............................] - ETA: 0s
15/15 [==============================] - 0s 3ms/step
evaluate_model(test_data[LABEL].values, test_p)
mse 5989.8178327641845
rmse 77.39391340902839
r2 0.9638683804476681
r2_score 0.9628952401356854
mape inf
mae 33.684673766334534

FT Transformer

# build the FTTransformer model
model = Model(model=FTTransformer(len(NUMERIC_FEATURES), cat_vocabulary,
                                  hidden_units=16, num_heads=8))
            building DL model for
            regression problem using Model
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
Input_num (InputLayer)          [(None, 10)]         0
__________________________________________________________________________________________________
Input_cat (InputLayer)          [(None, 2)]          0
__________________________________________________________________________________________________
numerical_embeddings (Numerical (None, 10, 16)       170         Input_num[0][0]
__________________________________________________________________________________________________
cat_embeddings (CatEmbeddings)  (None, 2, 16)        960         Input_cat[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 12, 16)       0           numerical_embeddings[0][0]
                                                                 cat_embeddings[0][0]
__________________________________________________________________________________________________
layer_normalization (LayerNorma (None, 12, 16)       32          concatenate[0][0]
__________________________________________________________________________________________________
multi_head_attention (MultiHead ((None, 12, 16), (No 8592        layer_normalization[0][0]
                                                                 layer_normalization[0][0]
__________________________________________________________________________________________________
add (Add)                       (None, 12, 16)       0           layer_normalization[0][0]
                                                                 multi_head_attention[0][0]
__________________________________________________________________________________________________
sequential (Sequential)         (None, 12, 16)       544         add[0][0]
__________________________________________________________________________________________________
add_1 (Add)                     (None, 12, 16)       0           sequential[0][0]
                                                                 add[0][0]
__________________________________________________________________________________________________
layer_normalization_1 (LayerNor (None, 12, 16)       32          add_1[0][0]
__________________________________________________________________________________________________
layer_normalization_2 (LayerNor (None, 12, 16)       32          layer_normalization_1[0][0]
__________________________________________________________________________________________________
multi_head_attention_1 (MultiHe ((None, 12, 16), (No 8592        layer_normalization_2[0][0]
                                                                 layer_normalization_2[0][0]
__________________________________________________________________________________________________
add_2 (Add)                     (None, 12, 16)       0           layer_normalization_2[0][0]
                                                                 multi_head_attention_1[0][0]
__________________________________________________________________________________________________
sequential_1 (Sequential)       (None, 12, 16)       544         add_2[0][0]
__________________________________________________________________________________________________
add_3 (Add)                     (None, 12, 16)       0           sequential_1[0][0]
                                                                 add_2[0][0]
__________________________________________________________________________________________________
layer_normalization_3 (LayerNor (None, 12, 16)       32          add_3[0][0]
__________________________________________________________________________________________________
layer_normalization_4 (LayerNor (None, 12, 16)       32          layer_normalization_3[0][0]
__________________________________________________________________________________________________
multi_head_attention_2 (MultiHe ((None, 12, 16), (No 8592        layer_normalization_4[0][0]
                                                                 layer_normalization_4[0][0]
__________________________________________________________________________________________________
add_4 (Add)                     (None, 12, 16)       0           layer_normalization_4[0][0]
                                                                 multi_head_attention_2[0][0]
__________________________________________________________________________________________________
sequential_2 (Sequential)       (None, 12, 16)       544         add_4[0][0]
__________________________________________________________________________________________________
add_5 (Add)                     (None, 12, 16)       0           sequential_2[0][0]
                                                                 add_4[0][0]
__________________________________________________________________________________________________
layer_normalization_5 (LayerNor (None, 12, 16)       32          add_5[0][0]
__________________________________________________________________________________________________
layer_normalization_6 (LayerNor (None, 12, 16)       32          layer_normalization_5[0][0]
__________________________________________________________________________________________________
multi_head_attention_3 (MultiHe ((None, 12, 16), (No 8592        layer_normalization_6[0][0]
                                                                 layer_normalization_6[0][0]
__________________________________________________________________________________________________
add_6 (Add)                     (None, 12, 16)       0           layer_normalization_6[0][0]
                                                                 multi_head_attention_3[0][0]
__________________________________________________________________________________________________
sequential_3 (Sequential)       (None, 12, 16)       544         add_6[0][0]
__________________________________________________________________________________________________
add_7 (Add)                     (None, 12, 16)       0           sequential_3[0][0]
                                                                 add_6[0][0]
__________________________________________________________________________________________________
layer_normalization_7 (LayerNor (None, 12, 16)       32          add_7[0][0]
__________________________________________________________________________________________________
lambda (Lambda)                 (None, 16)           0           layer_normalization_7[0][0]
__________________________________________________________________________________________________
layer_normalization_8 (LayerNor (None, 16)           32          lambda[0][0]
__________________________________________________________________________________________________
dense_8 (Dense)                 (None, 16)           272         layer_normalization_8[0][0]
__________________________________________________________________________________________________
Dense_out (Dense)               (None, 1)            17          dense_8[0][0]
==================================================================================================
Total params: 38,251
Trainable params: 38,251
Non-trainable params: 0
__________________________________________________________________________________________________
dot plot of model could not be plotted due to ('You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) ', 'for plot_model/model_to_dot to work.')
model.fit(x=train_x, y= train_data[LABEL].values,
              validation_data=(test_x, test_data[LABEL].values),
              epochs=500, verbose=0)
transformer models
assigning name Input_num to IteratorGetNext:0 with shape (None, 10)
assigning name Input_cat to IteratorGetNext:1 with shape (None, 2)
assigning name Input_num to IteratorGetNext:0 with shape (None, 10)
assigning name Input_cat to IteratorGetNext:1 with shape (None, 2)
assigning name Input_num to IteratorGetNext:0 with shape (None, 10)
assigning name Input_cat to IteratorGetNext:1 with shape (None, 2)
********** Successfully loaded weights from weights_194_2034.58508.hdf5 file **********

<keras.callbacks.History object at 0x7f7f4f71d810>
train_p = model.predict(x=train_x)
assigning name Input_num to IteratorGetNext:0 with shape (None, 10)
assigning name Input_cat to IteratorGetNext:1 with shape (None, 2)

 1/33 [..............................] - ETA: 16s
 7/33 [=====>........................] - ETA: 0s 
13/33 [==========>...................] - ETA: 0s
19/33 [================>.............] - ETA: 0s
25/33 [=====================>........] - ETA: 0s
31/33 [===========================>..] - ETA: 0s
33/33 [==============================] - 1s 9ms/step
evaluate_model(train_data[LABEL].values, train_p)
mse 2038.273404116556
rmse 45.14724137881024
r2 0.9924287429049051
r2_score 0.9888371406507435
mape inf
mae 24.095337528113994
test_p = model.predict(x=test_x,)
 1/15 [=>............................] - ETA: 0s
 7/15 [=============>................] - ETA: 0s
13/15 [=========================>....] - ETA: 0s
15/15 [==============================] - 0s 8ms/step
evaluate_model(test_data[LABEL].values, test_p)
mse 2034.5851668911816
rmse 45.106376122353055
r2 0.9884544229206549
r2_score 0.9873964791336308
mape inf
mae 24.676201901728298

Total running time of the script: (6 minutes 58.223 seconds)

Gallery generated by Sphinx-Gallery