In this read, we are modifying the previous self-attention mechanism into a causal self-attention mechanism.
I would highly recommend you to give it a read, if you haven’t till now 👇🏻:
Scaled Dot-Product Attention Explained!
Previously, I have covered a high-level overview about the Simple Attention Mechanism without Trainable Weights. Missed it? (really?) Go check it out.
And while you’re at it, subscribe to me so you’ll never miss any more of these reads.
What is Causal Self-Attention Mechanism?
Causal or Masked self-attention ensures that the model's prediction for a certain position in a sequence is only dependent on the known previous outputs, and not on future tokens.
How to Hide Future Tokens with Causal Attention?
In causal attention, the attention weights above the diagonal are masked.
This is to ensure that for any given input, the LLM is unable to utilize future tokens while calculating the context vectors with the attention weight.
To achieve this, for each given token, we mask out the future tokens (the ones that come after the current token in the input text).
The simplest way is to mask out the unnormalized attention scores above the diagonal with negative infinity before they enter the softmax function.
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)
"""
Output:
tensor([[0.0988, -inf, -inf, -inf, -inf],
[0.1345, 0.1951, -inf, -inf, -inf],
[0.1330, 0.1957, 0.2100, -inf, -inf],
[0.1985, 0.2716, 0.2985, 0.1392, -inf],
[0.0688, 0.1003, 0.1079, 0.0486, 0.1117]],
grad_fn=<MaskedFillBackward0>)
"""
Note: The attention weights in each row correctly sum to 1.
Masking Additional Attention Weights With Dropout
In addition, we also apply dropout to reduce overfitting during training of the LLM.
Dropout is a DL technique where randomly selected hidden layer units are ignored during the training of the Neural Network for preventing overfitting and improving generalization.
In GPT Models, dropout can be applied in several places, such as:
after computing the attention weights
or, after multiplying the attention weights with the value vectors
Though, it is recommended to apply the dropout mask after computing the attention weights.
Furthermore, in this specific example, we use a dropout rate of 50%, which means randomly masking out half of the attention weights.
Later while training the GPT model, we will use a lower dropout rate, such as 0.1 or 0.2.
If we apply a dropout rate of 0.5 (50%), the non-dropped values will be scaled accordingly by a factor of:
1 / (1 - dropout_rate
) = 1/0.5 = 2
This scaling is crucial to maintain the overall balance of the attention weights, ensuring that the average influence of the attention mechanism remains consistent during both the training and inference phases.
Implementing CausalAttention Class
Now, we are ready to implement a self-attention class, including the causal and dropout masks.
Step 1: Compared to the previous SelfAttention class, we added a dropout layer.
Step 2: The register_buffer
call is also a new addition.
Step 3: Transpose dimensions 1 and 2, keeping the batch dimension at the first position (i.e., 0).
class CausalAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
super().__init__()
self.d_out = d_out
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape # New batch dimension b
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
attn_scores = queries @ keys.transpose(1, 2) # Changed transpose
attn_scores.masked_fill_( # Note `_` operations are in-place in PyTorch
self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1
)
attn_weights = self.dropout(attn_weights)
context_vec = attn_weights @ values
return context_vec
Note: In PyTorch, operations with a trailing underscore (_) are performed in-place, avoiding unnecessary memory copies.
Instantiating CausalAttention
Class:
print(d_in)
print(d_out)
torch.manual_seed(123)
# Instantiating CausalAttention Class
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)
"""
Output:
3
2
tensor([[[-0.3585, 0.0415],
[-0.4451, 0.0252],
[-0.4656, 0.0393],
[-0.5461, -0.0500],
[-0.5129, -0.0367]],
[[-0.3585, 0.0415],
[-0.4451, 0.0252],
[-0.4656, 0.0393],
[-0.5461, -0.0500],
[-0.5129, -0.0367]]], grad_fn=<UnsafeViewBackward0>)
context_vecs.shape: torch.Size([2, 5, 2])
"""
As we can see, the resulting context vector is a 3D tensor where each token is now represented by a 2D embedding.
Note: Dropout is only applied during training, not during inference.
If you’d like to explore the full implementation, including code and data, then checkout: Github Repository 👈🏻
And that’s a wrap!
Next, we will expand on this concept and implement a multi-head attention module, that implements several of such causal attention mechanisms in parallel.
If you’ve made it this far — thank you so much, stay tuned with ME, so you won’t miss out on future updates.
Until next time, happy learning!