RNN model for time series forecasting

torchts_rnn(
  formula,
  data,
  learn_rate = 0.001,
  hidden_units,
  dropout = FALSE,
  timesteps = 20,
  horizon = 1,
  jump = horizon,
  rnn_layer = nn_gru,
  initial_layer_size = NULL,
  optim = optim_adam(),
  validation = NULL,
  stateful = FALSE,
  batch_size = 1,
  epochs = 10,
  shuffle = TRUE,
  scale = TRUE,
  sample_frac = 0.5,
  loss_fn = nnf_mae,
  device = NULL
)

Arguments

formula

(formula) A formula describing, how to use the data

data

(data.frame) A input data.frame.

learn_rate

(numeric) Learning rate.

hidden_units

(integer) Number of hidden units.

dropout

(logical) Use dropout (default = FALSE).

timesteps

(integer) Number of timesteps used to produce a forecast.

horizon

(integer) Forecast horizon.

jump

(integer) Input window shift.

rnn_layer

(nn_rnn_base) A torch recurrent layer.

optim

(function) A function returning a torch optimizer (like optim_adam) or R expression like optim_adam(amsgrad = TRUE). Such expression will be handled and feed with params and lr arguments.

validation

(data.frame or numeric) Validation dataset or percent of TODO.

stateful

(logical or list) Determine network behaviour: is stateful or not.

batch_size

(integer) Batch size.

epochs

(integer) Number of epochs to train the network.

shuffle

(logical) A dataloader argument - shuffle rows or not?

scale

(logical or list)

sample_frac

(numeric) A fraction of time series to be sampled.

loss_fn

(function) A torch loss function.

device

(character) A torch device.

Examples

library(dplyr, warn.conflicts = FALSE)
library(torch)
library(torchts)
library(timetk)

# Preparing a dataset
tiny_m5_sample <-
  tiny_m5 %>%
  filter(item_id == "FOODS_3_586", store_id == "CA_1") %>%
  mutate(value = as.numeric(value))

tk_summary_diagnostics(tiny_m5_sample)
#> tk_augment_timeseries_signature(): Using the following .date_var variable: date
#> # A tibble: 1 × 12
#>   n.obs start      end        units scale tzone diff.minimum diff.q1 diff.median
#>   <int> <date>     <date>     <chr> <chr> <chr>        <dbl>   <dbl>       <dbl>
#> 1  1913 2011-01-29 2016-04-24 days  day   UTC          86400   86400       86400
#> # … with 3 more variables: diff.mean <dbl>, diff.q3 <dbl>, diff.maximum <dbl>
glimpse(tiny_m5_sample)
#> Rows: 1,913
#> Columns: 18
#> $ item_id      <chr> "FOODS_3_586", "FOODS_3_586", "FOODS_3_586", "FOODS_3_586…
#> $ dept_id      <chr> "FOODS_3", "FOODS_3", "FOODS_3", "FOODS_3", "FOODS_3", "F…
#> $ cat_id       <chr> "FOODS", "FOODS", "FOODS", "FOODS", "FOODS", "FOODS", "FO…
#> $ store_id     <chr> "CA_1", "CA_1", "CA_1", "CA_1", "CA_1", "CA_1", "CA_1", "…
#> $ state_id     <chr> "CA", "CA", "CA", "CA", "CA", "CA", "CA", "CA", "CA", "CA…
#> $ value        <dbl> 42, 36, 30, 23, 27, 34, 30, 59, 54, 37, 22, 38, 33, 38, 5…
#> $ date         <date> 2011-01-29, 2011-01-30, 2011-01-31, 2011-02-01, 2011-02-…
#> $ wm_yr_wk     <int> 11101, 11101, 11101, 11101, 11101, 11101, 11101, 11102, 1…
#> $ weekday      <chr> "Saturday", "Sunday", "Monday", "Tuesday", "Wednesday", "…
#> $ wday         <int> 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, …
#> $ month        <int> 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, …
#> $ year         <int> 2011, 2011, 2011, 2011, 2011, 2011, 2011, 2011, 2011, 201…
#> $ event_name_1 <chr> "", "", "", "", "", "", "", "", "SuperBowl", "", "", "", …
#> $ event_type_1 <chr> "", "", "", "", "", "", "", "", "Sporting", "", "", "", "…
#> $ event_name_2 <chr> "", "", "", "", "", "", "", "", "", "", "", "", "", "", "…
#> $ event_type_2 <chr> "", "", "", "", "", "", "", "", "", "", "", "", "", "", "…
#> $ snap         <int> 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, …
#> $ sell_price   <dbl> 1.48, 1.48, 1.48, 1.48, 1.48, 1.48, 1.48, 1.48, 1.48, 1.4…

TIMESTEPS <- 20

data_split <-
  time_series_split(
    tiny_m5_sample, date,
    initial = "4 years",
    assess  = "1 year",
    lag     = TIMESTEPS
  )

# Training
rnn_model <-
  torchts_rnn(
    value ~ date + value + sell_price + wday,
    data = training(data_split),
    hidden_units = 10,
    timesteps = TIMESTEPS,
    horizon   = 1,
    epochs = 10,
    batch_size = 32
  )
#> Categorical variables found (1): wday
#> 
#> Training started
#> Error in (function (self, other) {    .Call("_torch_cpp_torch_method_matmul_self_Tensor_other_Tensor",         PACKAGE = "torchpkg", self, other)})(self = <pointer: 0x557a73f02ba0>, other = <pointer: 0x557a59c8ae30>): mat1 and mat2 shapes cannot be multiplied (640x4 and 3x6)
#> Exception raised from addmm_impl_cpu_ at ../aten/src/ATen/native/LinearAlgebra.cpp:939 (most recent call first):
#> frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x69 (0x7f1d873c71d9 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libc10.so)
#> frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xd2 (0x7f1d873c3812 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libc10.so)
#> frame #2: <unknown function> + 0x11848a3 (0x7f1cdff6d8a3 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #3: at::native::mm_cpu_out(at::Tensor const&, at::Tensor const&, at::Tensor&) + 0x100 (0x7f1cdff6ee40 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #4: at::native::mm_cpu(at::Tensor const&, at::Tensor const&) + 0x7e (0x7f1cdff6efae in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #5: <unknown function> + 0x1b17e13 (0x7f1ce0900e13 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #6: at::Tensor c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&, at::Tensor const&>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&)> const&, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) const + 0x86 (0x7f1ce08b8db6 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #7: at::redispatch::mm(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) + 0x72 (0x7f1ce0740bb2 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #8: <unknown function> + 0x33c2e41 (0x7f1ce21abe41 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #9: <unknown function> + 0x33c31c6 (0x7f1ce21ac1c6 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #10: at::Tensor::mm(at::Tensor const&) const + 0x137 (0x7f1ce0c5d367 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #11: <unknown function> + 0x118bcfe (0x7f1cdff74cfe in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #12: at::native::matmul(at::Tensor const&, at::Tensor const&) + 0x4a (0x7f1cdff7537a in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #13: <unknown function> + 0x1cc85f3 (0x7f1ce0ab15f3 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #14: at::Tensor::matmul(at::Tensor const&) const + 0x137 (0x7f1ce0c5d077 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/./libtorch_cpu.so)
#> frame #15: _lantern_Tensor_matmul_tensor_tensor + 0x46 (0x7f1d87a53ac6 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/deps/liblantern.so)
#> frame #16: cpp_torch_method_matmul_self_Tensor_other_Tensor(XPtrTorchTensor, XPtrTorchTensor) + 0x35 (0x7f1d8c5d17a5 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/libs/torchpkg.so)
#> frame #17: _torch_cpp_torch_method_matmul_self_Tensor_other_Tensor + 0xa2 (0x7f1d8c37ea92 in /home/krzysztof/R/x86_64-pc-linux-gnu-library/4.1/torch/libs/torchpkg.so)
#> frame #18: <unknown function> + 0xf751c (0x7f1d97ba751c in /usr/lib/R/lib/libR.so)
#> frame #19: <unknown function> + 0xf7a65 (0x7f1d97ba7a65 in /usr/lib/R/lib/libR.so)
#> frame #20: <unknown function> + 0x132ba1 (0x7f1d97be2ba1 in /usr/lib/R/lib/libR.so)
#> frame #21: Rf_eval + 0x88 (0x7f1d97bfd948 in /usr/lib/R/lib/libR.so)
#> frame #22: <unknown function> + 0x14f80f (0x7f1d97bff80f in /usr/lib/R/lib/libR.so)
#> frame #23: Rf_applyClosure + 0x1a2 (0x7f1d97c00702 in /usr/lib/R/lib/libR.so)
#> frame #24: Rf_eval + 0x2af (0x7f1d97bfdb6f in /usr/lib/R/lib/libR.so)
#> frame #25: <unknown function> + 0xc2bcd (0x7f1d97b72bcd in /usr/lib/R/lib/libR.so)
#> frame #26: <unknown function> + 0x132ba1 (0x7f1d97be2ba1 in /usr/lib/R/lib/libR.so)
#> frame #27: Rf_eval + 0x88 (0x7f1d97bfd948 in /usr/lib/R/lib/libR.so)
#> frame #28: <unknown function> + 0x14f80f (0x7f1d97bff80f in /usr/lib/R/lib/libR.so)
#> frame #29: Rf_applyClosure + 0x1a2 (0x7f1d97c00702 in /usr/lib/R/lib/libR.so)
#> frame #30: <unknown function> + 0x13b296 (0x7f1d97beb296 in /usr/lib/R/lib/libR.so)
#> frame #31: Rf_eval + 0x88 (0x7f1d97bfd948 in /usr/lib/R/lib/libR.so)
#> frame #32: <unknown function> + 0x14f80f (0x7f1d97bff80f in /usr/lib/R/lib/libR.so)
#> frame #33: Rf_applyClosure + 0x1a2 (0x7f1d97c00702 in /usr/lib/R/lib/libR.so)
#> frame #34: <unknown function> + 0x13b296 (0x7f1d97beb296 in /usr/lib/R/lib/libR.so)
#> frame #35: Rf_eval + 0x88 (0x7f1d97bfd948 in /usr/lib/R/lib/libR.so)
#> frame #36: <unknown function> + 0x14f80f (0x7f1d97bff80f in /usr/lib/R/lib/libR.so)
#> frame #37: Rf_applyClosure + 0x1a2 (0x7f1d97c00702 in /usr/lib/R/lib/libR.so)
#> frame #38: Rf_eval + 0x2af (0x7f1d97bfdb6f in /usr/lib/R/lib/libR.so)
#> frame #39: <unknown function> + 0x1514f3 (0x7f1d97c014f3 in /usr/lib/R/lib/libR.so)
#> frame #40: Rf_eval + 0x57b (0x7f1d97bfde3b in /usr/lib/R/lib/libR.so)
#> frame #41: <unknown function> + 0x14f80f (0x7f1d97bff80f in /usr/lib/R/lib/libR.so)
#> frame #42: Rf_applyClosure + 0x1a2 (0x7f1d97c00702 in /usr/lib/R/lib/libR.so)
#> frame #43: <unknown function> + 0x13b296 (0x7f1d97beb296 in /usr/lib/R/lib/libR.so)
#> frame #44: Rf_eval + 0x88 (0x7f1d97bfd948 in /usr/lib/R/lib/libR.so)
#> frame #45: <unknown function> + 0x14f80f (0x7f1d97bff80f in /usr/lib/R/lib/libR.so)
#> frame #46: Rf_applyClosure + 0x1a2 (0x7f1d97c00702 in /usr/lib/R/lib/libR.so)
#> frame #47: Rf_eval + 0x2af (0x7f1d97bfdb6f in /usr/lib/R/lib/libR.so)
#> frame #48: <unknown function> + 0x1514f3 (0x7f1d97c014f3 in /usr/lib/R/lib/libR.so)
#> frame #49: Rf_eval + 0x57b (0x7f1d97bfde3b in /usr/lib/R/lib/libR.so)
#> frame #50: <unknown function> + 0x14f80f (0x7f1d97bff80f in /usr/lib/R/lib/libR.so)
#> frame #51: Rf_applyClosure + 0x1a2 (0x7f1d97c00702 in /usr/lib/R/lib/libR.so)
#> frame #52: Rf_eval + 0x2af (0x7f1d97bfdb6f in /usr/lib/R/lib/libR.so)
#> frame #53: <unknown function> + 0x14e3dc (0x7f1d97bfe3dc in /usr/lib/R/lib/libR.so)
#> frame #54: Rf_eval + 0x39f (0x7f1d97bfdc5f in /usr/lib/R/lib/libR.so)
#> frame #55: <unknown function> + 0x14e3dc (0x7f1d97bfe3dc in /usr/lib/R/lib/libR.so)
#> frame #56: <unknown function> + 0x14e888 (0x7f1d97bfe888 in /usr/lib/R/lib/libR.so)
#> frame #57: <unknown function> + 0x13925c (0x7f1d97be925c in /usr/lib/R/lib/libR.so)
#> frame #58: Rf_eval + 0x88 (0x7f1d97bfd948 in /usr/lib/R/lib/libR.so)
#> frame #59: <unknown function> + 0x14e3dc (0x7f1d97bfe3dc in /usr/lib/R/lib/libR.so)
#> frame #60: Rf_eval + 0x5f8 (0x7f1d97bfdeb8 in /usr/lib/R/lib/libR.so)
#> frame #61: <unknown function> + 0x11b305 (0x7f1d97bcb305 in /usr/lib/R/lib/libR.so)
#> frame #62: <unknown function> + 0x132ba1 (0x7f1d97be2ba1 in /usr/lib/R/lib/libR.so)
#> frame #63: Rf_eval + 0x88 (0x7f1d97bfd948 in /usr/lib/R/lib/libR.so)

# Prediction
cleared_new_data <-
  testing(data_split) %>%
  clear_outcome(date, value, TIMESTEPS)

forecast <-
  predict(rnn_model, cleared_new_data)
#> Error in predict(rnn_model, cleared_new_data): object 'rnn_model' not found