Bahdanau attention

In an earlier post, I had written about seq2seq without attention by way of introducing the idea. This time, we extend upon that by adding attention to the setup. In the regular seq2seq model, we embed our input sequence x={x_1, x_2, ..., x_T} into a context vector c, which is then used to make predictions. In the attention variant, the context vector c is replaced by a customized context c_i for the hidden decoder vector s_{i-1}. The result is the summed over contribution over all of the input hidden vectors. Attention is important for the model to generalize well to test data, in that our model might learn to minimize the cost function during train time, but it is only when it learns attention that we know that it has an idea that it knows exactly where to look (and put that knowledge into the context) for it to generalize well to test data.

c_i = \sum_j \alpha_{ij} h_j

This operation computes how we weight the input hidden vectors h_j (this could be bidirectional, in which case we concatenate the forward and backward hidden states). Naturally, if the input and output hidden states are ‘aligned’ then \alpha_{ij} would be quite high for those states. In practice, more than one input word could be aligned with their output counterparts. For example, for some words in English, we will have a direct correspondences in French (de == of; le, la == the), but some words can have multiple correspondences (je ne suis pas == I am not), so the alignment should register these features. Here we know that (je/i) form a pair; (suis/am) form another pair, but we also know that (ne suis pas/am not) should occur together, and they will have non-zero alphas when we group them together, not to mention the difference in sequence lengths.

The attention/alignment parameters \alpha_{ij} are computed as a non-linear function of the hidden units, yielding an attention parameter a_{ij} which is then softmaxed to make it lie between 0 and 1.

a_{ij} = f(s_{i-1}, h_j) = v' \tanh(W_1 s_{i-1} + W_2 h_j + b)

\alpha_{ij} = softmax(a_{ij}) = \frac{\exp(a_{ij})}{\sum_j \exp(a_{ij})}

Once we compute the context c_i we can make it produce predictions for the output hidden states

s_i = F(s_{i-1}, c_i, y_i)

Alternative forms

Since attention/alignment is essentially a similarity measure between a decoder and encoder hidden vector, we can invoke dot products to compute it.

a_{ij} = F(s_{i-1}, h_j) = <h_j^T, s_{i-1}>

where <a,b> is the dot product between vectors a,b.

A development on this idea (Luong’s multiplicative attention) is to transform the vector before doing the dot product.

a_{ij} = <h_j^T, Fs_{i-1}>

The form F s_{i-1} indicates that we can apply a linear transformation to the decoder hidden unit without a bias term and then take dot product (which in torch would be through torch.bmm() for batched quantities).

[From Luong’s paper]

Screenshot_20181208_104624

 

Screenshot_20181208_105711
Multiplicative attention with components for location and content.

attn

Computing the hidden decoder state

As we can make out from the equations above, we would like to formulate the decoder state as an RNN. One way of doing that is to concatenate the input with the context and compute the next hidden state.

s_i = RNN(s_{i-1}, [c_i, y_i])

For example, in the Tacotron code, we have a multilayer decoder stack, with the first layer being the so-called attention RNN which is exactly what we have above. But there area two layers after this with residual connections.

s_i^1 = RNN^1(s_i^1, s_i) + s_i

s_i^2 = RNN^2(s_i^2, s_i^1) + s_i^1

Annotations and bidirectionality

The Bahdanau paper uses a bidirectional RNN for the encoder. This computes hidden units for the sequence with the normal ordering (left to right) and reversed ordering (right to left) (see CS224D notes by Richard Socher).

\overrightarrow{h}_i = f(\overrightarrow{W} x_i + \overleftarrow{V} h_{i-1} + \overrightarrow{b}) \\ \overleftarrow{h}_i = f(\overleftarrow{W} x_i + \overleftarrow{V} h_{i+1} + \overleftarrow{b})

The so called ‘annotations’ h are a concatenation of the forward and backward hidden vectors which are then used to compute context vectors.

h_j = concat(\overrightarrow{h}_j, \overleftarrow{h}_j) = [\overrightarrow{h}_j, \overleftarrow{h}_j] \\ c_i= \sum_j \alpha_{ij} h_j

Slide1

 

 

References

1. Cho et al. 2014 (https://arxiv.org/abs/1406.1078)
2. Bahdanau et al. 2014 (https://arxiv.org/abs/1409.0473)
3. Vinyals et al. 2014 (https://arxiv.org/abs/1412.7449)
4. Sutskever et al. 2014 (https://arxiv.org/abs/1409.3215)
5. Goodfellow et al. 2013 (https://arxiv.org/abs/1302.4389)
6. Tacotron: https://google.github.io/tacotron/
7. Luong et al.: https://arxiv.org/abs/1508.04025
8. CS224D: https://cs224d.stanford.edu/lecture_notes/LectureNotes4.pdf

 

Advertisements

1 thought on “Bahdanau attention”

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s