New features will be added in near future, e.g. categorical feature handling and so on.
model_rnn(
rnn_layer = nn_gru,
input_size,
output_size,
hidden_size,
horizon = 1,
embedding = NULL,
initial_layer = nn_nonlinear,
last_timesteps = 1,
final_layer = nn_linear,
dropout = 0,
batch_first = TRUE
)
(nn_rnn_base
) A recurrent torch
layer.
(integer
) Input size.
(integer
) Output size (number of target variables).
(integer
) A size of recurrent hidden layer.
(integer
) Horizon size. How many steps ahead produce from the last n steps?
(embedding_spec
) List with two values: num_embeddings and embedding_dim.
(nn_module
) A torch
module to preprocess numeric features before the recurrent layer.
(nn_module
) If not null, applied instead of default linear layer.
(logical
) Use dropout.
(logical
) Channel order.
library(dplyr, warn.conflicts = FALSE)
library(torch)
library(torchts)
# Preparing data
weather_data <-
weather_pl %>%
filter(station == "TRN") %>%
select(date, tmax_daily, rr_type) %>%
mutate(rr_type = ifelse(is.na(rr_type), "NA", rr_type))
weather_dl <-
weather_data %>%
as_ts_dataloader(
tmax_daily ~ date + tmax_daily + rr_type,
timesteps = 30,
categorical = "rr_type",
batch_size = 32
)
#> Categorical variables found (1): rr_type
unique(weather_data$rr_type)
#> [1] "" "W" "S" "NA"
n_unique_values <- n_distinct(weather_data$rr_type)
.embedding_spec <-
embedding_spec(
num_embeddings = n_unique_values,
embedding_dim = embedding_size_google(n_unique_values)
)
#> Error in embedding_spec(num_embeddings = n_unique_values, embedding_dim = embedding_size_google(n_unique_values)): could not find function "embedding_spec"
input_size <- 1 + embedding_size_google(n_unique_values) # tmax_daily + rr_type embedding
# Creating a model
rnn_net <-
model_rnn(
input_size = input_size,
output_size = 2,
hidden_size = 10,
horizon = 10,
embedding = .embedding_spec
)
#> Error in initialize(...): object '.embedding_spec' not found
print(rnn_net)
#> Error in print(rnn_net): object 'rnn_net' not found
# Prediction example on non-trained neural network
batch <-
dataloader_next(dataloader_make_iter(weather_dl))
# debugonce(rnn_net$forward)
rnn_net(batch$x_num, batch$x_cat)
#> Error in rnn_net(batch$x_num, batch$x_cat): could not find function "rnn_net"