RNN model for time series forecasting
torchts_rnn(
formula,
data,
learn_rate = 0.001,
hidden_units,
dropout = FALSE,
timesteps = 20,
horizon = 1,
jump = horizon,
rnn_layer = nn_gru,
initial_layer_size = NULL,
optim = optim_adam(),
validation = NULL,
stateful = FALSE,
batch_size = 1,
epochs = 10,
shuffle = TRUE,
scale = TRUE,
sample_frac = 0.5,
loss_fn = nnf_mae,
device = NULL
)
(formula
) A formula describing, how to use the data
(data.frame
) A input data.frame.
(numeric
) Learning rate.
(integer
) Number of hidden units.
(logical
) Use dropout (default = FALSE).
(integer
) Number of timesteps used to produce a forecast.
(integer
) Forecast horizon.
(integer
) Input window shift.
(nn_rnn_base
) A torch
recurrent layer.
(function
) A function returning a torch
optimizer (like optim_adam
)
or R expression like optim_adam(amsgrad = TRUE)
. Such expression will be handled and feed with
params
and lr
arguments.
(data.frame
or numeric
) Validation dataset or percent of TODO.
(logical
or list
) Determine network behaviour: is stateful or not.
(integer
) Batch size.
(integer
) Number of epochs to train the network.
(logical
) A dataloader argument - shuffle rows or not?
(logical
or list
)
(numeric
) A fraction of time series to be sampled.
(function
) A torch
loss function.
(character
) A torch
device.
library(dplyr, warn.conflicts = FALSE)
library(torch)
library(torchts)
library(timetk)
# Preparing a dataset
tiny_m5_sample <-
tiny_m5 %>%
filter(item_id == "FOODS_3_586", store_id == "CA_1") %>%
mutate(value = as.numeric(value))
tk_summary_diagnostics(tiny_m5_sample)
#> tk_augment_timeseries_signature(): Using the following .date_var variable: date
#> # A tibble: 1 × 12
#> n.obs start end units scale tzone diff.minimum diff.q1 diff.median
#> <int> <date> <date> <chr> <chr> <chr> <dbl> <dbl> <dbl>
#> 1 1913 2011-01-29 2016-04-24 days day UTC 86400 86400 86400
#> # … with 3 more variables: diff.mean <dbl>, diff.q3 <dbl>, diff.maximum <dbl>
glimpse(tiny_m5_sample)
#> Rows: 1,913
#> Columns: 18
#> $ item_id <chr> "FOODS_3_586", "FOODS_3_586", "FOODS_3_586", "FOODS_3_586…
#> $ dept_id <chr> "FOODS_3", "FOODS_3", "FOODS_3", "FOODS_3", "FOODS_3", "F…
#> $ cat_id <chr> "FOODS", "FOODS", "FOODS", "FOODS", "FOODS", "FOODS", "FO…
#> $ store_id <chr> "CA_1", "CA_1", "CA_1", "CA_1", "CA_1", "CA_1", "CA_1", "…
#> $ state_id <chr> "CA", "CA", "CA", "CA", "CA", "CA", "CA", "CA", "CA", "CA…
#> $ value <dbl> 42, 36, 30, 23, 27, 34, 30, 59, 54, 37, 22, 38, 33, 38, 5…
#> $ date <date> 2011-01-29, 2011-01-30, 2011-01-31, 2011-02-01, 2011-02-…
#> $ wm_yr_wk <int> 11101, 11101, 11101, 11101, 11101, 11101, 11101, 11102, 1…
#> $ weekday <chr> "Saturday", "Sunday", "Monday", "Tuesday", "Wednesday", "…
#> $ wday <int> 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, …
#> $ month <int> 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, …
#> $ year <int> 2011, 2011, 2011, 2011, 2011, 2011, 2011, 2011, 2011, 201…
#> $ event_name_1 <chr> "", "", "", "", "", "", "", "", "SuperBowl", "", "", "", …
#> $ event_type_1 <chr> "", "", "", "", "", "", "", "", "Sporting", "", "", "", "…
#> $ event_name_2 <chr> "", "", "", "", "", "", "", "", "", "", "", "", "", "", "…
#> $ event_type_2 <chr> "", "", "", "", "", "", "", "", "", "", "", "", "", "", "…
#> $ snap <int> 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, …
#> $ sell_price <dbl> 1.48, 1.48, 1.48, 1.48, 1.48, 1.48, 1.48, 1.48, 1.48, 1.4…
TIMESTEPS <- 20
data_split <-
time_series_split(
tiny_m5_sample, date,
initial = "4 years",
assess = "1 year",
lag = TIMESTEPS
)
# Training
rnn_model <-
torchts_rnn(
value ~ date + value + sell_price + wday,
data = training(data_split),
hidden_units = 10,
timesteps = TIMESTEPS,
horizon = 1,
epochs = 10,
batch_size = 32
)
#> Categorical variables found (1): wday
#>
#> Training started
#> Error in (function (self, other) { .Call("_torch_cpp_torch_method_matmul_self_Tensor_other_Tensor", PACKAGE = "torchpkg", self, other)})(self = <pointer: 0x557a73f02ba0>, other = <pointer: 0x557a59c8ae30>): mat1 and mat2 shapes cannot be multiplied (640x4 and 3x6)
#> Exception raised from addmm_impl_cpu_ at ../aten/src/ATen/native/LinearAlgebra.cpp:939 (most recent call first):
#> frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x69 (0x7f1d873c71d9 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libc10.so)
#> frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xd2 (0x7f1d873c3812 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libc10.so)
#> frame #2: <unknown function> + 0x11848a3 (0x7f1cdff6d8a3 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #3: at::native::mm_cpu_out(at::Tensor const&, at::Tensor const&, at::Tensor&) + 0x100 (0x7f1cdff6ee40 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #4: at::native::mm_cpu(at::Tensor const&, at::Tensor const&) + 0x7e (0x7f1cdff6efae in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #5: <unknown function> + 0x1b17e13 (0x7f1ce0900e13 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #6: at::Tensor c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&, at::Tensor const&>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&)> const&, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) const + 0x86 (0x7f1ce08b8db6 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #7: at::redispatch::mm(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) + 0x72 (0x7f1ce0740bb2 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #8: <unknown function> + 0x33c2e41 (0x7f1ce21abe41 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #9: <unknown function> + 0x33c31c6 (0x7f1ce21ac1c6 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #10: at::Tensor::mm(at::Tensor const&) const + 0x137 (0x7f1ce0c5d367 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #11: <unknown function> + 0x118bcfe (0x7f1cdff74cfe in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #12: at::native::matmul(at::Tensor const&, at::Tensor const&) + 0x4a (0x7f1cdff7537a in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #13: <unknown function> + 0x1cc85f3 (0x7f1ce0ab15f3 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #14: at::Tensor::matmul(at::Tensor const&) const + 0x137 (0x7f1ce0c5d077 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #15: _lantern_Tensor_matmul_tensor_tensor + 0x46 (0x7f1d87a53ac6 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/liblantern.so)
#> frame #16: cpp_torch_method_matmul_self_Tensor_other_Tensor(XPtrTorchTensor, XPtrTorchTensor) + 0x35 (0x7f1d8c5d17a5 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/libs/torchpkg.so)
#> frame #17: _torch_cpp_torch_method_matmul_self_Tensor_other_Tensor + 0xa2 (0x7f1d8c37ea92 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/libs/torchpkg.so)
#> frame #18: <unknown function> + 0xf751c (0x7f1d97ba751c in /usr/lib/R/lib/libR.so)
#> frame #19: <unknown function> + 0xf7a65 (0x7f1d97ba7a65 in /usr/lib/R/lib/libR.so)
#> frame #20: <unknown function> + 0x132ba1 (0x7f1d97be2ba1 in /usr/lib/R/lib/libR.so)
#> frame #21: Rf_eval + 0x88 (0x7f1d97bfd948 in /usr/lib/R/lib/libR.so)
#> frame #22: <unknown function> + 0x14f80f (0x7f1d97bff80f in /usr/lib/R/lib/libR.so)
#> frame #23: Rf_applyClosure + 0x1a2 (0x7f1d97c00702 in /usr/lib/R/lib/libR.so)
#> frame #24: Rf_eval + 0x2af (0x7f1d97bfdb6f in /usr/lib/R/lib/libR.so)
#> frame #25: <unknown function> + 0xc2bcd (0x7f1d97b72bcd in /usr/lib/R/lib/libR.so)
#> frame #26: <unknown function> + 0x132ba1 (0x7f1d97be2ba1 in /usr/lib/R/lib/libR.so)
#> frame #27: Rf_eval + 0x88 (0x7f1d97bfd948 in /usr/lib/R/lib/libR.so)
#> frame #28: <unknown function> + 0x14f80f (0x7f1d97bff80f in /usr/lib/R/lib/libR.so)
#> frame #29: Rf_applyClosure + 0x1a2 (0x7f1d97c00702 in /usr/lib/R/lib/libR.so)
#> frame #30: <unknown function> + 0x13b296 (0x7f1d97beb296 in /usr/lib/R/lib/libR.so)
#> frame #31: Rf_eval + 0x88 (0x7f1d97bfd948 in /usr/lib/R/lib/libR.so)
#> frame #32: <unknown function> + 0x14f80f (0x7f1d97bff80f in /usr/lib/R/lib/libR.so)
#> frame #33: Rf_applyClosure + 0x1a2 (0x7f1d97c00702 in /usr/lib/R/lib/libR.so)
#> frame #34: <unknown function> + 0x13b296 (0x7f1d97beb296 in /usr/lib/R/lib/libR.so)
#> frame #35: Rf_eval + 0x88 (0x7f1d97bfd948 in /usr/lib/R/lib/libR.so)
#> frame #36: <unknown function> + 0x14f80f (0x7f1d97bff80f in /usr/lib/R/lib/libR.so)
#> frame #37: Rf_applyClosure + 0x1a2 (0x7f1d97c00702 in /usr/lib/R/lib/libR.so)
#> frame #38: Rf_eval + 0x2af (0x7f1d97bfdb6f in /usr/lib/R/lib/libR.so)
#> frame #39: <unknown function> + 0x1514f3 (0x7f1d97c014f3 in /usr/lib/R/lib/libR.so)
#> frame #40: Rf_eval + 0x57b (0x7f1d97bfde3b in /usr/lib/R/lib/libR.so)
#> frame #41: <unknown function> + 0x14f80f (0x7f1d97bff80f in /usr/lib/R/lib/libR.so)
#> frame #42: Rf_applyClosure + 0x1a2 (0x7f1d97c00702 in /usr/lib/R/lib/libR.so)
#> frame #43: <unknown function> + 0x13b296 (0x7f1d97beb296 in /usr/lib/R/lib/libR.so)
#> frame #44: Rf_eval + 0x88 (0x7f1d97bfd948 in /usr/lib/R/lib/libR.so)
#> frame #45: <unknown function> + 0x14f80f (0x7f1d97bff80f in /usr/lib/R/lib/libR.so)
#> frame #46: Rf_applyClosure + 0x1a2 (0x7f1d97c00702 in /usr/lib/R/lib/libR.so)
#> frame #47: Rf_eval + 0x2af (0x7f1d97bfdb6f in /usr/lib/R/lib/libR.so)
#> frame #48: <unknown function> + 0x1514f3 (0x7f1d97c014f3 in /usr/lib/R/lib/libR.so)
#> frame #49: Rf_eval + 0x57b (0x7f1d97bfde3b in /usr/lib/R/lib/libR.so)
#> frame #50: <unknown function> + 0x14f80f (0x7f1d97bff80f in /usr/lib/R/lib/libR.so)
#> frame #51: Rf_applyClosure + 0x1a2 (0x7f1d97c00702 in /usr/lib/R/lib/libR.so)
#> frame #52: Rf_eval + 0x2af (0x7f1d97bfdb6f in /usr/lib/R/lib/libR.so)
#> frame #53: <unknown function> + 0x14e3dc (0x7f1d97bfe3dc in /usr/lib/R/lib/libR.so)
#> frame #54: Rf_eval + 0x39f (0x7f1d97bfdc5f in /usr/lib/R/lib/libR.so)
#> frame #55: <unknown function> + 0x14e3dc (0x7f1d97bfe3dc in /usr/lib/R/lib/libR.so)
#> frame #56: <unknown function> + 0x14e888 (0x7f1d97bfe888 in /usr/lib/R/lib/libR.so)
#> frame #57: <unknown function> + 0x13925c (0x7f1d97be925c in /usr/lib/R/lib/libR.so)
#> frame #58: Rf_eval + 0x88 (0x7f1d97bfd948 in /usr/lib/R/lib/libR.so)
#> frame #59: <unknown function> + 0x14e3dc (0x7f1d97bfe3dc in /usr/lib/R/lib/libR.so)
#> frame #60: Rf_eval + 0x5f8 (0x7f1d97bfdeb8 in /usr/lib/R/lib/libR.so)
#> frame #61: <unknown function> + 0x11b305 (0x7f1d97bcb305 in /usr/lib/R/lib/libR.so)
#> frame #62: <unknown function> + 0x132ba1 (0x7f1d97be2ba1 in /usr/lib/R/lib/libR.so)
#> frame #63: Rf_eval + 0x88 (0x7f1d97bfd948 in /usr/lib/R/lib/libR.so)
# Prediction
cleared_new_data <-
testing(data_split) %>%
clear_outcome(date, value, TIMESTEPS)
forecast <-
predict(rnn_model, cleared_new_data)
#> Error in predict(rnn_model, cleared_new_data): object 'rnn_model' not found