Chasing Copy Heads in a Residual Streams Avenue

WIP

I’ve been looking into attention mechanics recently, and this is my list of questions I want to answer in relation to copy heads to document my understanding:

1. Definitions

copy heads, the OV circuit, frobenius distance, identity matrix, singular value decomposition, etc.

2. Background of all this

3. Importance of copy heads across different layers

4. Critical questions

All answers to be expanded with concrete examples (python, torch, visual diagrams).

Got nerd sniped by copy heads; I couldn’t stop thinking of this meme in the process & thought this applied to them lol

do-nothing-win

I found copy heads pretty interesting as I was going through some mechanical interpretability basics; I had jotted down some questions on a piece of paper for testing myself at the end of my learning to get a better intuition behind this concept, and I write this log answering all those questions as a reinforcement of my understanding.

Definitions

Before diving into the mechanics behind copy heads across different transformer layers and their importance, some definitions are necessary to establish the foundations. Some key definitions coming up.

Copy heads

Copy heads are attention heads that are identity matrices or are close to being identity matrices. Inside a multi-head attention phase, we split the embedded dimensions into specific number of heads and do a parallel computation: in each head, we calculate the attention weights, do a matmul between attention weights and the V (here V = x @ W_V) to calculate attention output and then when we do another matmul between the attention output matrix and a projection matrix W_O, we get what we call an output projection.

Sometimes, an attention head barely transforms any information that it receives from the residual stream; what it outputs could simply be a copy (or close to a copy) of what it receives from the residual stream; in other words, a weighted average of the input and nothing more, nothing less. This sort of an attention head is called a copy head.

In order to get a clear picture behind a copy head, I think it’s a good idea to get a quick refresher behind how we reach the point of an attention head.

Refresher: How did we reach this rabbit hole?

An attention head, in a transformer, exists inside a layer (or a block), and is a factor of the transformer’s embedding dimension (d_model in the original transformer paper) that stores rich representations & semantic features during the training.

I need to add an illustration of a transformer block here
and also an illustration of attention heads

Let’s assume that I have an input_text called: “I like hot”.

This is an input word, at the very start of the process, before any embeddings, before any forward passes.

Parameters that I’d like to assume for the sake of this example:

batch = 1 (since we have but a single sentence)
seqlen = 3 (sequence length the size of input text)
d_model = 32 (this will be the embedding dimension for every word, i.e., input token id)
n_layers = 8 (total number of transformer blocks)
n_heads = 4 (number of heads in each layer, factor of d_model, will expand more soon) vocab_size = 100

In a transformer, the input_text gets tokenized, which is simply a method to assign unique numeric representation to words or characters. Tokenization can be done via various methods, and not all of them break down a sequence of words (i.e., the input_text here) 1:1 into tokens but for the sake of this worklog, let me assume that each word from the input_text above corresponds to 1 token, kind of like:

tokenizer = {
    "I": 1,
    "like": 2,
    "hot": 3,
} # a simple lookup table, basically a hash map/dict

This is still, to some extent, an oversimplification, but let’s think of tokenization as a step where we create a lookup table (like a dictionary) of words (i.e., tokens) and their unique IDs.

At this point, if we think of this input sequence of unique token IDs as a tensor, it would be:

tokenized_input = torch.tensor([1, 2, 3])
tokenized_input.shape # torch.Size([3])

Transformers do computation parallelly in batches, so generally, inputs are tokenized in batches as well. Basically, a single ‘batch’ represents a collection of tokens, and in a normal LLM, during training, there are countless number of batches, and each batch holds an enormous size of text (tons of sentences). In our case, since we have but just input_text, we are assuming that we have 1 batch, and that means our sequence length will be the length of our tokens, i.e., 3.

So I unsqueeze the tokenized_input to better represent it as [batch, seqlen]:

tokenized_input = tokens.unsequeeze(0)
tokenized_input # torch.tensor([[1, 2, 3]])
tokenized_input.shape # torch.Size([1, 3])

So this new tokenized_input represents the tokenized text that has 1 batch, and the batch’s sequence length is 3.

Moving on, we initialize random embedding layers/dimensions to each token. In the original tranformer paper, there were 512 embedding dimentions (d_model) assigned to each token, but for demonstration, let me simply start off with 32 dimensions per token.

In PyTorch, this is usually done via the torch.nn.Embedding module, where it could be something like:

embedding = torch.nn.Embedding(num_embeddings=37000, embedding_dim=512)
embedding.weight.shape # torch.Size([37000, 512])

token_embedding = embedding(tokenized_input)
token_embedding.shape # torch.Size([1, 3, 512])

Here, 37000 was the vocab size in the original transformer, and each vocab token had a dimension of 512. torch.nn.Embedding creates a lookup table (similar to the tokenizer above) of a certain vocabulary size (that we determine based on the kind of tokenizer we have created) where each token ID has its own unique d_model-dimensional vector. Nothing fancy going on here.

In order to assign random embedding weights for my example, let me think of a vocabulary number. From what I’ve read, it seems that vocab number is generally determined first, and then the tokenizer creates those amounts of tokens and their token ID representations. For a model, any token ID that does not fall inside the tokenizer’s vocab hash map will lead to the model not recognizing it. For me, if I look at the tokenizer has map from above, my vocab_size will be just 3, but let me extend it a little more and create a vocab_size of 100:

tokenizer = {
    "nice": 0,
    "I": 1,
    "like": 2,
    "hot": 3,
    ...
    "?": 98,
    "coffee": 99
}

Right, so this means I can simply create my token embedding for the tokenized_input as follows:

embedding = torch.nn.Embedding(100, 32)
token_embedding = embedding(tokenized_input)
token_embedding.shape # [1, 3, 32]

So each token now also has its vector weight representation of 32 dimensions.

We also add positional encoding from here, which I won’t get into details, but the shape of token_embedding remains the same ([1, 3, 32]) as a result, and finally, what we have is ready to be fed as the first residual stream to the first multihead attention of the first layer of the transformer.

These calculations (token ID calculation, tokenization + vocab prep, embedding, positional encoding) are not done more than once. After the resulting embedded + positionally encoded input IDs get fed into the first sublayer, it is, as the paper[1] I’m reading claims, a residual stream that flows sequentially across each layer of the transformer (layer = transformer block).

Making some progress here. Now, this token_embedding goes into the first MHA of the first layer, and something like the following takes place:

import torch.nn as nn


class MultiHeadAttention(nn.Module):
    def __init__(self, dtype, n_heads: int = 4, d_model: int = 32) -> None:
        super().__init__()
        assert d_model % n_heads == 0, "n_heads should be a factor of d_model"
        self.dtype = dtype
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_heads = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model, dtype=self.dtype)
        self.W_k = nn.Linear(d_model, d_model, dtype=self.dtype)
        self.W_v = nn.Linear(d_model, d_model, dtype=self.dtype)
        self.W_o = nn.Linear(d_model, d_model, dtype=self.dtype)

    def forward(self, x): # x == token_embedding
        if x.dtype != self.dtype:
            x.dtype = x.to(self.dtype)
        batch, seqlen, d_model = x.shape
        q = self.W_q(x)
        k = self.W_k(x)
        v = self.W_v(x)

        q = q.view(batch, seqlen, self.n_heads, self.d_heads).transpose(1, 2)
        k = k.view(batch, seqlen, self.n_heads, self.d_heads).transpose(1, 2)
        v = v.view(batch, seqlen, self.n_heads, self.d_heads).transpose(1, 2)

        attn_scores = q @ k.transpose(-1, -2)
        attn_weights = torch.softmax(attn_scores / self.d_heads**0.5, dim=-1)

        attn_output = attn_weights @ v
        attn_output_concat = (
            attn_output.transpose(1, 2).contiguous().view(batch, seqlen, d_model)
        )

        proj_output = self.W_o(attn_output_concat)
        return (
            attn_weights,
            proj_output,
        )

This is a very simple multi-head attention (MHA) mechanism, and this sub-layer exists in each transformer block. Each MHA mechanism involves splitting the provided x (token_embedding) into n_heads number of heads to do parallel attention computation (attn_output in the code block above) using d_heads that are of size d_model // n_heads (followed immediately by concatenating all these outputs to rejoin all the split heads — basically, in a way, doing n_heads * d_heads to get back the original d_model — and do a matmul with the projection weight W_O to get the final proj_output). So, in my demo, d_model is 32, and if I choose n_heads as 4, my d_heads (the dimension that each head gets) becomes 8.

We’re now with 4 heads per transformer layer, and each head is doing 8 computations in parallel. But is each head that has been computed and concatenated with the other heads doing meaningful transformation with new learned features and semantics, or could it just be calculating a weighted average of the provided input x and nothing more?

contd…


[1]A Mathematical Framework for Transformer Circuits, Anthropic, 2021. https://transformer-circuits.pub/2021/framework/index.html