Note
Go to the end to download the full example code
4. Transformer based models
from ai4water import Model
from ai4water.utils.utils import TrainTestSplit
from ai4water.models.utils import gen_cat_vocab
from ai4water.models import FTTransformer, TabTransformer
from utils import make_data, evaluate_model
Tab Transformer
data, *_ = make_data()
data = data.drop(labels=800, axis=0) # 'GSAC-Ce-1'
data = data.drop(labels=[921, 923], axis=0) # 'AB25'
data = data.drop(labels=[920, 922], axis=0) # 'TRAC'
data = data.drop(labels=801, axis=0) # 'GSAC'
NUMERIC_FEATURES = data.columns.to_list()[0:10]
CAT_FEATURES = ["Adsorbent", "Dye"]
LABEL = "Adsorption"
splitter = TrainTestSplit(seed=313)
data[NUMERIC_FEATURES] = data[NUMERIC_FEATURES].astype(float)
data[CAT_FEATURES] = data[CAT_FEATURES].astype(str)
data['Adsorption'] = data['Adsorption'].astype(float)
train_data, test_data, _, _ = splitter.split_by_random(data)
# create vocabulary of unique values of categorical features
cat_vocabulary = gen_cat_vocab(data)
make a list of input arrays for training data
train_x = [train_data[NUMERIC_FEATURES].values, train_data[CAT_FEATURES].values]
test_x = [test_data[NUMERIC_FEATURES].values, test_data[CAT_FEATURES].values]
depth = 3
num_heads = 4
hidden_units = 16
final_mpl_units = [84, 42]
num_numeric_features = len(NUMERIC_FEATURES)
model = Model(model=TabTransformer(
cat_vocabulary=cat_vocabulary,
num_numeric_features=num_numeric_features,
hidden_units=hidden_units,
final_mlp_units = final_mpl_units,
depth=depth,
num_heads=num_heads,
))
building DL model for
regression problem using Model
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
Input_cat (InputLayer) [(None, 2)] 0
__________________________________________________________________________________________________
cat_embeddings (CatEmbeddings) (None, 2, 16) 960 Input_cat[0][0]
__________________________________________________________________________________________________
layer_normalization_1 (LayerNor (None, 2, 16) 32 cat_embeddings[0][0]
__________________________________________________________________________________________________
multi_head_attention (MultiHead ((None, 2, 16), (Non 4304 layer_normalization_1[0][0]
layer_normalization_1[0][0]
__________________________________________________________________________________________________
add (Add) (None, 2, 16) 0 layer_normalization_1[0][0]
multi_head_attention[0][0]
__________________________________________________________________________________________________
sequential (Sequential) (None, 2, 16) 304 add[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 2, 16) 0 sequential[0][0]
add[0][0]
__________________________________________________________________________________________________
layer_normalization_3 (LayerNor (None, 2, 16) 32 add_1[0][0]
__________________________________________________________________________________________________
layer_normalization_4 (LayerNor (None, 2, 16) 32 layer_normalization_3[0][0]
__________________________________________________________________________________________________
multi_head_attention_1 (MultiHe ((None, 2, 16), (Non 4304 layer_normalization_4[0][0]
layer_normalization_4[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 2, 16) 0 layer_normalization_4[0][0]
multi_head_attention_1[0][0]
__________________________________________________________________________________________________
sequential_1 (Sequential) (None, 2, 16) 304 add_2[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 2, 16) 0 sequential_1[0][0]
add_2[0][0]
__________________________________________________________________________________________________
layer_normalization_6 (LayerNor (None, 2, 16) 32 add_3[0][0]
__________________________________________________________________________________________________
layer_normalization_7 (LayerNor (None, 2, 16) 32 layer_normalization_6[0][0]
__________________________________________________________________________________________________
multi_head_attention_2 (MultiHe ((None, 2, 16), (Non 4304 layer_normalization_7[0][0]
layer_normalization_7[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, 2, 16) 0 layer_normalization_7[0][0]
multi_head_attention_2[0][0]
__________________________________________________________________________________________________
sequential_2 (Sequential) (None, 2, 16) 304 add_4[0][0]
__________________________________________________________________________________________________
add_5 (Add) (None, 2, 16) 0 sequential_2[0][0]
add_4[0][0]
__________________________________________________________________________________________________
Input_num (InputLayer) [(None, 10)] 0
__________________________________________________________________________________________________
layer_normalization_9 (LayerNor (None, 2, 16) 32 add_5[0][0]
__________________________________________________________________________________________________
layer_normalization (LayerNorma (None, 10) 20 Input_num[0][0]
__________________________________________________________________________________________________
flatten (Flatten) (None, 32) 0 layer_normalization_9[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate) (None, 42) 0 layer_normalization[0][0]
flatten[0][0]
__________________________________________________________________________________________________
MLP (Sequential) (None, 42) 7350 concatenate[0][0]
__________________________________________________________________________________________________
Dense_out (Dense) (None, 1) 43 MLP[0][0]
==================================================================================================
Total params: 22,389
Trainable params: 22,305
Non-trainable params: 84
__________________________________________________________________________________________________
dot plot of model could not be plotted due to ('You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) ', 'for plot_model/model_to_dot to work.')
model.fit(x=train_x, y= train_data[LABEL].values,
validation_data=(test_x, test_data[LABEL].values),
epochs=500, verbose=0)

assigning name Input_num to IteratorGetNext:0 with shape (None, 10)
assigning name Input_cat to IteratorGetNext:1 with shape (None, 2)
assigning name Input_num to IteratorGetNext:0 with shape (None, 10)
assigning name Input_cat to IteratorGetNext:1 with shape (None, 2)
assigning name Input_num to IteratorGetNext:0 with shape (None, 10)
assigning name Input_cat to IteratorGetNext:1 with shape (None, 2)
********** Successfully loaded weights from weights_293_5989.81787.hdf5 file **********
<keras.callbacks.History object at 0x7f7f53bb54d0>
train_p = model.predict(x=train_x,)
assigning name Input_num to IteratorGetNext:0 with shape (None, 10)
assigning name Input_cat to IteratorGetNext:1 with shape (None, 2)
1/33 [..............................] - ETA: 11s
19/33 [================>.............] - ETA: 0s
33/33 [==============================] - 0s 3ms/step
evaluate_model(train_data[LABEL].values, train_p)
mse 8614.777600479118
rmse 92.81582623927407
r2 0.9562427570186125
r2_score 0.9528200924934528
mape inf
mae 34.05553273263597
test_p = model.predict(x=test_x,)
1/15 [=>............................] - ETA: 0s
15/15 [==============================] - 0s 3ms/step
evaluate_model(test_data[LABEL].values, test_p)
mse 5989.8178327641845
rmse 77.39391340902839
r2 0.9638683804476681
r2_score 0.9628952401356854
mape inf
mae 33.684673766334534
FT Transformer
# build the FTTransformer model
model = Model(model=FTTransformer(len(NUMERIC_FEATURES), cat_vocabulary,
hidden_units=16, num_heads=8))
building DL model for
regression problem using Model
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
Input_num (InputLayer) [(None, 10)] 0
__________________________________________________________________________________________________
Input_cat (InputLayer) [(None, 2)] 0
__________________________________________________________________________________________________
numerical_embeddings (Numerical (None, 10, 16) 170 Input_num[0][0]
__________________________________________________________________________________________________
cat_embeddings (CatEmbeddings) (None, 2, 16) 960 Input_cat[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate) (None, 12, 16) 0 numerical_embeddings[0][0]
cat_embeddings[0][0]
__________________________________________________________________________________________________
layer_normalization (LayerNorma (None, 12, 16) 32 concatenate[0][0]
__________________________________________________________________________________________________
multi_head_attention (MultiHead ((None, 12, 16), (No 8592 layer_normalization[0][0]
layer_normalization[0][0]
__________________________________________________________________________________________________
add (Add) (None, 12, 16) 0 layer_normalization[0][0]
multi_head_attention[0][0]
__________________________________________________________________________________________________
sequential (Sequential) (None, 12, 16) 544 add[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 12, 16) 0 sequential[0][0]
add[0][0]
__________________________________________________________________________________________________
layer_normalization_1 (LayerNor (None, 12, 16) 32 add_1[0][0]
__________________________________________________________________________________________________
layer_normalization_2 (LayerNor (None, 12, 16) 32 layer_normalization_1[0][0]
__________________________________________________________________________________________________
multi_head_attention_1 (MultiHe ((None, 12, 16), (No 8592 layer_normalization_2[0][0]
layer_normalization_2[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 12, 16) 0 layer_normalization_2[0][0]
multi_head_attention_1[0][0]
__________________________________________________________________________________________________
sequential_1 (Sequential) (None, 12, 16) 544 add_2[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 12, 16) 0 sequential_1[0][0]
add_2[0][0]
__________________________________________________________________________________________________
layer_normalization_3 (LayerNor (None, 12, 16) 32 add_3[0][0]
__________________________________________________________________________________________________
layer_normalization_4 (LayerNor (None, 12, 16) 32 layer_normalization_3[0][0]
__________________________________________________________________________________________________
multi_head_attention_2 (MultiHe ((None, 12, 16), (No 8592 layer_normalization_4[0][0]
layer_normalization_4[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, 12, 16) 0 layer_normalization_4[0][0]
multi_head_attention_2[0][0]
__________________________________________________________________________________________________
sequential_2 (Sequential) (None, 12, 16) 544 add_4[0][0]
__________________________________________________________________________________________________
add_5 (Add) (None, 12, 16) 0 sequential_2[0][0]
add_4[0][0]
__________________________________________________________________________________________________
layer_normalization_5 (LayerNor (None, 12, 16) 32 add_5[0][0]
__________________________________________________________________________________________________
layer_normalization_6 (LayerNor (None, 12, 16) 32 layer_normalization_5[0][0]
__________________________________________________________________________________________________
multi_head_attention_3 (MultiHe ((None, 12, 16), (No 8592 layer_normalization_6[0][0]
layer_normalization_6[0][0]
__________________________________________________________________________________________________
add_6 (Add) (None, 12, 16) 0 layer_normalization_6[0][0]
multi_head_attention_3[0][0]
__________________________________________________________________________________________________
sequential_3 (Sequential) (None, 12, 16) 544 add_6[0][0]
__________________________________________________________________________________________________
add_7 (Add) (None, 12, 16) 0 sequential_3[0][0]
add_6[0][0]
__________________________________________________________________________________________________
layer_normalization_7 (LayerNor (None, 12, 16) 32 add_7[0][0]
__________________________________________________________________________________________________
lambda (Lambda) (None, 16) 0 layer_normalization_7[0][0]
__________________________________________________________________________________________________
layer_normalization_8 (LayerNor (None, 16) 32 lambda[0][0]
__________________________________________________________________________________________________
dense_8 (Dense) (None, 16) 272 layer_normalization_8[0][0]
__________________________________________________________________________________________________
Dense_out (Dense) (None, 1) 17 dense_8[0][0]
==================================================================================================
Total params: 38,251
Trainable params: 38,251
Non-trainable params: 0
__________________________________________________________________________________________________
dot plot of model could not be plotted due to ('You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) ', 'for plot_model/model_to_dot to work.')
model.fit(x=train_x, y= train_data[LABEL].values,
validation_data=(test_x, test_data[LABEL].values),
epochs=500, verbose=0)

assigning name Input_num to IteratorGetNext:0 with shape (None, 10)
assigning name Input_cat to IteratorGetNext:1 with shape (None, 2)
assigning name Input_num to IteratorGetNext:0 with shape (None, 10)
assigning name Input_cat to IteratorGetNext:1 with shape (None, 2)
assigning name Input_num to IteratorGetNext:0 with shape (None, 10)
assigning name Input_cat to IteratorGetNext:1 with shape (None, 2)
********** Successfully loaded weights from weights_194_2034.58508.hdf5 file **********
<keras.callbacks.History object at 0x7f7f4f71d810>
train_p = model.predict(x=train_x)
assigning name Input_num to IteratorGetNext:0 with shape (None, 10)
assigning name Input_cat to IteratorGetNext:1 with shape (None, 2)
1/33 [..............................] - ETA: 16s
7/33 [=====>........................] - ETA: 0s
13/33 [==========>...................] - ETA: 0s
19/33 [================>.............] - ETA: 0s
25/33 [=====================>........] - ETA: 0s
31/33 [===========================>..] - ETA: 0s
33/33 [==============================] - 1s 9ms/step
evaluate_model(train_data[LABEL].values, train_p)
mse 2038.273404116556
rmse 45.14724137881024
r2 0.9924287429049051
r2_score 0.9888371406507435
mape inf
mae 24.095337528113994
test_p = model.predict(x=test_x,)
1/15 [=>............................] - ETA: 0s
7/15 [=============>................] - ETA: 0s
13/15 [=========================>....] - ETA: 0s
15/15 [==============================] - 0s 8ms/step
evaluate_model(test_data[LABEL].values, test_p)
mse 2034.5851668911816
rmse 45.106376122353055
r2 0.9884544229206549
r2_score 0.9873964791336308
mape inf
mae 24.676201901728298
Total running time of the script: (6 minutes 58.223 seconds)