Module 12 - Attention and Transformers

Table of Contents

Attention with RNNs

The first attention mechanism was proposed in Neural Machine Translation by Jointly Learning to Align and Translate by Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio (presented at ICLR 2015).

The task considered is English-to-French translation and the attention mechanism is proposed to extend a seq2seq architecture by adding a context vector cic_i in the RNN decoder so that, the hidden states for the decoder are computed recursively as si=f(si1,yi1,ci)s_i = f(s_{i-1}, y_{i-1}, c_i) where yi1y_{i-1} is the previously predicted token and predictions are made in a probabilist manner as yig(yi1,si,ci)y_i \sim g(y_{i-1},s_i,c_i) where sis_i and cic_i are the current hidden state and context of the decoder.

Now the main novelty is the introduction of the context cic_i which is a weighted average of all the hidden states of the encoder: ci=j=1Tαi,jhjc_i = \sum_{j=1}^T \alpha_{i,j} h_j where TT is the length of the input sequence, h1,,hTh_1,\dots, h_T are the corresponding hidden states of the decoder and jαi,j=1\sum_j \alpha_{i,j}=1. Hence the context allows passing direct information from the 'relevant' part of the input to the decoder. The coefficients (αi,j)j=1T(\alpha_{i,j})_{j=1}^T are computed from the current hidden state of the decoder si1s_{i-1} and all the hidden states from the encoder (h1,,hT)(h_1, \dots, h_T) as explained below (taken from the original paper):

PyTorch implementation

In Attention for seq2seq, you can play with a simple model and code the attention mechanism proposed in the paper. For the alignment network aa (used to define the coefficient αi,j=softmaxj(a(si1,hj))\alpha_{i,j} = softmax_{j}(a(s_{i-1},h_j))), we take a MLP with tanh\tanh activations.

You will learn about seq2seq, teacher-forcing for RNNs and build the attention mechanism. To simplify things, we do not deal with batches (see Batches with sequences in Pytorch for more on that). The solution for this practical is provided in Attention for seq2seq- solution

Note that each αi,j\alpha_{i,j} is a real number so that we can display the matrix of αi,j\alpha_{i,j}'s where jj ranges over the input tokens and ii over the output tokens, see below (taken from the paper):

(Self-)Attention in Transformers

We now describe the attention mechanism proposed in Attention Is All You Need by Vaswani et al. First, we recall basic notions from retrieval systems: query/key/value illustrated by an example: search for videos on Youtube. In this example, the query is the text in the search bar, the key is the metadata associated with the videos which are the values. Hence a score can be computed from the query and all the keys. Finally, the matched video with the highest score is returned.

We see that we can formalize this process as follows: if QsQ_s is the current query and KtK_t and VtV_t are all the keys and values in the database, we return

Ys=t=1Tsoftmaxt(score(Qs,Kt))Vt, Y_s = \sum_{t=1}^T\text{softmax}_{t}(\text{score}(Q_s, K_t))V_t,

where t=1Tsoftmaxt(score(Qs,Kt))=1\sum_{t=1}^T\text{softmax}_{t}(\text{score}(Q_s, K_t))=1.

Note that this formalism allows us to recover the way contexts were computed above (where the score function was called the alignment network). Now, we will change the score function and consider dot-product attention: score(Qs,Kt)=QsTKtd \text{score}(Q_s, K_t) = \frac{Q_s^TK_t}{\sqrt{d}}. Note that for this definition to make sense, both the query QsQ_s and the key KtK_t need to live in the same space and dd is the dimension of this space.

Given ss inputs in Rdin\mathbb{R}^{d_{\text{in}}} denoted by a matrix XRdin×sX\in \mathbb{R}^{d_{\text{in}}\times s} and a database containing tt samples in Rd\mathbb{R}^{d'} denoted by a matrix XRd×tX'\in \mathbb{R}^{d'\times t}, we define:

the queries: Q=WQX, with, WQRk×dinthe keys: K=WKX, with, WKRk×dthe values: V=WVX, with, WVRdout×d \text{the queries: } Q = W_Q X, \text{ with, } W_Q\in \mathbb{R}^{k\times d_{\text{in}}}\\ \text{the keys: } K = W_K X', \text{ with, } W_K\in \mathbb{R}^{k\times d'}\\ \text{the values: } V = W_V X', \text{ with, } W_V\in \mathbb{R}^{d_{\text{out}}\times d'}

Now self-attention is simply obtained with X=XX=X' (so that d=dind'=d_{\text{in}}) and din=dout=dd_{\text{in}} = d_{\text{out}} = d. In summary, self-attention layer can take as input any tensor of the form XRd×TX \in \mathbb{R}^{d\times T} (for any TT) has parameters:

WQRk×d,WKRk×d,WVRd×d, W_Q\in \mathbb{R}^{k\times d}, W_K\in \mathbb{R}^{k\times d}, W_V\in \mathbb{R}^{d\times d},

and produce YRd×TY \in \mathbb{R}^{d\times T} (with same dd and tt as for the input). dd is the dimension of the input and kk is a hyper-parameter of the self-attention layer:

Ys=t=1Tsoftmaxt(XsTWQTWKXtk)WVXt, Y_s = \sum_{t=1}^T\text{softmax}_{t}\left(\frac{X_s^TW_Q^TW_KX_t}{\sqrt{k}}\right)W_VX_t,

with the convention that XtRdX_t\in \mathbb{R}^d (resp. YsRdY_s\in \mathbb{R}^d) is the tt-th column of XX (resp. the ss-th column of YY). Note that the notation softmaxt(.)\text{softmax}_{t}(.) might be a bit confusing. Recall that softmax\text{softmax} is always taking as input a vector and returning a (normalized) vector. In practice, most of the time, we are dealing with batches so that the softmax\text{softmax} function is taking as input a matrix (or tensor) and we need to normalize according to the right axis! Named tensor notation see below deals with this notational issue. I also find the interpretation given below helpful:

Mental model for self-attention: self-attention interpreted as taking expectation

ys=t=1Tp(xtxs)v(xt)=E[v(x)xs],with, p(xtxs)=exp(q(xs)k(xt))rexp(q(xs)k(xr)), y_s = \sum_{t=1}^T p(x_t | x_s) v(x_t) = \mathbb{E}[v(x) | x_s],\\ \text{with, } p(x_t|x_s) = \frac{\exp(q(x_s)k(x_t))}{\sum_{r}\exp(q(x_s)k(x_r))},

where the mappings q(.),k(.)q(.), k(.) and v(.)v(.) represent query, key and value.

Multi-head attention combines several such operations in parallel, and YY is the concatenation of the results along the feature dimension to which is applied one more linear transformation.

Transformer block

To finish the description of a transformer block, we need to define two last layers: Layer Norm and Feed Forward Network.

The Layer Norm used in the transformer block is particularly simple as it acts on vectors and standardizes it as follows: for xRdx\in \mathbb{R}^d, we define

mean(x)=1di=1dxiRstd(x)2=1di=1d(ximean(x))2R \text{mean}(x) =\frac{1}{d}\sum_{i=1}^d x_i\in \mathbb{R}\\ \text{std}(x)^2 = \frac{1}{d}\sum_{i=1}^d(x_i-\text{mean}(x))^2\in \mathbb{R}

and then the Layer Norm has two parameters γ,βRd\gamma, \beta\in \mathbb{R}^d and

LN(x)=γxmean(x)std(x)+β, LN(x) = \gamma \cdot \frac{x-\text{mean}(x)}{\text{std}(x)}+\beta,

where we used the natural broadcasting rule for subtracting the mean and dividing by std and \cdot is component-wise multiplication.

A Feed Forward Network is an MLP acting on vectors: for xRdx\in \mathbb{R}^d, we define

FFN(x)=max(0,xW1+b1)W2+b2, FFN(x) = \max(0,xW_1+b_1)W_2+b_2,

where W1Rd×hW_1\in \mathbb{R}^{d\times h}, b1Rhb_1\in \mathbb{R}^h, W2Rh×dW_2\in \mathbb{R}^{h\times d}, b2Rdb_2\in \mathbb{R}^d.

Each of these layers is applied on each of the inputs given to the transformer block as depicted below:

Note that this block is equivariant: if we permute the inputs, then the outputs will be permuted with the same permutation. As a result, the order of the input is irrelevant to the transformer block. In particular, this order cannot be used. The important notion of positional encoding allows us to take order into account. It is a deterministic unique encoding for each time step that is added to the input tokens.

LLM Visualization.

Have a look at Brendan Bycroft’s beautifully crafted interactive explanation of the transformers architecture:


Transformers using Named Tensor Notation

In Transformers using Named Tensor Notation, we derive the formal equations for the Transformer block using named tensor notation.

Hacking a simple Transformer block

Now is the time to have fun building a simple transformer block and to think like transformers (open in colab).