It is especially useful, for dealing with multiple categorical features.

nn_multi_embedding(
  num_embeddings,
  embedding_dim,
  padding_idx = NULL,
  max_norm = NULL,
  norm_type = 2,
  scale_grad_by_freq = FALSE,
  sparse = FALSE,
  .weight = NULL
)

Arguments

num_embeddings

(integer) Size of the dictionary of embeddings.

embedding_dim

(integer) The size of each embedding vector.

padding_idx

(integer, optional) If given, pads the output with the embedding vector at padding_idx (initialized to zeros) whenever it encounters the index.

max_norm

(numeric, optional) If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm.

norm_type

(numeric, optional) The p of the p-norm to compute for the max_norm option. Default 2.

scale_grad_by_freq

(logical, optional) If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default FALSE.

sparse

(logical, optional) If TRUE, gradient w.r.t. weight matrix will be a sparse tensor.

.weight

(torch_tensor or list of torch_tensor) Embeddings weights (in case you want to set it manually).

Examples

library(recipes)

data("gss_cat", package = "forcats")

gss_cat_transformed <-
  recipe(gss_cat) %>%
  step_integer(everything()) %>%
  prep() %>%
  juice()

gss_cat_transformed <- na.omit(gss_cat_transformed)

gss_cat_transformed <-
   gss_cat_transformed %>%
   mutate(across(where(is.numeric), as.integer))

glimpse(gss_cat_transformed)
#> Rows: 11,299
#> Columns: 9
#> $ year    <int> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,…
#> $ marital <int> 2, 5, 2, 4, 2, 6, 6, 6, 6, 4, 6, 2, 6, 6, 6, 2, 5, 5, 5, 4, 5,…
#> $ age     <int> 9, 50, 22, 8, 19, 27, 30, 36, 35, 35, 23, 27, 23, 31, 32, 2, 3…
#> $ race    <int> 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,…
#> $ rincome <int> 8, 16, 16, 16, 4, 4, 4, 4, 4, 4, 4, 4, 7, 4, 3, 16, 4, 16, 16,…
#> $ partyid <int> 6, 7, 6, 9, 5, 9, 4, 9, 10, 8, 10, 7, 9, 8, 4, 7, 6, 9, 9, 10,…
#> $ relig   <int> 15, 15, 6, 12, 5, 15, 15, 15, 15, 12, 15, 12, 14, 14, 15, 12, …
#> $ denom   <int> 25, 3, 30, 30, 30, 4, 25, 4, 25, 30, 23, 30, 30, 30, 20, 30, 3…
#> $ tvhours <int> 13, 3, 5, 2, 4, 1, 4, 3, 2, 2, 8, 4, 4, 2, 3, 3, 2, 4, 5, 8, 4…

gss_cat_tensor  <- as_tensor(gss_cat_transformed)
.dict_size      <- dict_size(gss_cat_transformed)
.dict_size
#>    year marital     age    race rincome partyid   relig   denom tvhours 
#>       8       6      72       3      16      10      15      30      24 

.embedding_size <- embedding_size_google(.dict_size)

embedding_module <-
  nn_multi_embedding(.dict_size, .embedding_size)
#> Error in nn_multi_embedding(.dict_size, .embedding_size): could not find function "nn_multi_embedding"

# Expected output size
sum(.embedding_size)
#> [1] 21

embedding_module(gss_cat_tensor)
#> Error in embedding_module(gss_cat_tensor): could not find function "embedding_module"