QsTKt. Note that for this definition to make sense, both the query Qs and the key Kt need to live in the same space and d is the dimension of this space.
Given s inputs in Rdin denoted by a matrix X∈Rdin×s and a database containing t samples in Rd′ denoted by a matrix X′∈Rd′×t, we define:
Now self-attention is simply obtained with X=X′ (so that d′=din) and din=dout=d. In summary, self-attention layer can take as input any tensor of the form X∈Rd×T (for any T) has parameters:
WQ∈Rk×d,WK∈Rk×d,WV∈Rd×d,
and produce Y∈Rd×T (with same d and t as for the input). d is the dimension of the input and k is a hyper-parameter of the self-attention layer:
Ys=t=1∑Tsoftmaxt(kXsTWQTWKXt)WVXt,
with the convention that Xt∈Rd (resp. Ys∈Rd) is the t-th column of X (resp. the s-th column of Y). Note that the notation softmaxt(.) might be a bit confusing. Recall that softmax is always taking as input a vector and returning a (normalized) vector. In practice, most of the time, we are dealing with batches so that the softmax function is taking as input a matrix (or tensor) and we need to normalize according to the right axis! Named tensor notation see below deals with this notational issue. I also find the interpretation given below helpful:
Mental model for self-attention: self-attention interpreted as taking expectation
where the mappings q(.),k(.) and v(.) represent query, key and value.
Multi-head attention combines several such operations in parallel, and Y is the concatenation of the results along the feature dimension to which is applied one more linear transformation.
and then the Layer Norm has two parameters γ,β∈Rd and
LN(x)=γ⋅std(x)x−mean(x)+β,
where we used the natural broadcasting rule for subtracting the mean and dividing by std and ⋅ is component-wise multiplication.
A Feed Forward Network is an MLP acting on vectors: for x∈Rd, we define
FFN(x)=max(0,xW1+b1)W2+b2,
where W1∈Rd×h, b1∈Rh, W2∈Rh×d, b2∈Rd.
Each of these layers is applied on each of the inputs given to the transformer block as depicted below:
Note that this block is equivariant: if we permute the inputs, then the outputs will be permuted with the same permutation. As a result, the order of the input is irrelevant to the transformer block. In particular, this order cannot be used. The important notion of positional encoding allows us to take order into account. It is a deterministic unique encoding for each time step that is added to the input tokens.