Skip to contents

Interpretable multi-head attention layer

Usage

layer_interpretable_mh_attention(
  object,
  state_size,
  num_heads,
  dropout_rate = 0,
  ...
)

Arguments

num_heads

Number of attention heads.

dropout_rate

Dropout rate

Examples

lookback   <- 28
horizon    <- 14
all_steps  <- lookback + horizon
state_size <- 5

queries <- layer_input(c(horizon, state_size))
keys    <- layer_input(c(all_steps, state_size))
values  <- layer_input(c(all_steps, state_size))

imh_attention <-
   layer_interpretable_mh_attention(
      state_size = state_size, num_heads = 10
   )(queries, keys, values)