M5 Forecasting competitions
m5_forecasting.Rmd
# Dataset
library(m5)
# Neural Networks
library(aion)
library(keras)
# Data wrangling
library(dplyr, warn.conflicts=FALSE)
library(data.table, warn.conflicts=FALSE)
library(recipes, warn.conflicts=FALSE)
Data preprocessing
train <- tiny_m5[date < '2016-01-01']
test <- tiny_m5[date >= '2016-01-01']
m5_recipe <-
recipe(value ~ ., data=train) %>%
step_mutate(item_id_idx=item_id, store_id_idx=store_id) %>%
step_integer(item_id_idx, store_id_idx,
wday, month,
event_name_1, event_type_1,
event_name_2, event_type_2,
zero_based=TRUE) %>%
step_naomit(all_predictors()) %>%
prep()
train <- bake(m5_recipe, train)
test <- bake(m5_recipe, test)
setDT(train)
setDT(test)
Experiment config
TARGET <- 'value'
STATIC_CAT <- c('item_id_idx', 'store_id_idx')
DYNAMIC_CAT <- c('event_name_1', 'event_type_1')
CATEGORICAL <- c(DYNAMIC_CAT, STATIC_CAT)
NUMERIC <- c('sell_price', 'sell_price')
KEY <- c('item_id', 'store_id')
INDEX <- 'date'
LOOKBACK <- 28
HORIZON <- 14
STRIDE <- LOOKBACK
BATCH_SIZE <- 32
Creating generators
c(train_generator, train_steps) %<-%
ts_generator(
data = train,
key = KEY,
index = INDEX,
lookback = LOOKBACK,
horizon = HORIZON,
stride = STRIDE,
target = TARGET,
static = STATIC_CAT,
categorical = CATEGORICAL,
numeric = NUMERIC,
shuffle = TRUE,
batch_size = BATCH_SIZE
)
c(test_generator, test_steps) %<-%
ts_generator(
data = test,
key = KEY,
index = INDEX,
lookback = LOOKBACK,
horizon = HORIZON,
stride = STRIDE,
target = TARGET,
static = STATIC_CAT,
categorical = CATEGORICAL,
numeric = NUMERIC,
shuffle = FALSE,
batch_size = BATCH_SIZE
)
TFT model
tft <-
model_tft(
lookback = LOOKBACK,
horizon = HORIZON,
past_numeric_size = length(NUMERIC) + 1,
past_categorical_size = length(DYNAMIC_CAT),
future_numeric_size = length(NUMERIC),
future_categorical_size = length(DYNAMIC_CAT),
vocab_static_size = dict_size(train, STATIC_CAT),
vocab_dynamic_size = dict_size(train, DYNAMIC_CAT),
hidden_dim = 10,
state_size = 5,
num_heads = 10,
dropout_rate = NULL,
output_size = 1
)
#> Loaded Tensorflow version 2.10.0
tft %>%
compile(optimizer='adam', loss='mse')