It is especially useful, for dealing with multiple categorical features.
nn_multi_embedding(
num_embeddings,
embedding_dim,
padding_idx = NULL,
max_norm = NULL,
norm_type = 2,
scale_grad_by_freq = FALSE,
sparse = FALSE,
.weight = NULL
)
(integer
) Size of the dictionary of embeddings.
(integer
) The size of each embedding vector.
(integer
, optional) If given, pads the output with
the embedding vector at padding_idx
(initialized to zeros) whenever it encounters the index.
(numeric
, optional) If given, each embedding vector with norm larger
than max_norm is renormalized to have norm max_norm.
(numeric
, optional) The p of the p-norm to compute for the max_norm option. Default 2.
(logical
, optional) If given, this will scale gradients by
the inverse of frequency of the words in the mini-batch. Default FALSE.
(logical
, optional) If TRUE, gradient w.r.t. weight matrix will be a sparse tensor.
(torch_tensor
or list
of torch_tensor
) Embeddings weights (in case you want to set it manually).
library(recipes)
data("gss_cat", package = "forcats")
gss_cat_transformed <-
recipe(gss_cat) %>%
step_integer(everything()) %>%
prep() %>%
juice()
gss_cat_transformed <- na.omit(gss_cat_transformed)
gss_cat_transformed <-
gss_cat_transformed %>%
mutate(across(where(is.numeric), as.integer))
glimpse(gss_cat_transformed)
#> Rows: 11,299
#> Columns: 9
#> $ year <int> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,…
#> $ marital <int> 2, 5, 2, 4, 2, 6, 6, 6, 6, 4, 6, 2, 6, 6, 6, 2, 5, 5, 5, 4, 5,…
#> $ age <int> 9, 50, 22, 8, 19, 27, 30, 36, 35, 35, 23, 27, 23, 31, 32, 2, 3…
#> $ race <int> 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,…
#> $ rincome <int> 8, 16, 16, 16, 4, 4, 4, 4, 4, 4, 4, 4, 7, 4, 3, 16, 4, 16, 16,…
#> $ partyid <int> 6, 7, 6, 9, 5, 9, 4, 9, 10, 8, 10, 7, 9, 8, 4, 7, 6, 9, 9, 10,…
#> $ relig <int> 15, 15, 6, 12, 5, 15, 15, 15, 15, 12, 15, 12, 14, 14, 15, 12, …
#> $ denom <int> 25, 3, 30, 30, 30, 4, 25, 4, 25, 30, 23, 30, 30, 30, 20, 30, 3…
#> $ tvhours <int> 13, 3, 5, 2, 4, 1, 4, 3, 2, 2, 8, 4, 4, 2, 3, 3, 2, 4, 5, 8, 4…
gss_cat_tensor <- as_tensor(gss_cat_transformed)
.dict_size <- dict_size(gss_cat_transformed)
.dict_size
#> year marital age race rincome partyid relig denom tvhours
#> 8 6 72 3 16 10 15 30 24
.embedding_size <- embedding_size_google(.dict_size)
embedding_module <-
nn_multi_embedding(.dict_size, .embedding_size)
#> Error in nn_multi_embedding(.dict_size, .embedding_size): could not find function "nn_multi_embedding"
# Expected output size
sum(.embedding_size)
#> [1] 21
embedding_module(gss_cat_tensor)
#> Error in embedding_module(gss_cat_tensor): could not find function "embedding_module"