Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Transformers

Open in Colab

Overview

In this notebook we will build an intuitive picture of how transformers work. If you would like a more visual and detailed companion, I strongly recommend the 3Blue1Brown video Attention in Transformers, Step-by-Step and the original paper Attention Is All You Need.

We will cover the basics of the following concepts:

  • what a sequence looks like to a machine,

  • why attention is useful,

  • what query, key, and value mean,

  • how self-attention is computed,

  • why positional encoding is needed,

  • how multi-head attention works,

  • what a transformer block does.

We are here for a good time, not a long time. So the goal is not to prove every theorem, but to build a strong physical intuition for transformers.

A quick motivation

Older sequence models read tokens one by one. That can work, but it makes long-range interactions difficult.

Transformers use attention.

Very crudely:

A transformer lets every token look at every other token and decide what matters.

That is the core reason transformers became so powerful in language, vision, scientific modeling, and beyond.

# Basic imports
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display
from matplotlib.patches import FancyArrowPatch, Rectangle, Circle
from matplotlib import cm

%matplotlib inline

np.random.seed(7)
plt.rcParams["figure.figsize"] = (10, 5)
plt.rcParams["font.size"] = 12

Tokens: how a machine sees a sentence

Take the sentence

“the cat sat on the mat”

A human sees words with meaning. A machine first sees tokens. Then each token is turned into a vector called an embedding.

Let us animate that idea.

❓ Exercise

Q11: Write a tokenizer function that takes any sentence and converts it into tokens in a simple transformer-style format.

Your function should:

  • convert the sentence to lowercase,

  • separate punctuation into standalone tokens,

  • remove extra spaces,

  • split the sentence into tokens.

Test it on:

sentence = "The cat sat on the mat."

in the following code block

Click to show answer

import re

def tokenizer(sentence): “”" Basic tokenizer for transformer ideas. NOTE: The original transformer tokens also have some special tokens like [CLS] and [SEP]

What it does:
1. lowercases the text
2. separates punctuation into standalone tokens
3. collapses repeated spaces
4. splits into tokens

"""
sentence = sentence.lower().strip()

# Separate punctuation so punctuation becomes its own token
sentence = re.sub(r'([.,!?;:()"\'])', r' \1 ', sentence)

# Remove extra spaces
sentence = re.sub(r'\s+', ' ', sentence).strip()

tokens = sentence.split(" ")
return tokens
sentence = "The cat sat on the mat"
tokens = tokenizer(sentence)
n_tokens = len(tokens)

print("Tokens:", tokens)
print("Number of tokens:", n_tokens)
Tokens: ['the', 'cat', 'sat', 'on', 'the', 'mat']
Number of tokens: 6

Positional encoding

A transformer does not only need to know what each token is, it also needs to know where that token appears in the sequence.

Below we will do two things:

  1. build helper functions for toy embeddings and toy positional offsets, much like our tokenizer helper,

  2. use them to visualize the difference between embedding only and embedding + position.

In the plots below:

  • the black points represent the token embeddings,

  • the red points represent embedding + positional information,

  • the motion shows how positional information shifts token representations.

NOTE: These are random 2D toy vectors for illustration. They are not real transformer embeddings.

Why position matters

Without position, the following two sequences would look too similar:

  • dog bites man

  • man bites dog

The words are the same, but the order changes the meaning completely.

That is why transformers add positional information to token embeddings.

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.patches import Rectangle
from IPython.display import HTML, display, Markdown

We now create two helper functions, in the same spirit as the tokenizer:

  • make_toy_embeddings(tokens, ...) creates a toy embedding vector for each token,

  • make_toy_positions(n_tokens, ...) creates a position-dependent offset.

We then add them together to form a very simple picture of positional encoding.

These are toy vectors for visualization only.

def make_toy_embeddings(tokens, seed=7, scale=0.65):
    """
    Create a random 2D toy embedding for each token.
    This is only for visualization.
    """
    rng = np.random.default_rng(seed)
    return rng.normal(loc=0.0, scale=scale, size=(len(tokens), 2))


def make_toy_positions(n_tokens, x_step=0.22, amplitude=0.18, frequency=0.9):
    """
    Create a smooth 2D positional offset for each token position.
    """
    positions = np.arange(n_tokens, dtype=float)
    return np.column_stack([
        x_step * positions,
        amplitude * np.sin(frequency * positions)
    ])


def build_embedding_lookup(vocab, seed=12, scale=0.75):
    """
    Create one shared embedding lookup so that the same word gets the same
    base embedding even if it appears in different sentences.
    """
    rng = np.random.default_rng(seed)
    return {
        word: rng.normal(loc=0.0, scale=scale, size=2)
        for word in sorted(vocab)
    }


def tokens_to_embeddings(tokens, embedding_lookup):
    return np.array([embedding_lookup[tok] for tok in tokens], dtype=float)


# Use the current token list from the earlier tokenizer example
emb = make_toy_embeddings(tokens, seed=7)
pos = make_toy_positions(n_tokens)
combined = emb + pos

Eembedding only vs Embedding + Position"

fig_width = max(10, 1.2 * n_tokens)
fig, ax = plt.subplots(figsize=(fig_width, 5.5))

# Dynamic axis limits so this works for any token length
all_points = np.vstack([emb, combined])
xmin, ymin = all_points.min(axis=0) - 0.5
xmax, ymax = all_points.max(axis=0) + 0.5

ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)
ax.set_xlabel("feature 1")
ax.set_ylabel("feature 2")
ax.set_title("From token embeddings to embeddings + position")
ax.grid(alpha=0.3)

# Original embeddings
sc1 = ax.scatter(emb[:, 0], emb[:, 1], s=110, label="embedding only")

# Labels for original embeddings
base_texts = []
for i, tok in enumerate(tokens):
    base_texts.append(
        ax.text(
            emb[i, 0] + 0.03,
            emb[i, 1] + 0.03,
            tok,
            fontsize=10,
            alpha=0.7
        )
    )

# Moving points: embedding + position
sc2 = ax.scatter([], [], s=110, label="embedding + position")

# Labels for moving points
moving_texts = [ax.text(0, 0, "", fontsize=10, color="darkred") for _ in range(n_tokens)]

ax.legend(loc="upper left")

frames = 45

def update(frame):
    t = frame / (frames - 1)
    current = emb + t * (combined - emb)

    sc2.set_offsets(current)

    for i in range(n_tokens):
        moving_texts[i].set_position((current[i, 0] + 0.03, current[i, 1] - 0.08))
        moving_texts[i].set_text(tokens[i] if frame > 3 else "")

    ax.set_title(f"From token embeddings to embeddings + position  |  blend = {t:.2f}")
    return [sc2, *moving_texts]

anim_tokens = FuncAnimation(fig, update, frames=frames, interval=140, blit=False)
plt.close(fig)
display(HTML(anim_tokens.to_jshtml()))
Loading...

Why position is needed

Below we use a shared embedding lookup so that the same word gets the same base embedding in both sentences. That means any difference we see afterward comes from position.

sentence_a = "dog bites man"
sentence_b = "man bites dog"

tokens_a = tokenizer(sentence_a)
tokens_b = tokenizer(sentence_b)

shared_lookup = build_embedding_lookup(set(tokens_a + tokens_b), seed=12)

emb_a = tokens_to_embeddings(tokens_a, shared_lookup)
emb_b = tokens_to_embeddings(tokens_b, shared_lookup)

pos_a = make_toy_positions(len(tokens_a), x_step=0.55, amplitude=0.18, frequency=1.2)
pos_b = make_toy_positions(len(tokens_b), x_step=0.55, amplitude=0.18, frequency=1.2)

combined_a = emb_a + pos_a
combined_b = emb_b + pos_b

fig, axes = plt.subplots(1, 2, figsize=(14, 5.5), sharex=True, sharey=True)

all_points = np.vstack([emb_a, emb_b, combined_a, combined_b])
xmin, ymin = all_points.min(axis=0) - 0.6
xmax, ymax = all_points.max(axis=0) + 0.6

for ax in axes:
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)
    ax.set_xlabel("feature 1")
    ax.set_ylabel("feature 2")
    ax.grid(alpha=0.3)

axes[0].set_title(f"Sentence A: {' '.join(tokens_a)}")
axes[1].set_title(f"Sentence B: {' '.join(tokens_b)}")

# Left panel
axes[0].scatter(emb_a[:, 0], emb_a[:, 1], s=130, label="embedding only")
for i, tok in enumerate(tokens_a):
    axes[0].text(emb_a[i, 0] + 0.03, emb_a[i, 1] + 0.03, tok, fontsize=11, alpha=0.7)
sc2_a = axes[0].scatter([], [], s=130, label="embedding + position")
moving_texts_a = [axes[0].text(0, 0, "", fontsize=11, color="darkred") for _ in range(len(tokens_a))]

# Right panel
axes[1].scatter(emb_b[:, 0], emb_b[:, 1], s=130, label="embedding only")
for i, tok in enumerate(tokens_b):
    axes[1].text(emb_b[i, 0] + 0.03, emb_b[i, 1] + 0.03, tok, fontsize=11, alpha=0.7)
sc2_b = axes[1].scatter([], [], s=130, label="embedding + position")
moving_texts_b = [axes[1].text(0, 0, "", fontsize=11, color="darkred") for _ in range(len(tokens_b))]

axes[0].legend(loc="upper left")
axes[1].legend(loc="upper left")

frames = 45

def update_compare(frame):
    t = frame / (frames - 1)

    current_a = emb_a + t * (combined_a - emb_a)
    current_b = emb_b + t * (combined_b - emb_b)

    sc2_a.set_offsets(current_a)
    sc2_b.set_offsets(current_b)

    for i in range(len(tokens_a)):
        moving_texts_a[i].set_position((current_a[i, 0] + 0.03, current_a[i, 1] - 0.08))
        moving_texts_a[i].set_text(tokens_a[i] if frame > 3 else "")

    for i in range(len(tokens_b)):
        moving_texts_b[i].set_position((current_b[i, 0] + 0.03, current_b[i, 1] - 0.08))
        moving_texts_b[i].set_text(tokens_b[i] if frame > 3 else "")

    fig.suptitle(
        f"Same words, different order → different position-aware representations  |  blend = {t:.2f}",
        fontsize=14
    )
    return [sc2_a, sc2_b, *moving_texts_a, *moving_texts_b]

anim_compare = FuncAnimation(fig, update_compare, frames=frames, interval=140, blit=False)
plt.close(fig)
display(HTML(anim_compare.to_jshtml()))
Loading...

The basic philosophy

Each position gets a vector.

So the model does not just see

token embedding\text{token embedding}

but rather

token embedding+position encoding.\text{token embedding} + \text{position encoding}.

That way, two identical words at different places can still be distinguished.

Self-attention: the central idea

Now we come to the main event.

For each token, the model builds three learned vectors:

  • Query (Q): what am I looking for?

  • Key (K): what kind of information do I offer?

  • Value (V): what information do I actually carry?

Then a token compares its query with every other token’s key.

Large match = large attention.

Mathematically, the attention scores are

scores=QKTdk\text{scores} = \frac{QK^T}{\sqrt{d_k}}

and after a softmax,

attention weights=softmax(QKTdk).\text{attention weights} = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right).

Finally,

output=attention weightsV.\text{output} = \text{attention weights} \cdot V.

Below, the graphics are now interactive: use the button bar to choose which query token you want to inspect.

import ipywidgets as widgets


def softmax(x, axis=-1):
    x = x - np.max(x, axis=axis, keepdims=True)
    e = np.exp(x)
    return e / np.sum(e, axis=axis, keepdims=True)


# A short sentence to keep the visuals readable
attention_sentence = "the cat sat on the mat"
tokens = tokenizer(attention_sentence)

X = np.array([
    [0.8, 0.1, 0.0, 0.2],
    [0.2, 0.9, 0.1, 0.3],
    [0.6, 0.7, 0.3, 0.2],
    [0.1, 0.4, 0.8, 0.4],
    [0.8, 0.1, 0.0, 0.2],
    [0.2, 0.5, 0.9, 0.1],
], dtype=float)

WQ = np.array([
    [0.7, 0.0, 0.2, 0.1],
    [0.1, 0.8, 0.0, 0.2],
    [0.0, 0.1, 0.9, 0.1],
    [0.2, 0.0, 0.1, 0.7],
])

WK = np.array([
    [0.6, 0.1, 0.1, 0.0],
    [0.0, 0.7, 0.1, 0.2],
    [0.1, 0.0, 0.8, 0.2],
    [0.2, 0.2, 0.0, 0.6],
])

WV = np.array([
    [1.0, 0.2, 0.0, 0.1],
    [0.0, 0.9, 0.2, 0.1],
    [0.1, 0.0, 1.0, 0.2],
    [0.2, 0.1, 0.0, 0.9],
])

Q = X @ WQ
K = X @ WK
V = X @ WV
raw_scores = Q @ K.T
scores = raw_scores / np.sqrt(K.shape[1])
A = softmax(scores, axis=1)
context = A @ V


def render_attention(query_idx=0):
    fig = plt.figure(figsize=(14, 8))
    gs = fig.add_gridspec(2, 2, width_ratios=[1.15, 1.0], height_ratios=[1.0, 1.0])

    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[1, 0])
    ax4 = fig.add_subplot(gs[1, 1])

    im1 = ax1.imshow(scores, cmap="coolwarm")
    ax1.set_title("Scaled dot-product scores $QK^T/\sqrt{d_k}$")
    ax1.set_xticks(range(len(tokens)))
    ax1.set_yticks(range(len(tokens)))
    ax1.set_xticklabels(tokens, rotation=45, ha="right")
    ax1.set_yticklabels(tokens)
    ax1.set_xlabel("Key token")
    ax1.set_ylabel("Query token")
    ax1.add_patch(Rectangle((-0.5, query_idx - 0.5), len(tokens), 1, fill=False, edgecolor="black", linewidth=2))
    for i in range(len(tokens)):
        for j in range(len(tokens)):
            ax1.text(j, i, f"{scores[i, j]:.2f}", ha="center", va="center", fontsize=9)
    fig.colorbar(im1, ax=ax1, fraction=0.046, pad=0.04)

    im2 = ax2.imshow(A, vmin=0, vmax=1, cmap="viridis")
    ax2.set_title("Attention weights after softmax")
    ax2.set_xticks(range(len(tokens)))
    ax2.set_yticks(range(len(tokens)))
    ax2.set_xticklabels(tokens, rotation=45, ha="right")
    ax2.set_yticklabels(tokens)
    ax2.set_xlabel("Key token")
    ax2.set_ylabel("Query token")
    ax2.add_patch(Rectangle((-0.5, query_idx - 0.5), len(tokens), 1, fill=False, edgecolor="white", linewidth=2))
    for i in range(len(tokens)):
        for j in range(len(tokens)):
            ax2.text(j, i, f"{A[i, j]:.2f}", ha="center", va="center", fontsize=9, color="white")
    fig.colorbar(im2, ax=ax2, fraction=0.046, pad=0.04)

    weights = A[query_idx]
    bars = ax3.bar(range(len(tokens)), weights)
    ax3.set_xticks(range(len(tokens)))
    ax3.set_xticklabels(tokens, rotation=45, ha="right")
    ax3.set_ylim(0, 1)
    ax3.set_ylabel("weight")
    ax3.set_title(f"How much query token '{tokens[query_idx]}' attends to each token")
    for bar, value in zip(bars, weights):
        ax3.text(bar.get_x() + bar.get_width() / 2, value + 0.02, f"{value:.2f}", ha="center", fontsize=9)

    ranked = np.argsort(weights)[::-1]
    top_two = ", ".join([f"{tokens[j]} ({weights[j]:.2f})" for j in ranked[:2]])
    info = (
        f"Selected query token: {tokens[query_idx]}\n\n"
        f"Query vector q_i:\n{np.round(Q[query_idx], 3)}\n\n"
        f"Raw dot products q_i · k_j:\n{np.round(raw_scores[query_idx], 3)}\n\n"
        f"Scaled scores:\n{np.round(scores[query_idx], 3)}\n\n"
        f"Softmax weights:\n{np.round(weights, 3)}\n\n"
        f"Top attended tokens: {top_two}\n\n"
        f"Context vector (weighted sum of V):\n{np.round(context[query_idx], 3)}"
    )
    ax4.text(0.02, 0.98, info, va="top", ha="left", fontsize=11, family="monospace")
    ax4.set_title("Step-by-step view for the selected query")
    ax4.axis("off")

    fig.suptitle("Interactive self-attention explorer", fontsize=15)
    plt.tight_layout()
    plt.show()


widgets.interact(
    render_attention,
    query_idx=widgets.ToggleButtons(
        options=[(f"{i}: {tok}", i) for i, tok in enumerate(tokens)],
        description="Query token:",
    )
)
Loading...
<function __main__.render_attention(query_idx=0)>

What did we just see?

Each row of the attention matrix belongs to one query token.

  • The row entries tell us how much that token attends to every token in the sequence.

  • The row sums to 1 after softmax.

  • The output for that token is a weighted average of value vectors.

Why a dot product?

The dot product measures how well two vectors point in the same direction.

  • large positive dot product → strong match,

  • dot product near zero → weak match,

  • negative dot product → mismatch.

So when we compute q_i · k_j, we are asking:

How relevant is token j to the current query token i?

Why divide by dk\sqrt{d_k}?

If the key/query dimension is large, dot products can become numerically large as well. Then the softmax becomes too sharp and unstable.

Dividing by dk\sqrt{d_k} keeps the scores in a more reasonable range, so learning stays smoother.

So attention is not magic. It is a learned weighted information-routing rule.

A minimal NumPy implementation of scaled dot-product attention

Let us now compute attention directly in code, and print the intermediate objects so that the formula turns into something concrete.

def scaled_dot_product_attention(X, WQ, WK, WV):
    Q = X @ WQ
    K = X @ WK
    V = X @ WV

    raw_scores = Q @ K.T
    scores = raw_scores / np.sqrt(K.shape[-1])
    weights = softmax(scores, axis=-1)
    output = weights @ V
    return Q, K, V, raw_scores, scores, weights, output

Q, K, V, raw_scores, scores, weights, output = scaled_dot_product_attention(X, WQ, WK, WV)

print("Input shape          :", X.shape)
print("Q, K, V shape        :", Q.shape, K.shape, V.shape)
print("Attention shape      :", weights.shape)
print("Output shape         :", output.shape)

query_idx = 1
print(f"\nInspecting query token: '{tokens[query_idx]}'")
print("Raw dot products     :", np.round(raw_scores[query_idx], 3))
print("Scaled scores        :", np.round(scores[query_idx], 3))
print("Attention weights    :", np.round(weights[query_idx], 3))
print("Context vector       :", np.round(output[query_idx], 3))

print("\nFull attention matrix:\n")
print(np.round(weights, 3))
Input shape          : (6, 4)
Q, K, V shape        : (6, 4) (6, 4) (6, 4)
Attention shape      : (6, 6)
Output shape         : (6, 4)

Inspecting query token: 'cat'
Raw dot products     : [0.363 0.763 0.749 0.646 0.363 0.621]
Scaled scores        : [0.181 0.382 0.374 0.323 0.181 0.31 ]
Attention weights    : [0.149 0.182 0.18  0.171 0.149 0.169]
Context vector       : [0.517 0.536 0.456 0.375]

Full attention matrix:

[[0.166 0.159 0.172 0.168 0.166 0.167]
 [0.149 0.182 0.18  0.171 0.149 0.169]
 [0.149 0.17  0.18  0.175 0.149 0.176]
 [0.139 0.164 0.173 0.191 0.139 0.193]
 [0.166 0.159 0.172 0.168 0.166 0.167]
 [0.138 0.164 0.175 0.19  0.138 0.196]]

A useful interpretation

The output sequence has the same number of tokens, but the representation of each token has changed.

Each token now carries not only its own information, but also a summary gathered from the rest of the sequence.

For one specific token, the computation goes like this:

  1. form the token’s query vector,

  2. compare it with all key vectors using dot products,

  3. scale the scores by dk\sqrt{d_k},

  4. apply softmax to get nonnegative weights summing to 1,

  5. use those weights to take a weighted sum of the value vectors.

So the output token is literally a mixture of information pulled from the whole sentence.

Multi-head attention

One head is useful. Many heads are better.

Why?

Because different heads can learn different relationships:

  • local grammar,

  • long-range dependencies,

  • subject-verb agreement,

  • semantic similarity,

  • special token interactions.

Below we make the visualization more elaborate:

  • three toy heads,

  • a heatmap for each head,

  • an average-attention view,

  • and a query-token selector so you can compare how different heads focus on different parts of the sentence.

# Define three toy heads with slightly different projections
WQ1, WK1, WV1 = WQ, WK, WV

WQ2 = np.array([
    [0.5, 0.2, 0.1, 0.1],
    [0.2, 0.4, 0.3, 0.0],
    [0.1, 0.1, 0.5, 0.3],
    [0.2, 0.0, 0.2, 0.6],
])
WK2 = np.array([
    [0.4, 0.2, 0.1, 0.1],
    [0.1, 0.6, 0.2, 0.0],
    [0.2, 0.1, 0.5, 0.2],
    [0.0, 0.2, 0.2, 0.5],
])
WV2 = np.array([
    [0.9, 0.1, 0.2, 0.0],
    [0.1, 0.8, 0.1, 0.2],
    [0.0, 0.2, 0.9, 0.1],
    [0.2, 0.0, 0.1, 0.8],
])

WQ3 = np.array([
    [0.6, 0.1, 0.0, 0.2],
    [0.0, 0.5, 0.2, 0.2],
    [0.2, 0.0, 0.6, 0.1],
    [0.1, 0.2, 0.1, 0.5],
])
WK3 = np.array([
    [0.5, 0.0, 0.2, 0.1],
    [0.1, 0.5, 0.1, 0.1],
    [0.0, 0.2, 0.6, 0.1],
    [0.2, 0.1, 0.0, 0.5],
])
WV3 = np.array([
    [0.8, 0.2, 0.1, 0.0],
    [0.0, 0.7, 0.2, 0.2],
    [0.1, 0.1, 0.8, 0.2],
    [0.2, 0.0, 0.2, 0.7],
])

heads = []
for idx, (WQh, WKh, WVh) in enumerate([(WQ1, WK1, WV1), (WQ2, WK2, WV2), (WQ3, WK3, WV3)], start=1):
    Qh, Kh, Vh, raw_h, score_h, A_h, O_h = scaled_dot_product_attention(X, WQh, WKh, WVh)
    heads.append({
        "name": f"Head {idx}",
        "Q": Qh, "K": Kh, "V": Vh,
        "raw": raw_h, "scores": score_h,
        "A": A_h, "O": O_h,
    })

mean_attention = np.mean([h["A"] for h in heads], axis=0)


def render_multihead(query_idx=0):
    fig = plt.figure(figsize=(17, 8))
    gs = fig.add_gridspec(2, 4, height_ratios=[1.0, 1.0], width_ratios=[1, 1, 1, 1.1])

    heatmap_axes = [fig.add_subplot(gs[0, i]) for i in range(4)]
    ax_bar = fig.add_subplot(gs[1, :3])
    ax_text = fig.add_subplot(gs[1, 3])

    cmap_list = ["magma", "viridis", "plasma"]
    for ax, head, cmap in zip(heatmap_axes[:3], heads, cmap_list):
        A_h = head["A"]
        im = ax.imshow(A_h, vmin=0, vmax=1, cmap=cmap)
        ax.set_title(head["name"])
        ax.set_xticks(range(len(tokens)))
        ax.set_yticks(range(len(tokens)))
        ax.set_xticklabels(tokens, rotation=45, ha="right")
        ax.set_yticklabels(tokens)
        ax.set_xlabel("Keys")
        ax.set_ylabel("Queries")
        ax.add_patch(Rectangle((-0.5, query_idx - 0.5), len(tokens), 1, fill=False, edgecolor="white", linewidth=2))
        for j in range(len(tokens)):
            ax.text(j, query_idx, f"{A_h[query_idx, j]:.2f}", ha="center", va="center", fontsize=8, color="white")

    ax_mean = heatmap_axes[3]
    ax_mean.imshow(mean_attention, vmin=0, vmax=1, cmap="cividis")
    ax_mean.set_title("Average across heads")
    ax_mean.set_xticks(range(len(tokens)))
    ax_mean.set_yticks(range(len(tokens)))
    ax_mean.set_xticklabels(tokens, rotation=45, ha="right")
    ax_mean.set_yticklabels(tokens)
    ax_mean.set_xlabel("Keys")
    ax_mean.set_ylabel("Queries")
    ax_mean.add_patch(Rectangle((-0.5, query_idx - 0.5), len(tokens), 1, fill=False, edgecolor="white", linewidth=2))

    x = np.arange(len(tokens))
    width = 0.22
    for offset, head in zip([-width, 0.0, width], heads):
        weights = head["A"][query_idx]
        ax_bar.bar(x + offset, weights, width=width, label=head["name"])
    ax_bar.set_xticks(x)
    ax_bar.set_xticklabels(tokens, rotation=45, ha="right")
    ax_bar.set_ylim(0, 1)
    ax_bar.set_ylabel("attention weight")
    ax_bar.set_title(f"Selected query token: '{tokens[query_idx]}'")
    ax_bar.legend(loc="upper right")

    summaries = []
    concat_context = []
    for head in heads:
        weights = head["A"][query_idx]
        best = np.argmax(weights)
        context_vec = head["O"][query_idx]
        concat_context.extend(context_vec.tolist())
        summaries.append(
            f"{head['name']}: top focus = {tokens[best]} ({weights[best]:.2f})\ncontext = {np.round(context_vec, 3)}"
        )

    info = (
        f"Chosen query token: {tokens[query_idx]}\n\n"
        + "\n\n".join(summaries)
        + f"\n\nConcatenated multi-head vector length: {len(concat_context)}"
        + f"\nFirst few entries: {np.round(concat_context[:6], 3)}"
    )
    ax_text.text(0.02, 0.98, info, va="top", ha="left", fontsize=10.5, family="monospace")
    ax_text.set_title("How the heads differ")
    ax_text.axis("off")

    fig.suptitle("Interactive multi-head attention explorer", fontsize=15)
    plt.tight_layout()
    plt.show()


widgets.interact(
    render_multihead,
    query_idx=widgets.ToggleButtons(
        options=[(f"{i}: {tok}", i) for i, tok in enumerate(tokens)],
        description="Query token:",
    )
)
Loading...
<function __main__.render_multihead(query_idx=0)>

Why heads help

If one head tries to learn everything, it becomes too limited.

Multi-head attention allows the model to split the job:

  • one head may care about a nearby modifier,

  • another may care about a distant reference,

  • another may track a special pattern.

You can think of it as giving the model multiple parallel ways of asking:

What information matters for this token?

Then the head outputs are concatenated and linearly mixed, so the transformer gets a richer representation than any single head could provide on its own.

A transformer block

A standard transformer block usually contains:

  1. multi-head self-attention,

  2. add & norm,

  3. feed-forward network,

  4. add & norm again.

Here is a simple schematic animation.

fig, ax = plt.subplots(figsize=(12, 4))
ax.set_xlim(0, 12)
ax.set_ylim(0, 5)
ax.axis("off")

stages = [
    ("Input\nTokens", 1.0),
    ("Multi-Head\nAttention", 3.5),
    ("Add &\nNorm", 6.0),
    ("Feed\nForward", 8.5),
    ("Output", 11.0),
]

for label, x in stages:
    rect = Rectangle((x-0.9, 1.7), 1.8, 1.6, fill=False, linewidth=2)
    ax.add_patch(rect)
    ax.text(x, 2.5, label, ha="center", va="center", fontsize=12)

for i in range(len(stages)-1):
    x1 = stages[i][1] + 0.9
    x2 = stages[i+1][1] - 0.9
    arrow = FancyArrowPatch((x1, 2.5), (x2, 2.5), arrowstyle="->", mutation_scale=20, linewidth=2)
    ax.add_patch(arrow)

token_y = [1.2, 2.5, 3.8]
circles = [Circle((1.0, y), 0.12) for y in token_y]
for c in circles:
    ax.add_patch(c)

stage_positions = [1.0, 3.5, 6.0, 8.5, 11.0]
n_subframes = 20
total_frames = (len(stage_positions)-1) * n_subframes

caption = ax.text(6, 4.5, "", ha="center", va="center", fontsize=13)

def update(frame):
    seg = min(frame // n_subframes, len(stage_positions)-2)
    local = (frame % n_subframes) / (n_subframes - 1)
    x = (1 - local) * stage_positions[seg] + local * stage_positions[seg+1]

    for c, y in zip(circles, token_y):
        c.center = (x, y)

    stage_name = stages[seg+1][0].replace("\n", " ")
    caption.set_text(f"Tokens moving through the transformer block → approaching: {stage_name}")

    return [*circles, caption]

anim_block = FuncAnimation(fig, update, frames=total_frames, interval=120, blit=False, repeat=True)
plt.close(fig)

HTML(anim_block.to_jshtml())
Loading...

The feed-forward network

After attention mixes information across tokens, the feed-forward network acts within each token.

So a rough mental model is:

  • attention = communication between tokens,

  • feed-forward = private processing inside each token representation.

Why transformers scale so well

Transformers are powerful because they make parallel computation easy.

In many old recurrent approaches, token processing is fundamentally sequential. In a transformer, the attention matrix for a whole sequence can be computed together.

That made them attractive not just for language, but also for:

  • images,

  • proteins,

  • time series,

  • scientific simulators,

  • particle physics event representations.

Quick exercise

Q12: Why can a transformer not rely on embeddings alone, without positional encoding?

Click to show answer

Because embeddings encode what the token is, but not where it appears. Without position, sequences with the same bag of words could look too similar. For example, dog bites man and man bites dog contain the same words but very different meaning.

Q13: Suppose the attention weight from token A to token B becomes very large. What does that mean physically in the computation?

Click to show answer

It means the output representation of token A will include a large contribution from the value vector of token B. So A is deciding that B carries useful information for updating A.

Summary

A transformer works through a few simple ideas:

  • Turn tokens into vectors.

  • Add positional information.

  • Build query, key, and value vectors.

  • Use query-key similarity to form attention weights.

  • Use those weights to combine values.

  • Repeat through many layers and heads.

The mathematics is compact, but the behavior is rich.

In one sentence: a transformer is a machine for learning who should pay attention to whom, and by how much.

For more...

Good things to try with this knowledge:

  • replace the toy sentence with your own tokens,

  • modify the projection matrices,

  • change the embedding dimension,

  • mask future tokens and turn this into causal attention,

  • compare self-attention with convolution or recurrence,

  • try the same logic on physics objects instead of words.

That is where the real fun begins.