Transformer架构详解:从注意力机制到GPT模型实现

Transformer架构详解:从注意力机制到GPT模型实现
30款热门AI模型一站整合DeepSeek/GLM/Qwen 随心用限时 5 折。 点击领海量免费额度1. 背景与核心概念Transformer 架构自 2017 年由 Google 团队在论文《Attention Is All You Need》中提出以来已经彻底改变了自然语言处理乃至整个深度学习领域。它摒弃了传统的循环神经网络和卷积神经网络在处理序列数据时的固有缺陷通过一种全新的、完全基于注意力机制的架构实现了前所未有的并行化能力和对长距离依赖关系的强大建模能力。如今从 ChatGPT、GPT-4 这样的对话模型到 BERT、T5 这样的理解模型再到 Stable Diffusion、Sora 这样的多模态生成模型其核心都离不开 Transformer。对于许多开发者而言Transformer 的原理常常被其复杂的数学公式和架构图所掩盖感觉“高深莫测”。实际上其核心思想非常直观。本文将彻底拆解 Transformer 的每一个组件从最基础的注意力机制开始逐步构建起完整的编码器-解码器架构并结合代码示例和实际应用场景让你不仅能理解其原理更能掌握其实现细节。无论你是希望深入理解大模型背后的技术还是计划在自己的项目中应用 Transformer 架构这篇文章都将为你提供一个清晰、系统、可实践的指南。2. 从序列建模的困境到注意力机制在 Transformer 出现之前序列建模如机器翻译、文本生成的主流是循环神经网络及其变体 LSTM 和 GRU。这些模型按顺序处理输入序列将之前步骤的信息保存在一个“隐藏状态”中。然而这种顺序处理方式存在两个根本性瓶颈并行化困难由于每一步的计算都依赖于上一步的隐藏状态模型无法充分利用现代 GPU 的并行计算能力训练速度慢。长距离依赖遗忘对于较长的序列早期输入的信息在传递过程中会逐渐衰减或丢失模型难以捕捉序列开头和结尾之间的关联。注意力机制的引入是解决第二个问题的关键一步。它允许模型在生成输出序列的每一个词时直接“查看”输入序列中的所有词并动态地为每个输入词分配一个“关注度”权重。这就像人在翻译句子时会反复回看原文的不同部分一样。最初的注意力机制被用在基于 RNN 的编码器-解码器模型中但它仍然依赖于 RNN 来生成编码表示因此第一个瓶颈并行化困难依然存在。Transformer 的革命性在于它完全抛弃了循环结构仅使用注意力机制来构建整个模型。这使得模型可以一次性处理整个输入序列所有词之间的关联计算都可以并行进行极大地提升了训练效率。3. Transformer 核心组件详解一个标准的 Transformer 模型主要由编码器和解码器堆叠而成。我们先来深入理解构成它们的基础模块。3.1 输入表示词嵌入与位置编码Transformer 的输入是一系列词元。首先每个词元通过一个词嵌入层被映射为一个高维向量。这个向量捕获了词义的语义信息。然而自注意力机制本身是对顺序不敏感的。对于句子 “狗咬人” 和 “人咬狗”如果不提供位置信息模型会认为它们是相同的。因此我们必须注入位置信息。位置编码为序列中每个位置的词元嵌入向量添加一个独特的向量。原论文使用了一种基于正弦和余弦函数的固定编码方式对于位置pos和维度i其计算公式为PE(pos, 2i) sin(pos / 10000^(2i/d_model)) PE(pos, 2i1) cos(pos / 10000^(2i/d_model))其中d_model是嵌入向量的维度。这种编码方式的优点是它能产生一种有界、平滑的位置表示并且模型可以轻松学会关注相对位置因为PE(posk)可以表示为PE(pos)的线性函数。import numpy as np import torch import torch.nn as nn class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len5000): super(PositionalEncoding, self).__init__() # 创建一个足够长的位置编码矩阵 pe torch.zeros(max_len, d_model) position torch.arange(0, max_len, dtypetorch.float).unsqueeze(1) div_term torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) pe[:, 0::2] torch.sin(position * div_term) # 偶数维度用sin pe[:, 1::2] torch.cos(position * div_term) # 奇数维度用cos pe pe.unsqueeze(0) # 形状: (1, max_len, d_model) self.register_buffer(pe, pe) # 将其注册为缓冲区不参与梯度更新 def forward(self, x): # x 形状: (batch_size, seq_len, d_model) x x self.pe[:, :x.size(1)] return x # 示例 d_model 512 seq_len 10 batch_size 2 embedding torch.randn(batch_size, seq_len, d_model) pos_encoder PositionalEncoding(d_model) output_with_pos pos_encoder(embedding) print(f输入嵌入形状: {embedding.shape}) print(f加入位置编码后形状: {output_with_pos.shape})3.2 缩放点积注意力这是 Transformer 的灵魂。其核心思想是对于序列中的每一个元素查询 Query计算它与序列中所有元素键 Key的相关性然后用这个相关性权重对对应的值Value进行加权求和从而得到一个融合了全局上下文信息的表示。计算步骤线性变换将输入序列通过三个不同的权重矩阵W_Q,W_K,W_V投影得到查询矩阵Q、键矩阵K和值矩阵V。计算注意力分数计算Q和K的点积度量查询和键之间的相关性。分数越高表示相关性越强。缩放将点积结果除以sqrt(d_k)其中d_k是键向量的维度。这一步是为了防止点积结果过大导致经过 softmax 后梯度消失。归一化对缩放后的分数应用 softmax 函数将其转化为概率分布权重和为1。加权求和用 softmax 得到的权重对V进行加权求和得到最终的输出。公式表示Attention(Q, K, V) softmax(Q * K^T / sqrt(d_k)) * Vimport torch import torch.nn.functional as F def scaled_dot_product_attention(query, key, value, maskNone): query: 形状 (..., seq_len_q, d_k) key: 形状 (..., seq_len_k, d_k) value: 形状 (..., seq_len_v, d_v) 通常 seq_len_k seq_len_v mask: 形状 (..., seq_len_q, seq_len_k) d_k query.size(-1) # 计算点积并缩放 scores torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtypetorch.float32)) if mask is not None: # 将 mask 中为 True 的位置替换为一个非常大的负数这样 softmax 后权重接近 0 scores scores.masked_fill(mask 0, -1e9) # 计算注意力权重 attention_weights F.softmax(scores, dim-1) # 加权求和 output torch.matmul(attention_weights, value) return output, attention_weights # 示例自注意力 (seq_len_q seq_len_k seq_len_v) batch_size 2 seq_len 5 d_model 64 d_k d_v d_model query torch.randn(batch_size, seq_len, d_k) key torch.randn(batch_size, seq_len, d_k) value torch.randn(batch_size, seq_len, d_v) output, attn_weights scaled_dot_product_attention(query, key, value) print(f注意力输出形状: {output.shape}) # (2, 5, 64) print(f注意力权重形状: {attn_weights.shape}) # (2, 5, 5) # 对于第0个样本第0个词元对其他所有词元的注意力权重 print(f示例注意力权重第一个样本第一个词元: {attn_weights[0, 0]})3.3 多头注意力单一的注意力机制可能只关注到一种类型的语义关系。为了让模型能够同时关注来自不同表示子空间的信息Transformer 引入了多头注意力。其做法是将Q,K,V通过h个不同的线性投影矩阵分别投影到h个更小的空间d_k,d_v维度。在每个投影后的子空间上独立执行缩放点积注意力得到h个输出。将这h个输出拼接起来再通过一个最终的线性投影层W_O融合信息。公式表示MultiHead(Q, K, V) Concat(head_1, ..., head_h) * W_O where head_i Attention(Q * W_Q_i, K * W_K_i, V * W_V_i)import torch.nn as nn class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super(MultiHeadAttention, self).__init__() assert d_model % num_heads 0, d_model 必须能被 num_heads 整除 self.d_model d_model self.num_heads num_heads self.d_k d_model // num_heads # 定义投影矩阵 self.W_q nn.Linear(d_model, d_model) # 实际实现中通常先投影到 d_model self.W_k nn.Linear(d_model, d_model) self.W_v nn.Linear(d_model, d_model) self.W_o nn.Linear(d_model, d_model) def split_heads(self, x): 将输入张量从 (batch_size, seq_len, d_model) 重塑为 (batch_size, num_heads, seq_len, d_k) batch_size, seq_len, _ x.size() return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) def forward(self, query, key, value, maskNone): batch_size query.size(0) # 1. 线性投影 Q self.W_q(query) # (batch_size, seq_len_q, d_model) K self.W_k(key) # (batch_size, seq_len_k, d_model) V self.W_v(value) # (batch_size, seq_len_v, d_model) # 2. 分割成多个头 Q self.split_heads(Q) # (batch_size, num_heads, seq_len_q, d_k) K self.split_heads(K) # (batch_size, num_heads, seq_len_k, d_k) V self.split_heads(V) # (batch_size, num_heads, seq_len_v, d_k) # 3. 为每个头计算缩放点积注意力 # 我们需要调整 mask 的维度以匹配多头 if mask is not None: mask mask.unsqueeze(1) # (batch_size, 1, seq_len_q, seq_len_k) - 广播到每个头 # 计算注意力这里调用之前定义的函数但需要处理多头维度 # 简便起见我们重塑张量将 num_heads 视为 batch 维度的一部分 Q_reshaped Q.transpose(1, 2).contiguous().view(batch_size * self.num_heads, -1, self.d_k) K_reshaped K.transpose(1, 2).contiguous().view(batch_size * self.num_heads, -1, self.d_k) V_reshaped V.transpose(1, 2).contiguous().view(batch_size * self.num_heads, -1, self.d_k) if mask is not None: mask_reshaped mask.repeat(1, self.num_heads, 1, 1).view(batch_size * self.num_heads, -1, mask.size(-1)) attn_output, _ scaled_dot_product_attention(Q_reshaped, K_reshaped, V_reshaped, mask_reshaped if mask is not None else None) # 4. 合并多头输出 attn_output attn_output.view(batch_size, self.num_heads, -1, self.d_k).transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) # 5. 最终线性投影 output self.W_o(attn_output) return output # 示例 d_model 512 num_heads 8 mha MultiHeadAttention(d_model, num_heads) seq_len 10 batch_size 4 x torch.randn(batch_size, seq_len, d_model) # 假设是自注意力QKVx output mha(x, x, x) print(f多头注意力输入形状: {x.shape}) print(f多头注意力输出形状: {output.shape}) # 应保持 (4, 10, 512)3.4 前馈网络每个编码器和解码器层中的注意力子层后面都跟着一个前馈网络。这是一个简单的两层全连接神经网络通常中间层的维度更大例如d_ff 4 * d_model并带有 ReLU 激活函数。它的作用是对每个位置的表示进行独立、相同的非线性变换增加模型的表达能力。FFN(x) max(0, x * W1 b1) * W2 b2class PositionwiseFeedForward(nn.Module): def __init__(self, d_model, d_ff, dropout0.1): super(PositionwiseFeedForward, self).__init__() self.linear1 nn.Linear(d_model, d_ff) self.linear2 nn.Linear(d_ff, d_model) self.dropout nn.Dropout(dropout) self.activation nn.ReLU() def forward(self, x): return self.linear2(self.dropout(self.activation(self.linear1(x)))) # 示例 d_model 512 d_ff 2048 ffn PositionwiseFeedForward(d_model, d_ff) x torch.randn(4, 10, d_model) output ffn(x) print(f前馈网络输出形状: {output.shape}) # (4, 10, 512)3.5 残差连接与层归一化为了缓解深层网络中的梯度消失问题并稳定训练Transformer 在每个子层自注意力、前馈网络周围都使用了残差连接并在其后进行层归一化。残差连接将子层的输入直接加到其输出上即Output LayerNorm(x Sublayer(x))。这确保了信息可以更直接地向前传播。层归一化对单个样本的所有特征维度进行归一化使其均值为0方差为1。这有助于稳定训练过程加速收敛。现代实现更常用Pre-LN结构即先做层归一化再进入子层Output x Sublayer(LayerNorm(x))这被证明训练更稳定。class SublayerConnection(nn.Module): 一个残差连接后接层归一化。注意为了简化我们使用 Post-LN。 def __init__(self, size, dropout): super(SublayerConnection, self).__init__() self.norm nn.LayerNorm(size) self.dropout nn.Dropout(dropout) def forward(self, x, sublayer): 应用残差连接到任何与 x 相同形状的子层。 # Post-LN: LayerNorm(x Sublayer(x)) return x self.dropout(sublayer(self.norm(x))) # 如果是 Pre-LN则应为return x self.dropout(sublayer(self.norm(x))) # 注意Pre-LN 中 norm 在 sublayer 内部调用这里仅为示意。4. 编码器与解码器架构4.1 编码器层一个编码器层由两个主要子层构成多头自注意力层输入序列自己对自己做注意力让每个词元都能关注到序列中所有其他词元从而获得包含全局上下文的表示。前馈网络层对每个位置的表示进行独立变换。每个子层周围都有残差连接和层归一化。class EncoderLayer(nn.Module): def __init__(self, d_model, num_heads, d_ff, dropout0.1): super(EncoderLayer, self).__init__() self.self_attn MultiHeadAttention(d_model, num_heads) self.feed_forward PositionwiseFeedForward(d_model, d_ff, dropout) self.sublayer nn.ModuleList([SublayerConnection(d_model, dropout) for _ in range(2)]) self.size d_model def forward(self, x, mask): x: 输入张量形状 (batch_size, seq_len, d_model) mask: 用于自注意力的 mask形状 (batch_size, 1, seq_len, seq_len) 或 (batch_size, seq_len, seq_len) # 第一个子层多头自注意力 (带残差和归一化) x self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) # 第二个子层前馈网络 (带残差和归一化) x self.sublayer[1](x, self.feed_forward) return x4.2 解码器层解码器层比编码器层多一个子层共三个掩码多头自注意力层防止解码器在预测当前位置时“偷看”未来的信息。这是通过一个因果掩码实现的该掩码将注意力权重矩阵右上三角部分未来位置设置为负无穷大。编码器-解码器注意力层交叉注意力让解码器能够关注编码器的最终输出。其中查询Q来自解码器的上一子层输出而键K和值V来自编码器的输出。前馈网络层与编码器相同。class DecoderLayer(nn.Module): def __init__(self, d_model, num_heads, d_ff, dropout0.1): super(DecoderLayer, self).__init__() self.self_attn MultiHeadAttention(d_model, num_heads) self.cross_attn MultiHeadAttention(d_model, num_heads) self.feed_forward PositionwiseFeedForward(d_model, d_ff, dropout) self.sublayer nn.ModuleList([SublayerConnection(d_model, dropout) for _ in range(3)]) self.size d_model def forward(self, x, encoder_output, src_mask, tgt_mask): x: 解码器输入 (或上一层的输出)形状 (batch_size, tgt_seq_len, d_model) encoder_output: 编码器输出形状 (batch_size, src_seq_len, d_model) src_mask: 源序列 mask (用于编码器-解码器注意力可选) tgt_mask: 目标序列 mask (用于解码器自注意力因果掩码) # 第一子层掩码自注意力 x self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) # 第二子层编码器-解码器注意力 # Q 来自解码器K, V 来自编码器 x self.sublayer[1](x, lambda x: self.cross_attn(x, encoder_output, encoder_output, src_mask)) # 第三子层前馈网络 x self.sublayer[2](x, self.feed_forward) return x4.3 构建完整的 Transformer现在我们可以将编码器层和解码器层堆叠起来并加上嵌入层和最后的线性输出层构建一个完整的 Transformer 模型。class Transformer(nn.Module): def __init__(self, src_vocab_size, tgt_vocab_size, d_model512, num_heads8, num_encoder_layers6, num_decoder_layers6, d_ff2048, max_seq_len5000, dropout0.1): super(Transformer, self).__init__() self.encoder_embedding nn.Embedding(src_vocab_size, d_model) self.decoder_embedding nn.Embedding(tgt_vocab_size, d_model) self.positional_encoding PositionalEncoding(d_model, max_seq_len) self.encoder_layers nn.ModuleList([ EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_encoder_layers) ]) self.decoder_layers nn.ModuleList([ DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_decoder_layers) ]) self.final_linear nn.Linear(d_model, tgt_vocab_size) self.dropout nn.Dropout(dropout) self.d_model d_model def generate_square_subsequent_mask(self, sz): 生成因果掩码 (下三角为 True上三角为 False) mask (torch.triu(torch.ones(sz, sz)) 1).transpose(0, 1) mask mask.float().masked_fill(mask 0, float(-inf)).masked_fill(mask 1, float(0.0)) return mask def forward(self, src, tgt, src_maskNone, tgt_maskNone): src: 源语言序列索引形状 (batch_size, src_len) tgt: 目标语言序列索引形状 (batch_size, tgt_len) src_mask: 源序列 padding mask (可选) tgt_mask: 目标序列因果掩码 padding mask # 1. 编码器 src_emb self.dropout(self.positional_encoding(self.encoder_embedding(src) * torch.sqrt(torch.tensor(self.d_model, dtypetorch.float32)))) memory src_emb for layer in self.encoder_layers: memory layer(memory, src_mask) # 2. 解码器 if tgt_mask is None: tgt_mask self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device) tgt_emb self.dropout(self.positional_encoding(self.decoder_embedding(tgt) * torch.sqrt(torch.tensor(self.d_model, dtypetorch.float32)))) output tgt_emb for layer in self.decoder_layers: output layer(output, memory, src_mask, tgt_mask) # 3. 输出投影 logits self.final_linear(output) return logits # 示例定义一个微型 Transformer src_vocab_size 10000 tgt_vocab_size 10000 model Transformer(src_vocab_size, tgt_vocab_size, d_model128, num_heads4, num_encoder_layers2, num_decoder_layers2, d_ff512) batch_size 4 src_len 10 tgt_len 12 src torch.randint(0, src_vocab_size, (batch_size, src_len)) tgt torch.randint(0, tgt_vocab_size, (batch_size, tgt_len)) logits model(src, tgt) print(f模型输出 logits 形状: {logits.shape}) # (batch_size, tgt_len, tgt_vocab_size) # 这代表了在目标序列每个位置对目标词汇表中所有词的概率预测5. 训练与推理流程5.1 训练任务掩码语言建模与自回归语言建模Transformer 的训练方式决定了其最终用途编码器-解码器架构如原始 Transformer、T5通常用于序列到序列任务如翻译、摘要。训练时编码器接收源序列解码器以自回归方式使用因果掩码预测目标序列。仅编码器架构如 BERT用于理解任务。采用掩码语言建模随机遮盖输入序列中的一些词元让模型根据上下文预测被遮盖的词。仅解码器架构如 GPT 系列用于生成任务。采用自回归语言建模给定前文预测下一个词。训练时整个序列作为输入但使用因果掩码确保预测位置i时只能看到位置 i的信息。5.2 推理流程以自回归生成为例仅解码器模型如 GPT的推理是一个循环过程给定一个起始标记如bos输入模型。模型输出下一个词的概率分布。根据某种策略如贪婪搜索、束搜索、采样从分布中选择一个词。将选中的词追加到输入序列末尾作为新的输入。重复步骤 2-4直到生成结束标记如eos或达到最大长度。def greedy_decode(model, src, src_mask, max_len, start_symbol, end_symbol, device): 贪婪解码示例 model.eval() src src.to(device) src_mask src_mask.to(device) # 编码器前向传播 memory model.encode(src, src_mask) # 初始化解码器输入为起始符号 ys torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device) for i in range(max_len-1): # 为当前生成的序列生成因果掩码 tgt_mask model.generate_square_subsequent_mask(ys.size(1)).to(device) # 解码器前向传播 out model.decode(ys, memory, src_mask, tgt_mask) # 获取最后一个位置的 logits 并预测下一个词 prob model.generator(out[:, -1]) _, next_word torch.max(prob, dim1) next_word next_word.item() ys torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim1) if next_word end_symbol: break return ys # 注意此处的 model.encode, model.decode, model.generator 需要在 Transformer 类中实现相应方法。 # 实际中我们通常直接调用 model(src, tgt) 并手动管理推理循环。6. 关键变体与优化原始的 Transformer 架构是基石后续研究提出了许多重要的变体和优化。6.1 位置编码的演进相对位置编码原始正弦编码是绝对位置编码。相对位置编码如 Transformer-XL、T5 使用的让模型更关注词元之间的相对距离而非绝对位置在处理长文本时泛化能力更强。旋转位置编码RoPE 将位置信息通过旋转矩阵融入查询和键向量中在保持相对位置信息的同时被证明对长上下文扩展更友好被 LLaMA、GPT-NeoX 等模型广泛采用。6.2 注意力机制的优化稀疏注意力计算所有词元对之间的注意力复杂度是O(n^2)对于长序列开销巨大。稀疏注意力如 Longformer、BigBird只计算每个词元与局部窗口内或少数全局词元之间的注意力将复杂度降低到O(n)或O(n log n)。线性注意力通过核函数近似将 softmax 注意力转化为线性复杂度如 Linformer、Performer。FlashAttention一种 IO 感知的精确注意力算法通过分块计算和重计算显著减少 GPU 高带宽内存与片上 SRAM 之间的数据移动极大提升了长序列注意力计算的速度和内存效率已成为训练大模型的事实标准。6.3 模型架构变体仅编码器如 BERT适用于文本分类、命名实体识别等理解任务。仅解码器如 GPT 系列适用于文本生成、代码生成等任务。编码器-解码器如 T5、BART适用于翻译、摘要等 seq2seq 任务。前缀语言模型一种介于仅解码器和编码器-解码器之间的架构将输入作为前缀后续部分自回归生成。6.4 推理优化技术KV 缓存在自回归生成时键K和值V对于已经生成的 token 是固定不变的。KV 缓存将这些中间结果存储起来避免在生成每个新 token 时重复计算大幅提升推理速度。多查询注意力 / 分组查询注意力让多个注意力头共享同一套K和V的投影权重减少了推理时 KV 缓存的大小从而支持更长的上下文或更大的批次对推理速度有显著提升。推测解码使用一个更小、更快的“草稿模型”先生成多个候选 token然后用原始大模型一次性并行验证这些候选。如果验证通过则一次性接受多个 token从而减少大模型的调用次数加速生成。7. 实战使用 PyTorch 构建一个简化的 GPT 模型让我们动手实现一个极简版的 GPT仅解码器来巩固理解。这个模型将包含词嵌入、位置编码、多个解码器层带掩码自注意力和一个输出层。import torch import torch.nn as nn import torch.nn.functional as F import math class GPTDecoderLayer(nn.Module): 简化的 GPT 解码器层只有掩码多头自注意力和前馈网络 def __init__(self, d_model, num_heads, d_ff, dropout0.1): super().__init__() self.ln1 nn.LayerNorm(d_model) self.self_attn MultiHeadAttention(d_model, num_heads) # 复用之前定义的多头注意力 self.dropout1 nn.Dropout(dropout) self.ln2 nn.LayerNorm(d_model) self.ffn PositionwiseFeedForward(d_model, d_ff, dropout) self.dropout2 nn.Dropout(dropout) def forward(self, x, mask): # Pre-LN 结构 attn_output self.self_attn(self.ln1(x), self.ln1(x), self.ln1(x), mask) x x self.dropout1(attn_output) ffn_output self.ffn(self.ln2(x)) x x self.dropout2(ffn_output) return x class SimpleGPT(nn.Module): 一个极简的 GPT 模型 def __init__(self, vocab_size, d_model256, num_heads8, num_layers6, d_ff1024, max_seq_len512, dropout0.1): super().__init__() self.token_embedding nn.Embedding(vocab_size, d_model) self.position_embedding nn.Embedding(max_seq_len, d_model) # 使用可学习的位置嵌入 self.dropout nn.Dropout(dropout) self.layers nn.ModuleList([ GPTDecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers) ]) self.ln_f nn.LayerNorm(d_model) # 最后的层归一化 self.lm_head nn.Linear(d_model, vocab_size, biasFalse) # 语言模型头 # 权重绑定语言模型头的权重与词嵌入层共享常见做法可减少参数 self.lm_head.weight self.token_embedding.weight self.max_seq_len max_seq_len self.d_model d_model self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean0.0, std0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean0.0, std0.02) def forward(self, idx, targetsNone): idx: 输入 token 索引形状 (batch_size, seq_len) targets: 目标 token 索引用于计算损失形状同 idx device idx.device b, t idx.size() assert t self.max_seq_len, f序列长度 {t} 超过了最大长度 {self.max_seq_len} # 1. 词嵌入 位置嵌入 tok_emb self.token_embedding(idx) # (b, t, d_model) pos torch.arange(0, t, dtypetorch.long, devicedevice).unsqueeze(0) # (1, t) pos_emb self.position_embedding(pos) # (1, t, d_model) x self.dropout(tok_emb pos_emb) # 2. 生成因果掩码 causal_mask torch.tril(torch.ones(t, t, devicedevice)).view(1, 1, t, t) # (1, 1, t, t) # 3. 通过所有解码器层 for layer in self.layers: x layer(x, causal_mask) # 4. 最终层归一化和投影到词汇表 x self.ln_f(x) logits self.lm_head(x) # (b, t, vocab_size) loss None if targets is not None: # 计算交叉熵损失忽略 padding 等操作此处省略 loss F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index-1) return logits, loss def generate(self, idx, max_new_tokens, temperature1.0, top_kNone): 自回归生成文本 self.eval() for _ in range(max_new_tokens): # 如果上下文太长裁剪到最大长度一种简单的处理方式 idx_cond idx if idx.size(1) self.max_seq_len else idx[:, -self.max_seq_len:] # 前向传播获取最后一个时间步的 logits logits, _ self(idx_cond) logits logits[:, -1, :] / temperature # (batch_size, vocab_size) # 可选top-k 采样 if top_k is not None: v, _ torch.topk(logits, min(top_k, logits.size(-1))) logits[logits v[:, [-1]]] -float(Inf) # 从分布中采样 probs F.softmax(logits, dim-1) idx_next torch.multinomial(probs, num_samples1) # (batch_size, 1) # 将采样结果拼接到序列中 idx torch.cat((idx, idx_next), dim1) return idx # 示例初始化模型并尝试生成 vocab_size 1000 # 假设词汇表大小 model SimpleGPT(vocab_sizevocab_size, d_model128, num_heads4, num_layers4, d_ff512, max_seq_len256) print(f模型参数量: {sum(p.numel() for p in model.parameters())/1e6:.2f} M) # 模拟一个 batch 的数据 batch_size 2 seq_len 10 input_ids torch.randint(0, vocab_size, (batch_size, seq_len)) logits, loss model(input_ids) print(f输入形状: {input_ids.shape}) print(f输出 logits 形状: {logits.shape}) # (2, 10, 1000) # 尝试生成由于是随机初始化的模型输出无意义 start_tokens torch.randint(0, vocab_size, (batch_size, 1)) generated model.generate(start_tokens, max_new_tokens5, temperature1.0, top_k50) print(f生成结果形状: {generated.shape}) print(f示例生成序列: {generated[0]})8. 常见问题与排查思路在理解和实现 Transformer 时你可能会遇到以下问题问题现象可能原因解决思路训练不稳定损失 NaN学习率过高梯度爆炸层归一化或残差连接实现有误激活函数问题。使用学习率预热使用梯度裁剪检查 Pre-LN/Post-LN 实现是否正确尝试 GELU/SiLU 等更平滑的激活函数。模型无法收敛或收敛慢学习率不合适模型初始化不当数据预处理有问题任务过于复杂。进行学习率搜索使用 Xavier/Glorot 或 Kaiming/He 初始化检查数据标签和分词是否正确尝试更简单的任务或增加数据。推理时生成重复或无意义内容采样策略问题温度过低导致贪婪温度过高导致随机模型训练不足存在重复性惩罚未设置。调整temperature和top_p/top_k参数增加训练步数在生成时加入重复惩罚repetition_penalty。处理长序列时内存溢出 (OOM)注意力矩阵(seq_len, seq_len)过大消耗O(n^2)内存。使用稀疏注意力如 Longformer使用 FlashAttention如果框架支持增加梯度检查点减少批次大小或序列长度。位置编码外推性差使用绝对正弦位置编码的模型在推理时遇到比训练时更长的序列性能下降。使用相对位置编码如 RoPE、ALiBi在训练时使用更长的上下文进行微调。KV 缓存导致推理错误缓存未正确更新或重置在生成不同序列时缓存混用。确保在开始生成新序列时清空 KV 缓存检查缓存张量的形状与当前生成步数是否匹配。9. 最佳实践与工程建议从预训练模型开始除非有特定研究目的否则不要从头开始训练大型 Transformer。利用 Hugging Facetransformers库加载 BERT、GPT-2、T5 等预训练模型进行微调这是最高效的方式。注意计算资源Transformer 模型尤其是大模型对 GPU 显存要求很高。训练时注意使用混合精度训练、梯度累积、模型并行、数据并行等技术来优化资源使用。使用现代库和优化器优先使用 PyTorch 或 TensorFlow 等成熟框架并搭配优化器如 AdamW并配合学习率调度器如带热身的余弦衰减。监控训练过程密切关注训练损失和验证损失曲线使用 WandB 或 TensorBoard 等工具进行可视化。早停法可以防止过拟合。理解你的数据Tokenizer 的选择如 BPE、WordPiece、SentencePiece对模型性能影响巨大。确保分词方式与你的任务和语言匹配。生产环境部署优化推理时利用模型量化、动态批处理、持续批处理、FlashAttention、PagedAttentionvLLM等技术来降低延迟、提高吞吐量。安全与伦理Transformer 模型可能生成有偏见、有害或不准确的内容。在部署前必须进行全面的评估、红队测试并考虑加入内容过滤和安全层。Transformer 的原理虽然源于一篇学术论文但其影响早已遍及工业界的每一个角落。从理解它的核心——自注意力机制开始到掌握其完整的编码器-解码器架构再到熟悉各种高效的变体和优化技术这条学习路径将为你打开通往现代人工智能核心的大门。希望这篇近万字的详解能成为你探索 Transformer 世界的一块坚实基石。动手运行文中的代码修改参数观察输出变化是理解这一切的最佳方式。 30款热门AI模型一站整合DeepSeek/GLM/Qwen 随心用限时 5 折。 点击领海量免费额度