Time series models with torch
You can install the released version of torchts from CRAN with:
The development version from GitHub with:
# install.packages("devtools")
devtools::install_github("krzjoa/torchts")
library(torchts)
library(torch)
library(rsample)
library(dplyr, warn.conflicts = FALSE)
library(parsnip)
library(timetk)
library(ggplot2)
tarnow_temp <-
weather_pl %>%
filter(station == "TRN") %>%
select(date, tmax_daily)
# Params
EPOCHS <- 3
HORIZON <- 1
TIMESTEPS <- 28
# Splitting on training and test
data_split <-
time_series_split(
tarnow_temp, date,
initial = "18 years",
assess = "2 years",
lag = TIMESTEPS
)
# Training
rnn_model <-
rnn(
timesteps = TIMESTEPS,
horizon = HORIZON,
epochs = EPOCHS,
learn_rate = 0.01,
hidden_units = 20,
batch_size = 32,
scale = TRUE
) %>%
set_device('cpu') %>%
fit(tmax_daily ~ date,
data = training(data_split))
#> Warning: Engine set to `torchts`.
#>
#> Training started
#> | train: 0.36100
#> | train: 0.30690
#> | train: 0.27990
prediction <-
rnn_model %>%
predict(new_data = testing(data_split))
plot_forecast(
data = testing(data_split),
forecast = prediction,
outcome = tmax_daily
)
In as_tensor
function we can specify columns, that are used to create a tensor out of the input data.frame
. Listed column names are only used to determine dimension sizes - they are removed after that and are not present in the final tensor.
temperature_pl <-
weather_pl %>%
select(station, date, tmax_daily)
# Expected shape
c(
n_distinct(temperature_pl$station),
n_distinct(temperature_pl$date),
1
)
#> [1] 2 7305 1
temperature_tensor <-
temperature_pl %>%
as_tensor(station, date)
dim(temperature_tensor)
#> [1] 2 7305 1
temperature_tensor[1, 1:10]
#> torch_tensor
#> -0.2000
#> -1.4000
#> 0.4000
#> 1.0000
#> 0.6000
#> 3.0000
#> 4.0000
#> 1.0000
#> 1.2000
#> 1.4000
#> [ CPUFloatType{10,1} ]
temperature_pl %>%
filter(station == "SWK") %>%
arrange(date) %>%
head(10)
#> station date tmax_daily
#> 1140 SWK 2001-01-01 -0.2
#> 1230 SWK 2001-01-02 -1.4
#> 2330 SWK 2001-01-03 0.4
#> 2630 SWK 2001-01-04 1.0
#> 2730 SWK 2001-01-05 0.6
#> 2830 SWK 2001-01-06 3.0
#> 2930 SWK 2001-01-07 4.0
#> 3030 SWK 2001-01-08 1.0
#> 3130 SWK 2001-01-09 1.2
#> 2140 SWK 2001-01-10 1.4