Skip to content

22.2 Transformer 数学基础与程序演示

22.2.1 人工智能的数学基础

2017 年,Vaswani 等人提出了一种完全基于注意力机制、舍弃循环和卷积的序列转换模型——Transformer。本节严格依据该论文原文,重构其全部数学基础,并补充必要的背景知识。

许多核心机器学习模型从根本上依靠线性代数(Linear Algebra)原理来表达和求解。实践中,数据很少以简单的单一数值形式出现,通常表现为数据集,即大量数据点的集合。线性代数提供了有效整理、处理和分析这些数据的工具,使从业者得以通过向量(Vector)、矩阵(Matrix)和张量(Tensor)等对象来表示结构化数据(Structured Data),如表格数据,以及非结构化数据,如图像或视频。

不确定性量化(Uncertainty Quantification,UQ)旨在量化并降低物理系统建模与模拟中的不确定性;在系统某些因素未知的情况下,它试图给出研究结果的置信度。

统计学家乔治·博克斯(George Box)曾言:“所有的模型都是错误的,但有些模型是有用的。”

22.2.1.1 Attention 定义

Vaswani 等人在 2017 年 Transformer 论文中提出的缩放点积注意力(Scaled Dot-Product Attention)的数学定义,其核心是加权求和(weighted sum),通过查询(Query)、键(Key)和值(Value)矩阵实现。注意力标准定义如下:

其中:

22.2.1.2 向量的定义

在数学上,向量(vector)是一个 有序数组,表示一个具有大小和方向的量。例如,二维向量可以写作:

一般来说,长度为 d 的向量:

向量之间有以下基本运算:

  • 加法:逐元素相加
  • 数乘(标量乘法):向量每个分量乘以一个标量
  • 内积(dot product):衡量两个向量的相似度

22.2.1.3 权重加权求和

设有向量 v₁,…,vₘ 和对应权重 α₁,…,αₘ,权重非负且和为 1,则输出向量 o 为它们的加权求和。

这就是 从多组向量中选择性地聚合信息 的基本操作。Attention 正是执行这种加权求和,但权重由查询向量决定。

22.2.1.4 查询、键、值向量

22.2.1.5 相似度与 softmax

以内积衡量查询与键的相似度:

随后将这些相似度转换为概率:

由此,最相关的键所对应的值获得更大权重。

22.2.1.6 向量矩阵形式

将 m 个值向量堆成矩阵 V,键矩阵 K,查询矩阵 Q,可以得到矩阵形式的 Attention:

22.2.2 程序示例

以下示例在连续向量空间中通过注意力机制学习字符序列的动态关联,完成“你好世界”序列的自回归生成。

python
# 导入 PyTorch 核心库
import torch                               # 深度学习框架:提供张量运算与自动求导
import torch.nn as nn                      # 神经网络模块:Linear、Embedding、LayerNorm、ModuleList 等
import torch.nn.functional as F            # 函数式接口:提供 softmax 等无状态操作
import torch.optim as optim                # 优化器模块:提供 Adam 等参数更新算法
import math                                # 数学库:用于 sqrt 计算缩放因子

# ============================================================
# 0. vocab — 词表定义
# ============================================================
# 示例仅使用 4 个汉字作为最小词表,便于理解注意力机制每一步的计算
chars = ["你", "好", "世", "界"]            # 词表:4 个字符
vocab_size = len(chars)                     # 词表大小 = 4

# 正向映射:字符 → 整数索引,供模型输入使用
char2idx = {c: i for i, c in enumerate(chars)}
# 反向映射:整数索引 → 字符,供推理阶段还原汉字显示
idx2char = {i: c for c, i in char2idx.items()}

# ============================================================
# 1. Transformer 超参数(标准 Transformer 风格)
# ============================================================
d_model = 512                               # 模型隐藏维度(嵌入维度),Transformer 经典配置
n_heads = 8                                 # 多头注意力头数
n_layers = 6                                # Transformer 块堆叠层数
d_ff = 2048                                 # 前馈网络中间层维度(通常为 d_model 的 4 倍)

# 确保 d_model 可被 n_heads 整除,每头维度为整数
assert d_model % n_heads == 0
d_head = d_model // n_heads                 # 每个注意力头的维度 = 512 / 8 = 64

# ============================================================
# 2. 训练数据 — 自回归序列构造
# ============================================================
# Transformer 自动学习 token → embedding → attention → output logits 的完整流程,
# 无需手动指定 target_vectors。
#
# 输入序列:  [你, 好, 世, 界]   →  索引 [0, 1, 2, 3]
# 目标序列:  [好, 世, 界, 你]   →  索引 [1, 2, 3, 0]
# 这就是“自回归”: 给定前 t 个字符,预测第 t+1 个字符

data = torch.tensor([[0, 1, 2, 3]])         # 输入张量,形状 (batch=1, seq_len=4)
target = torch.tensor([[1, 2, 3, 0]])       # 目标张量,形状 (batch=1, seq_len=4)

# ============================================================
# 3. Multi-Head Attention — 多头注意力层
# ============================================================
class MultiHeadAttention(nn.Module):
    """
    标准缩放点积多头自注意力(Scaled Dot-Product Multi-Head Self-Attention)。

    数学定义:
        Attention(Q, K, V) = softmax(Q·Kᵀ / √d_k) · V

    流程: 输入 x → 线性投影得到 Q/K/V → 拆分为多头 → 计算注意力得分 →
          缩放 → 因果掩码 → softmax → 加权求和 V → 合并多头 → 输出投影
    """

    def __init__(self, d_model, n_heads):
        """
        参数:
            d_model: 模型总维度,同时作为 Q/K/V 的投影维度
            n_heads: 注意力头数
        """
        super().__init__()

        self.d_model = d_model              # 模型维度,例如 512
        self.n_heads = n_heads              # 注意力头数,例如 8
        self.d_head = d_model // n_heads    # 每头维度 = 512 / 8 = 64

        # 四个无偏置线性投影矩阵(形状均为 d_model × d_model):
        self.W_Q = nn.Linear(d_model, d_model, bias=False)   # 查询投影: x → Q
        self.W_K = nn.Linear(d_model, d_model, bias=False)   # 键投影:   x → K
        self.W_V = nn.Linear(d_model, d_model, bias=False)   # 值投影:   x → V
        self.W_O = nn.Linear(d_model, d_model, bias=False)   # 输出投影: 拼接后融合

    def forward(self, x, return_attention=False):
        """
        前向传播。

        参数:
            x:               输入张量,形状 (B, T, D)
            return_attention: 是否返回中间计算结果(Q/K/V/scores/alpha)供可解释性分析

        返回:
            if return_attention=False: out, 形状 (B, T, D)
            if return_attention=True:  (out, Q, K, V, scores, scores_scaled, alpha)
        """
        B, T, D = x.shape                   # B=批次大小, T=序列长度, D=d_model

        # ---------- 第 1 步:线性投影 x → Q, K, V ----------
        Q = self.W_Q(x)                     # (B, T, D) → (B, T, D),每个 token 的查询向量
        K = self.W_K(x)                     # (B, T, D) → (B, T, D),每个 token 的键向量
        V = self.W_V(x)                     # (B, T, D) → (B, T, D),每个 token 的值向量

        # ---------- 第 2 步:拆分为多头 ----------
        # view:   (B, T, D) → (B, T, n_heads, d_head)
        # transpose: 交换维度 1 和 2 → (B, n_heads, T, d_head)
        # 此后每个头独立拥有 T × d_head 的 Q/K/V 子空间
        Q = Q.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        K = K.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        V = V.view(B, T, self.n_heads, self.d_head).transpose(1, 2)

        # ---------- 第 3 步:计算注意力得分 S = Q · Kᵀ ----------
        # Q: (B, n_heads, T, d_head), Kᵀ: (B, n_heads, d_head, T)
        # scores: (B, n_heads, T, T) — 位置 i 和位置 j 之间的原始相似度
        scores = Q @ K.transpose(-2, -1)

        # ---------- 第 4 步:缩放 S / √d_k ----------
        # 除以 √d_k 防止点积值过大,避免 softmax 进入梯度饱和区
        # 在标准配置下 d_head = 64,因此缩放因子 = 8
        scores_scaled = scores / math.sqrt(self.d_head)

        # ---------- 第 5 步:因果掩码(Causal Mask)----------
        # 自回归语言模型要求位置 i 只能看到 ≤ i 的 token(不可预先获取未来信息)
        # torch.triu(..., diagonal=1) 生成上三角矩阵(对角线以上为 True)
        # 例如 T=4 时:
        #   [[F, T, T, T],
        #    [F, F, T, T],
        #    [F, F, F, T],
        #    [F, F, F, F]]
        mask = torch.triu(torch.ones(T, T), diagonal=1).bool()
        # 掩码位置填充 -inf,经 softmax 后权重≈0,即“禁止关注”
        scores_scaled = scores_scaled.masked_fill(mask, float("-inf"))

        # ---------- 第 6 步:Softmax 归一化 ----------
        # 沿最后一维(key 方向)做 softmax,得到注意力权重分布 α
        # 每行权重之和 = 1
        alpha = F.softmax(scores_scaled, dim=-1)

        # ---------- 第 7 步:加权求和 output = α · V ----------
        # α:  (B, n_heads, T, T)     — 注意力权重
        # V:  (B, n_heads, T, d_head) — 值向量
        # out: (B, n_heads, T, d_head) — 加权后的上下文表示
        out = alpha @ V

        # ---------- 第 8 步:合并多头 ----------
        # transpose: (B, n_heads, T, d_head) → (B, T, n_heads, d_head)
        # contiguous + view: 展平为 (B, T, n_heads * d_head) = (B, T, D)
        out = out.transpose(1, 2).contiguous()
        out = out.view(B, T, D)

        # ---------- 第 9 步:输出投影 ----------
        out = self.W_O(out)                # W_O 融合来自不同头的信息

        # 根据 return_attention 标志决定返回内容
        if return_attention:
            return out, Q, K, V, scores, scores_scaled, alpha

        return out


# ============================================================
# 4. FeedForward — 前馈网络
# ============================================================
class FeedForward(nn.Module):
    """
    位置式前馈网络(Position-wise Feed-Forward Network)。
    对每个位置的表示独立应用两层全连接 + ReLU 激活:
        FFN(x) = ReLU(x·W₁ + b₁)·W₂ + b₂
    中间维度 d_ff 通常为 d_model 的 4 倍 (512 → 2048),扩展后压缩回原维度。
    """

    def __init__(self, d_model, d_ff):
        """
        参数:
            d_model: 输入/输出维度
            d_ff:    中间隐藏层维度(通常是 d_model 的 4 倍)
        """
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),      # 升维:d_model → d_ff (512 → 2048)
            nn.ReLU(),                      # 非线性激活函数,引入非线性
            nn.Linear(d_ff, d_model)        # 降维:d_ff → d_model (2048 → 512)
        )

    def forward(self, x):
        return self.net(x)


# ============================================================
# 5. TransformerBlock — Transformer 块
# ============================================================
class TransformerBlock(nn.Module):
    """
    一个完整的 Transformer 块,采用 Pre-Norm 残差结构:

        x  →  LayerNorm  →  MultiHeadAttention  →  +  →  x'
        x' →  LayerNorm  →  FeedForward         →  +  →  x"

    Pre-Norm(先归一化再子层)相比 Post-Norm 训练更稳定。
    """

    def __init__(self, d_model, n_heads, d_ff):
        """
        参数:
            d_model: 模型隐藏维度
            n_heads: 注意力头数
            d_ff:    前馈网络中间层维度
        """
        super().__init__()

        # 第一个 Pre-Norm 子层: LayerNorm + MultiHeadAttention
        self.ln1 = nn.LayerNorm(d_model)     # 层归一化,沿最后一维标准化
        self.attn = MultiHeadAttention(d_model, n_heads)

        # 第二个 Pre-Norm 子层: LayerNorm + FeedForward
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model, d_ff)

    def forward(self, x, return_attention=False):
        """
        前向传播。

        参数:
            x:                输入张量 (B, T, D)
            return_attention: 是否传递至注意力层以捕获中间值

        返回:
            普通模式: 输出张量 (B, T, D)
            注意力模式: (输出, Q, K, V, scores, scores_scaled, alpha)
        """
        if return_attention:
            # 需要捕获注意力中间值:
            #   先 LayerNorm,再注意力(同时返回注意力中间值),最后残差连接
            attn_out, Q, K, V, scores, scores_scaled, alpha = \
                self.attn(self.ln1(x), return_attention=True)

            x = x + attn_out                # 残差连接 1: x + Attention(LayerNorm(x))
            x = x + self.ffn(self.ln2(x))  # 残差连接 2: x + FFN(LayerNorm(x))

            return x, Q, K, V, scores, scores_scaled, alpha

        else:
            # 普通前向: 不捕获中间值
            x = x + self.attn(self.ln1(x))  # 残差连接 1
            x = x + self.ffn(self.ln2(x))   # 残差连接 2

            return x


# ============================================================
# 6. GPT — 自回归语言模型
# ============================================================
class GPT(nn.Module):
    """
    微型 GPT(Generative Pre-trained Transformer)模型。

    架构组成:
        Token Embedding → 与 Position Embedding 相加
        → N 层 TransformerBlock
        → 最终 LayerNorm
        → 线性投影头(输出词表维度 logits)

    通过 return_attention=True 可以捕获第一层 TransformerBlock 的注意力中间值,
    用于可视化分析和教学演示。
    """

    def __init__(
        self,
        vocab_size,                         # 词表大小,此处为 4
        d_model=512,                        # 隐藏维度
        n_heads=8,                          # 注意力头数
        n_layers=6,                         # Transformer 块数
        d_ff=2048,                          # FFN 中间维度
        max_len=128                         # 支持的最大序列长度
    ):
        super().__init__()

        # Token Embedding: 将字符索引映射为 d_model 维稠密向量
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        # Position Embedding: 为位置 0~max_len-1 各分配一个可学习的嵌入向量
        # 使模型感知 token 的相对/绝对位置(Transformer 本身无序列感知能力)
        self.position_embedding = nn.Embedding(max_len, d_model)

        # 用 ModuleList 而非 Sequential,以便按索引遍历并取出特定层输出
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff)
            for _ in range(n_layers)        # 堆叠 n_layers=6 个 Transformer 块
        ])

        # 最终层归一化:在所有 Transformer 块之后、输出投影之前稳定分布
        self.ln_f = nn.LayerNorm(d_model)
        # 输出投影头:将 d_model 维隐藏状态映射回 vocab_size 维,得到每个 token 的 logits
        self.head = nn.Linear(d_model, vocab_size)

    def forward(self, idx, return_attention=False):
        """
        前向传播。

        参数:
            idx:              输入 token 索引,形状 (B, T)
            return_attention: 是否返回第一层 TransformerBlock 的注意力中间值

        返回:
            if return_attention=False: logits, 形状 (B, T, vocab_size)
            if return_attention=True:  (logits, (Q, K, V, scores, scores_scaled, alpha))
        """
        B, T = idx.shape                    # B=批次大小, T=序列长度

        # 生成位置索引 [0, 1, 2, ..., T-1],形状 (1, T),移至与 idx 相同设备
        positions = torch.arange(T).unsqueeze(0).to(idx.device)

        # Token 嵌入 + 位置嵌入(逐元素相加)
        token_emb = self.token_embedding(idx)       # (B, T) → (B, T, d_model)
        pos_emb = self.position_embedding(positions)  # (1, T) → (1, T, d_model) → 广播
        x = token_emb + pos_emb                       # 嵌入融合: (B, T, d_model)

        saved = None                         # 用于保存第一层注意力的中间结果

        for i, block in enumerate(self.blocks):
            # 遍历每一层 TransformerBlock
            if return_attention and i == 0:
                # 仅在第一层 (i==0) 捕获注意力中间值供分析
                x, Q, K, V, scores, scores_scaled, alpha = \
                    block(x, return_attention=True)
                saved = (Q, K, V, scores, scores_scaled, alpha)
            else:
                x = block(x)                # 其余层正常前向

        x = self.ln_f(x)                    # 最终 LayerNorm

        logits = self.head(x)               # (B, T, d_model) → (B, T, vocab_size)

        if return_attention:
            return logits, saved            # 返回预测值 + 第一层注意力数据

        return logits


# ============================================================
# 7. 模型实例化与优化器/损失函数配置
# ============================================================
model = GPT(                                # 实例化 GPT 模型
    vocab_size=vocab_size,                  # 词表大小 = 4
    d_model=d_model,                        # 维度 = 512
    n_heads=n_heads,                        # 8 个注意力头
    n_layers=n_layers,                      # 6 层 Transformer
    d_ff=d_ff                               # FFN 中间维度 = 2048
)

optimizer = optim.Adam(model.parameters(), lr=1e-4)  # Adam 优化器,学习率 0.0001
loss_fn = nn.CrossEntropyLoss()                      # 交叉熵损失:衡量 logits 与目标分布的差距

# ============================================================
# 8. 训练循环
# ============================================================
for epoch in range(1000):                  # 训练 1000 个 epoch

    optimizer.zero_grad()                   # 清空上一轮的梯度缓存

    logits = model(data)                    # 前向传播:输入 [0,1,2,3] → logits (1,4,4)

    # 计算交叉熵损失
    # logits.view(-1, vocab_size): (1*4, 4) = (4, 4)  — 展平为 4 个样本,每个 4 类
    # target.view(-1):            (1*4,)  = (4,)      — 展平为 4 个目标标签
    loss = loss_fn(
        logits.view(-1, vocab_size),
        target.view(-1)
    )

    loss.backward()                         # 反向传播:计算所有参数的梯度
    optimizer.step()                        # 参数更新:沿梯度方向优化

    if epoch % 100 == 0:                   # 每 100 轮输出一次当前损失
        print(f"\nEpoch {epoch}, loss={loss.item():.6f}")

# ============================================================
# 9. 推理与注意力可视化
# ============================================================
with torch.no_grad():                       # 推理阶段禁用梯度计算,节省显存

    # 推理时设置 return_attention=True,捕获第一层注意力的中间值
    logits, saved = model(data, return_attention=True)

    # argmax 取每个位置 logits 最大的索引,即模型预测的下一个字符
    pred = torch.argmax(logits, dim=-1)    # (1, 4, 4) → (1, 4)

    # 解包保存的注意力中间值
    Q, K, V, scores, scores_scaled, alpha = saved

    # ---------- 打印输入/输出 ----------
    print("\n==============================")
    print("输入")
    print("==============================")
    print([idx2char[i.item()] for i in data[0]])   # 索引还原为汉字

    print("\n==============================")
    print("预测")
    print("==============================")
    print([idx2char[i.item()] for i in pred[0]])

    # ---------- 可视化第一层注意力 ----------
    # 以下取 batch=0, head=0 的数据展示注意力计算全过程

    print("\n==============================")
    print("Q")
    print("==============================")
    print(Q[0, 0].detach().numpy().round(3))       # 第 0 个 batch 第 0 个头: (T, d_head)

    print("\n==============================")
    print("K")
    print("==============================")
    print(K[0, 0].detach().numpy().round(3))

    print("\n==============================")
    print("V")
    print("==============================")
    print(V[0, 0].detach().numpy().round(3))

    print("\n==============================")
    print("scores = QK^T")                          # S = Q · Kᵀ,原始相似度
    print("==============================")
    print(scores[0, 0].detach().numpy().round(3))   # (T, T) 矩阵,每行是查询对各键的得分

    print("\n==============================")
    print("scores_scaled")                          # S / √d_k,缩放后
    print("==============================")
    print(scores_scaled[0, 0].detach().numpy().round(3))

    print("\n==============================")
    print("softmax alpha")                          # α = softmax(S_scaled)
    print("==============================")
    print(alpha[0, 0].detach().numpy().round(3))    # 每行和≈1,因果掩码使 α[i][j>i]=0

程序输出结果示例:

python
Epoch 0, loss=2.033882

Epoch 100, loss=0.000004

Epoch 200, loss=0.000001

Epoch 300, loss=0.000001

Epoch 400, loss=0.000001

Epoch 500, loss=0.000001

Epoch 600, loss=0.000001

Epoch 700, loss=0.000001

Epoch 800, loss=0.000001

Epoch 900, loss=0.000001

==============================
输入
==============================
['你', '好', '世', '界']

==============================
预测
==============================
['好', '世', '界', '你']

==============================
Q
==============================
[[-0.35  -0.407 -0.483  0.202 -1.21   0.052 -0.176 -0.561 -0.954  0.093
  -0.656  0.312 -0.605  0.628 -0.073 -0.659 -0.658 -0.503  0.324  0.06
  -0.949  0.006 -0.826 -0.279 -0.249 -0.938  0.166 -1.03  -0.397 -0.074
  -0.224 -0.779  0.28  -0.599 -0.993 -1.21  -1.172  0.522  0.781  0.627
   0.02   0.964  0.134  0.466 -0.597 -0.354  1.265  0.147  0.153 -0.168
  -0.758 -1.047 -0.158 -0.468  0.3   -0.336 -0.627  0.59  -0.605  0.138
   0.853 -0.566  0.453  1.011]
 [ 0.27  -0.235 -1.544 -0.038 -0.234  0.129 -0.311 -0.97  -0.243  0.366
   0.974 -0.945  1.132 -0.288  0.058  0.435 -1.075  0.144 -0.33   0.04
   0.292  0.917  0.534  0.309  0.005  1.361 -0.053  1.036  0.357  0.03
  -1.367  1.261  0.283 -1.16   0.064  0.076  0.333 -0.472  0.25   0.676
   0.645 -0.466 -0.35   0.288 -0.622 -0.433  0.905 -0.78  -0.267  0.335
  -0.605  1.117  1.436 -0.976  0.721 -0.072 -0.166  0.609  0.454 -0.149
  -0.154 -0.668  0.161 -0.913]
 [ 0.581  0.029 -0.121  0.056 -0.817  0.408  0.55   0.592  0.344  0.392
  -0.727  0.332 -0.747  0.111  0.192 -0.743  0.444  0.83   0.271  0.468
  -0.137 -0.249  0.421 -0.437 -0.49   0.057  1.117  0.208 -0.024  0.32
  -0.607 -1.125 -0.129  0.433 -0.024  0.241 -0.815 -0.317 -0.91  -0.111
   1.066  0.623 -0.157 -0.078  0.192  0.671 -0.932  0.325 -0.445  0.267
  -0.302 -0.772 -0.207  0.422 -0.264  0.245 -0.51   0.083 -0.671  0.297
  -0.732 -0.329  0.484  0.396]
 [-0.164 -0.309  0.665  0.228  0.476  0.356  0.508 -0.426  0.738 -1.743
   0.765 -1.036  0.976  0.177  0.305 -0.756 -0.159 -0.67   0.28  -0.417
   0.269 -1.157  0.675  0.042 -0.181  0.131  0.475  0.959 -0.266 -0.379
  -0.019  0.514 -0.391  0.601 -0.263  0.054  0.207  0.211 -0.374  1.04
  -0.138  0.116 -0.39   0.163 -0.706 -0.45  -0.003 -0.221 -0.198 -0.205
   0.011 -0.605  0.984 -0.19   0.639 -0.172  0.391  0.23  -0.836  0.388
  -0.843  0.875  1.419  1.416]]

==============================
K
==============================
[[-0.758  0.836  1.233 -0.518  1.507 -0.171  0.372  1.24  -0.012  0.445
  -0.646  0.797 -1.395 -0.258 -0.864  1.3    0.911  0.249  0.716  0.063
  -0.287  0.377 -0.545 -0.265  0.266 -1.027 -0.283 -0.82  -1.036 -0.487
   0.696  0.463 -0.023 -0.2    0.003 -0.34   0.028 -0.417 -0.254  0.457
  -0.762  0.02   0.358 -0.576  0.728 -0.    -0.277  0.591  0.231 -0.004
   0.678  0.182 -0.355  0.832 -0.75  -0.295  0.176 -0.704  0.651 -0.327
  -0.286 -0.435 -0.667 -0.421]
 [-0.367  0.32  -0.186 -0.378 -0.099  0.194  0.173 -0.819 -0.245  0.213
   0.657  0.48  -0.556 -0.199  0.147  0.707 -0.712 -0.05  -0.426 -0.776
  -0.975 -0.009  1.464 -0.459  0.894 -0.531  0.331  0.945 -0.772 -0.308
  -1.053  0.235 -0.061 -0.689  0.808  1.045 -1.145 -0.563  0.35   0.508
   0.531 -0.083 -1.198 -0.17   0.603 -0.352  0.115 -0.765 -0.331 -0.01
  -0.155 -0.298  0.175 -0.224  0.238 -0.13  -0.169  0.69  -1.039 -0.211
  -0.962 -0.051  0.119  0.095]
 [-1.209 -0.033  0.01   0.999 -0.245 -0.039  0.07  -0.573 -0.168 -0.506
   0.262 -0.775  0.129  0.323 -0.973 -0.38  -0.042 -0.542 -0.859  0.274
   0.533 -0.437 -0.169 -0.279 -0.091 -0.663  0.257 -0.414 -0.418  0.152
   0.11   0.647 -0.328 -0.307  0.3   -0.767 -0.319  0.01  -0.304  0.964
   0.459 -0.167  0.375 -0.713 -0.56  -0.774  0.093  0.444  0.294 -0.894
  -0.624 -0.389  0.906 -0.127  0.104 -0.849  0.852 -0.373  0.257 -0.362
  -0.849 -0.226 -0.101  0.524]
 [-0.113  0.098 -1.399 -0.02  -0.138 -0.146 -0.012  0.238 -0.635  0.471
  -0.496  0.014  0.286 -0.013 -0.101  0.229 -1.211 -0.011  0.656  0.288
   0.015  1.026 -0.975 -0.086  0.849 -0.453 -0.008 -0.421  0.008  0.809
   0.159  0.371 -0.331 -0.861  0.488 -0.312  0.251 -0.362 -0.15   0.226
  -0.11  -1.087  0.258 -0.936  1.026 -0.263 -0.249 -0.646  0.053  0.126
   0.028  0.064 -0.741  0.314  0.544 -0.04  -0.123  0.305 -0.596 -0.908
  -0.176 -0.341 -0.923  0.163]]

==============================
V
==============================
[[-0.67  -0.117 -1.085 -0.536  1.016  0.047  0.154 -0.087  0.926 -0.14
   0.781 -0.443  0.172 -0.433 -0.567  0.666 -0.321  0.173 -0.428 -0.357
  -0.559 -0.221  0.138 -0.469  0.326 -0.691  0.647  0.108  0.062 -0.926
  -0.424 -0.87   0.053 -1.048  0.652 -0.798 -0.27   0.211 -0.281 -0.317
   0.637 -0.954 -0.274 -0.991  0.6    0.093  0.044 -1.527 -0.575 -0.087
   0.494 -0.896  0.691 -1.437 -0.675 -0.253 -0.193 -0.758 -0.29  -1.044
  -0.631 -0.194  0.17   0.962]
 [-0.193  0.785  0.281  1.047 -0.778 -0.042  0.493 -0.989 -0.119  0.309
  -0.724 -0.085 -0.068 -0.205  0.213  1.236  0.288  1.454 -0.181 -0.452
   0.334  0.37  -0.163  0.507  0.01   0.317  0.442  0.108  0.258 -0.605
   0.597  0.36  -0.311  0.552 -0.37   0.423  1.258 -0.014  0.607  0.109
   0.078 -0.015 -0.743  0.727  0.526  0.137  0.076  0.581 -0.797 -0.633
  -1.153  0.331 -0.329  0.178 -0.325 -1.605  0.141  0.207  1.123 -0.171
  -0.499  0.082 -0.06  -0.245]
 [-0.05   0.524 -0.457 -0.141 -0.514 -0.749  0.284 -0.071  0.072 -0.255
  -0.066 -0.178  0.629  0.028  0.027 -0.088 -0.594  0.894 -0.834  0.163
  -0.85  -0.203  0.152 -0.006 -0.053  0.338 -0.117  0.177 -0.059 -0.246
   0.522 -0.267  0.842  0.308 -0.153  0.036 -0.091  0.693  0.592  0.425
   0.679  0.323  0.123 -0.902  0.734  0.215 -0.396 -0.322  0.013  0.511
   0.352 -0.016 -0.481  0.078  0.021 -0.197  0.624 -0.753 -0.083  0.389
   0.226  0.331  0.727 -0.86 ]
 [-0.219  0.152 -0.382  0.197 -0.753 -0.36  -0.544 -0.78   0.7   -0.306
  -0.269  0.145 -0.222  0.287 -0.851 -0.078 -0.227  0.145 -0.272  0.687
   0.432  0.179 -1.144  0.775 -1.057 -0.776  0.556  0.788  0.059  1.002
  -1.004 -0.037 -0.441 -0.011  0.488 -0.348 -0.531  0.883  0.063 -0.765
  -0.091  0.13  -0.484 -1.145  0.331  0.787  0.875 -0.226  0.819 -0.523
  -1.217 -0.296 -0.068 -0.378  0.773  0.452 -0.011 -0.157  0.278 -0.372
  -0.036  0.179 -0.251  0.309]]

==============================
scores = QK^T
==============================
[[ -4.192   0.67    3.931   1.269]
 [-13.375   6.972   1.766   3.113]
 [ -2.067   2.363  -2.743  -2.388]
 [ -9.26    3.297   6.765  -7.769]]

==============================
scores_scaled
==============================
[[-0.524   -inf   -inf   -inf]
 [-1.672  0.871   -inf   -inf]
 [-0.258  0.295 -0.343   -inf]
 [-1.158  0.412  0.846 -0.971]]

==============================
softmax alpha
==============================
[[1.    0.    0.    0.   ]
 [0.073 0.927 0.    0.   ]
 [0.273 0.475 0.251 0.   ]
 [0.069 0.333 0.514 0.084]]

上述代码实现了微型 GPT(自回归语言模型):给定“你”预测“好”,给定“你好”预测“世”,以此类推。经过 1000 轮训练,交叉熵损失从 2.033882 降至 0.000001,模型精确习得了序列 ['你', '好', '世', '界']['好', '世', '界', '你'] 的移位关系。

模型严格符合标准 缩放点积注意力(Scaled Dot-Product Attention)的数学定义:

在代码中的对应关系如下。

Q、K、V 生成

源代码中 Q = self.W_Q(x)K = self.W_K(x)V = self.W_V(x) —— 三者从同一输入 x 经线性投影得到,即自注意力。由于:

投影结果被拆分为 8 头,每头维度

(代码 Q.view(B, T, self.n_heads, self.d_head).transpose(1, 2))。

输出仅展示了第一个注意力头(head=0)的 Q、K、V、scores 和 alpha,便于读者追溯完整的计算链路。

相似度与缩放

对应源代码 scores = Q @ K.transpose(-2, -1)scores_scaled = scores / math.sqrt(self.d_head)

因果掩码(Causal Mask)

python
mask = torch.triu(torch.ones(T, T), diagonal=1).bool()
scores_scaled = scores_scaled.masked_fill(mask, float("-inf"))

上三角置 -∞,迫使位置 i 只能关注位置 0,…,i:这是自回归模型的根本约束——预测下一个词元时不可预先获取未来信息。scaled 矩阵中的 -inf 正是这一约束的直观体现。

Softmax 归一化

对应 alpha = F.softmax(scores_scaled, dim=-1)。由于

mask 位置权重为零;其余位置非负且和为 1,构成合法概率分布:

加权求和输出

对应 out = alpha @ V,随后合并多头、经 W_O 投影输出。

前馈网络(FFN)

python
nn.Linear(512, 2048), nn.ReLU(), nn.Linear(2048, 512)

与原始论文一致。

残差连接与层归一化

python
x = x + self.attn(self.ln1(x))   # Pre-LayerNorm
x = x + self.ffn(self.ln2(x))

等同于

的 Pre-LN 变体。

规模对比

本例仅 4 个词元

与原始 Transformer 论文的 base 配置完全一致,证明即使是极小规模的序列,只要数学结构正确,Transformer 同样能够完美收敛。

22.2.3 课后习题

  1. Transformer 架构中的自注意力机制解决了传统循环神经网络在处理长序列时的什么局限?请用你自己的话解释注意力机制的核心思想。