Quick shortcut to create a torch dataloader based on the given dataset

as_ts_dataloader(
  data,
  formula,
  index = NULL,
  key = NULL,
  predictors = NULL,
  outcomes = NULL,
  categorical = NULL,
  timesteps,
  horizon = 1,
  sample_frac = 1,
  scale = TRUE,
  batch_size,
  shuffle = FALSE,
  jump = 1,
  drop_last = TRUE,
  ...
)

Arguments

data

(data.frame) An input data.frame object with. For now only single data frames are handled with no categorical features.

formula

(formula) A formula describing, how to use the data

index

(character) The index column name.

key

(character) The key column name(s). Use only if formula was not specified.

predictors

(character) Input variable names. Use only if formula was not specified.

outcomes

(character) Target variable names. Use only if formula was not specified.

categorical

(character) Categorical features.

timesteps

(integer) The time series chunk length.

horizon

(integer) Forecast horizon.

sample_frac

(numeric) Sample a fraction of rows (default: 1, i.e.: all the rows).

scale

(logical or list) Scale feature columns. Logical value or two-element list. with values (mean, std)

batch_size

(numeric) Batch size.

shuffle

(logical) Shuffle examples.

drop_last

(logical) Set to TRUE to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If FALSE and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: TRUE)

Examples

library(rsample)
library(dplyr, warn.conflicts = FALSE)

suwalki_temp <-
   weather_pl %>%
   filter(station == "SWK") %>%
   select(date, temp = tmax_daily)

# Splitting on training and test
data_split <- initial_time_split(suwalki_temp)

train_dl <-
 training(data_split) %>%
 as_ts_dataloader(temp ~ date, timesteps = 20, horizon = 10, batch_size = 32)

train_dl
#> <dataloader>
#>   Public:
#>     .auto_collation: active binding
#>     .dataset_kind: map
#>     .has_getbatch: FALSE
#>     .index_sampler: active binding
#>     .iter: function () 
#>     .length: function () 
#>     batch_sampler: utils_sampler_batch, utils_sampler, R6
#>     batch_size: 32
#>     clone: function (deep = FALSE) 
#>     collate_fn: function (batch) 
#>     dataset: ts_dataset, dataset, R6
#>     drop_last: TRUE
#>     generator: NULL
#>     initialize: function (dataset, batch_size = 1, shuffle = FALSE, sampler = NULL, 
#>     multiprocessing_context: NULL
#>     num_workers: 0
#>     pin_memory: FALSE
#>     sampler: utils_sampler_sequential, utils_sampler, R6
#>     timeout: -1
#>     worker_globals: NULL
#>     worker_init_fn: NULL
#>     worker_packages: NULL

dataloader_next(dataloader_make_iter(train_dl))
#> Error in dataloader_next(dataloader_make_iter(train_dl)): could not find function "dataloader_next"