Skip to contents

Temporal Fusion Transformer model

Usage

model_tft(...)

Arguments

lookback

Number of timesteps from the past

horizon

Forecast length (number of timesteps)

past_numeric_size

Number of numeric features from the past

past_categorical_size

Number of categorical features from the past

future_numeric_size

Number of numeric features from the future

output_size

Number of the models output. For simple point estimate set 1.

Examples

library(keras)
library(aion)

tft <- model_tft(
   lookback                = 28,
   horizon                 = 14,
   past_numeric_size       = 5,
   past_categorical_size   = 2,
   future_numeric_size     = 4,
   future_categorical_size = 2,
   vocab_static_size       = c(5, 5),
   vocab_dynamic_size      = c(4, 4),
   hidden_dim              = 12,
   state_size              = 7,
   num_heads                 = 10,
   dropout_rate            = 0.1,
   output_size             = 3
   #quantiles               = 0.5
)

x_static_cat <- array(sample(5, 32 * 2, replace=TRUE), c(32, 2)) - 1
x_static_num <- array(runif(32 * 1), c(32, 1))

x_past_num <- array(runif(32 * 28 * 2), c(32, 28, 2))
x_past_cat <- array(sample(4, 32 * 28 * 2, replace=TRUE), c(32, 28, 5))

x_fut_num <- array(runif(32 * 14 * 5), c(32, 28, 1))
x_fut_cat <- array(sample(4, 32 * 14 * 2, replace=TRUE), c(32, 28, 5))

tft(x_past_num, x_past_cat, x_fut_num, x_fut_cat, x_static_num, x_static_cat)
#> Error in py_call_impl(callable, dots$args, dots$keywords): RuntimeError: Evaluation error: unused arguments (c(0.803468857891858, 0.453635178972036, 0.632403897354379, 0.937457627151161, 0.0241850311867893, 0.155269286362454, 0.95891289645806, 0.353669170988724, 0.358631904236972, 0.374306883197278, 0.0379654949065298, 0.845062513137236, 0.691599707584828, 0.309445038903505, 0.66580521245487, 0.812673106556758, 0.661846718052402, 0.157121825963259, 0.880957897985354, 0.650074421195313, 0.570479388581589, 0.67772026732564, 0.603589087491855, 0.461734354496002, 0.288662529317662, 0.175363410962746, 
#> 0.548123173648492, 0.794412402203307, 0.329797565471381, 0.609179110499099, 0.758990033529699, 0.601375452009961), c(0, 4, 2, 1, 4, 3, 0, 0, 4, 4, 4, 0, 1, 4, 0, 0, 4, 3, 0, 0, 3, 1, 1, 2, 3, 4, 3, 3, 3, 0, 0, 0, 2, 0, 0, 4, 2, 0, 3, 0, 0, 4, 4, 3, 3, 4, 4, 3, 2, 0, 2, 0, 0, 4, 4, 4, 2, 3, 4, 3, 2, 2, 3, 4)).