返回首页
笔记2026-02-2418 min

大模型字斟句酌的暗箱操作:Decoding 算法全景硬核拆解

Greedy、Beam、Top-K、Top-P、Min-P、Contrastive 与 Constrained Decoding 的底层机制与 PyTorch 实现全景拆解。

大模型字斟句酌的暗箱操作:Decoding 算法全景硬核拆解

大模型(LLM)的本质是什么?是一个无情的“下一个词汇概率预测机”。 给定前文 x<tx_{<t},模型会在词表(比如 10 万个 Token)上输出一个概率分布 P(xtx<t)P(x_t \mid x_{<t})

但问题来了:手里拿着这 10 万个词的概率,你究竟该挑哪一个输出到屏幕上? 这就到了 Decoding(解码策略) 登场的时刻。解码算法的选择,直接决定了模型是严谨死板,还是天马行空;是逻辑严密,还是废话连篇。

今天,我们将深入 PyTorch 代码底层,硬核拆解 Greedy, Beam Search, Top-K, Top-P, Min-P, Contrastive Search 以及用于格式控制的 Constrained Decoding (Trie)

贯穿全文的基础设定: 假设我们的词表里只有 5 个词:['猫', '狗', '猪', '黑洞', '量子'] 当前模型的输出概率(Probs)为:[0.50, 0.30, 0.15, 0.04, 0.01]


1. Greedy Search (贪心搜索):目光短浅的完美主义者

  • Motivation (核心动机):既然我们要找最连贯的句子,那每次挑概率最高的那一个词不就行了?
  • 原理:每一步都严格选择 P(xtx<t)P(x_t \mid x_{<t}) 最大的那个 Token。没有任何随机性。

xt=argmaxP(xtx<t)x_t = \arg\max P(x_t \mid x_{<t})

  • Python 代码实现 (核心逻辑)
PYTHON
import torch

def greedy_search(logits):
    # logits shape: (batch_size, vocab_size)
    # 直接取绝对最大值的索引
    next_token = torch.argmax(logits, dim=-1)
    return next_token
  • Pros & Cons (优缺点)
    • 优点:极快,计算开销最小,且结果 100% 可复现。适合代码生成、数学计算等要求绝对严谨的任务。
    • 缺点局部最优不等于全局最优。它极容易陷入死循环(比如一直输出“我我我我”),且生成的文本极其无聊,缺乏人类语言的多样性。

2. Beam Search (束搜索):平行宇宙探索者

  • Motivation (核心动机):贪心搜索太短视了!如果第一步选了概率第二的词,也许第二步能引出一个极其惊艳的绝世好词呢?我们需要保留多个“平行宇宙”的候选分支。

  • 原理:维护一个大小为 kk(Beam Size,束宽)的候选序列集合。每生成一步,都会展开这 kk 个序列的所有可能下一个词,计算累积的对数概率(Log Prob),然后再从中挑选出全局得分最高的 kk 个序列,淘汰其他的。

  • Python 代码实现 (概念伪代码)

PYTHON
def beam_search_step(current_beams, logits, beam_size=3):
    # current_beams: [(token_seq, log_prob_sum), ...]
    candidates = []
    for seq, score in current_beams:
        probs = torch.log_softmax(logits(seq), dim=-1)
        # 获取当前分支 top-k 的下一个词
        topk_probs, topk_indices = torch.topk(probs, beam_size)
        
        for i in range(beam_size):
            new_seq = seq + [topk_indices[i].item()]
            new_score = score + topk_probs[i].item()
            candidates.append((new_seq, new_score))
            
    # 全局再筛选出得分最高的 Top-K 个分支
    candidates.sort(key=lambda x: x[1], reverse=True)
    return candidates[:beam_size]
  • Pros & Cons
    • 优点:能找到全局概率更优的句子,极大地减少了语法错误和逻辑断层。机器翻译时代的绝对王者。
    • 缺点:非常吃显存和算力(计算量扩大了 kk 倍)。而且,在开放式文本生成中,最高概率的句子往往是最无聊的废话(比如“我不知道”)。

3. Top-K Sampling:VIP 俱乐部的掷骰子

  • Motivation (核心动机):为了让对话机器人显得“像个人”,我们必须引入随机采样(Sampling)。但如果直接按概率掷骰子,万一掷到了概率只有 0.0001 的极品生僻词(比如把前文的“量子”接在“我吃了一口”后面),句子直接就崩了。

  • 原理:设立一个 VIP 门槛 KK。每次只保留概率排名前 KK 的 Token,把剩下的所有词的概率强制设为 0。然后在这个小圈子里重新归一化概率,掷骰子。

  • Python 代码实现

PYTHON
def top_k_sampling(logits, k=50):
    # 找到第 k 大的 logit 值
    top_k_values, _ = torch.topk(logits, k)
    kth_value = top_k_values[:, -1].unsqueeze(-1)
    
    # 将所有小于第 k 大值的 logit 设为负无穷 (概率变为 0)
    indices_to_remove = logits < kth_value
    logits[indices_to_remove] = float('-inf')
    
    # 重新 softmax 并采样
    probs = torch.softmax(logits, dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)
    return next_token
  • Pros & Cons
    • 优点:成功切断了长尾的“垃圾词汇”,保证了生成的底线质量,同时带有随机性。
    • 缺点KK 是一个死板的常数。遇到“薛定谔的[猫/狗]”这种只有 2 个合理词的语境,K=50K=50 会引入 48 个废话;遇到字典里有 100 个同义词的语境,K=50K=50 又限制了创造力。

4. Top-P (Nucleus) Sampling:动态预算的核采样

  • Motivation (核心动机):既然固定的 KK 不科学,我们能不能按**“累计概率预算”**来圈定候选词?

  • 原理:将词表按概率从高到低排序,依次累加概率。当累加值刚刚超过设定的阈值 PP(如 0.9)时,立刻停止。在这个“核心圈(Nucleus)”里的词,才能参与掷骰子。

  • Python 代码实现

PYTHON
def top_p_sampling(logits, p=0.9):
    # 排序
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
    
    # 找到累加概率超过 p 的位置,并将其右侧的所有词剔除
    sorted_indices_to_remove = cumulative_probs > p
    # 必须保证至少保留一个词(位移操作)
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0
    
    # 散布回原位并掩码
    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
    logits[indices_to_remove] = float('-inf')
    
    probs = torch.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)
  • Pros & Cons
    • 优点:候选集合大小会根据当前语境的确定性动态收缩和扩张,是目前 ChatGPT 等主流大模型的默认策略。
    • 缺点:当大模型非常不自信(概率分布极其平缓)时,哪怕设了 P=0.9P=0.9,依然会把上千个低质量垃圾词圈进来,导致小概率崩坏。

5. Min-P Sampling:开源社区的 2024 新宠

  • Motivation (核心动机):Top-P 的致命弱点在于它看的是“绝对总和”。如果榜一大哥的概率是 80%,剩下的 20% 是由 1000 个垃圾词凑成的,Top-P(0.9) 依然会把几百个垃圾词放进来。能不能设计一种**“相对门槛”**?

  • 原理:定下一个相对下限 min_p(比如 0.1)。如果当前概率最高的词是 80%,那么门槛就是 80%×0.1=8%80\% \times 0.1 = 8\%,低于 8% 的统统滚蛋;如果当前概率最高的词只有 10%,门槛就自动降到 10%×0.1=1%10\% \times 0.1 = 1\%

  • Python 代码实现 (极其优雅)

PYTHON
def min_p_sampling(logits, min_p=0.05):
    probs = torch.softmax(logits, dim=-1)
    # 获取榜一大哥的概率
    max_probs, _ = probs.max(dim=-1, keepdim=True)
    
    # 相对门槛 = 榜一大哥概率 * min_p
    thresholds = max_probs * min_p
    
    # 剔除低于门槛的渣渣
    indices_to_remove = probs < thresholds
    logits[indices_to_remove] = float('-inf')
    
    probs = torch.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)
  • Pros & Cons
    • 优点:用极其简单的两行代码,完美解决了 Top-P 的长尾垃圾问题。当模型确信时,候选集极小;当模型发散时,依然保有创造力。Llama.cpp 和开源社区目前极度推崇。

6. Contrastive Search (对比搜索):查重雷达

  • Motivation (核心动机):无论怎么采样,模型都有可能因为内部 Attention 的崩塌,陷入“无脑复读”的死循环(Repetition Degradation)。如何优雅地打破复读?
  • 原理:在评估一个候选词时,不仅看模型给它的置信度(概率),还要看这个词的语义向量跟前面已经生成的上下文的语义向量像不像。如果太像(相似度高),就狠狠扣分!

Score(xt)=(1α)×P(xtx<t)α×maxxjx<t{Cosine(hxt,hxj)}\text{Score}(x_t) = (1 - \alpha) \times P(x_t \mid x_{<t}) - \alpha \times \max_{x_j \in x_{<t}} \{ \text{Cosine}(h_{x_t}, h_{x_j}) \}

  • Python 代码实现 (核心逻辑提取)
PYTHON
def contrastive_search(logits, hidden_states, context_hidden_states, alpha=0.6):
    # 假设先用 Top-K 圈定候选集
    top_k_probs, top_k_indices = torch.topk(torch.softmax(logits, dim=-1), k=5)
    
    best_score = float('-inf')
    best_token = None
    
    for i in range(5):
        token_id = top_k_indices[0, i]
        token_prob = top_k_probs[0, i]
        token_hidden = hidden_states[token_id] # 获取候选词的隐状态向量
        
        # 计算与历史上下文的最大余弦相似度 (查重率)
        similarity_scores = torch.cosine_similarity(token_hidden, context_hidden_states, dim=-1)
        max_similarity = torch.max(similarity_scores)
        
        # 惩罚项公式:自信度高,且查重率低,才是好词
        score = (1 - alpha) * token_prob - alpha * max_similarity
        
        if score > best_score:
            best_score = score
            best_token = token_id
            
    return best_token
  • Pros & Cons
    • 优点:极大地提升了长文本生成的连贯性,几乎彻底消灭了生成重复内容的痛点,比简单的“频率惩罚(Repetition Penalty)”要聪明得多。
    • 缺点:极其吃显存和算力。每走一步都要和前面的历史做向量相似度计算,计算复杂度随序列长度线性爆炸。

7. Constrained Decoding (Trie树):戴着镣铐跳舞的护卫

  • Motivation (核心动机):大模型在做 API 调用(Function Calling)、信息抽取输出 JSON 或写 SQL 语句时,如果括号少了一个,或者把 {"age": 18} 写成了 {"age": 十八},整个系统就会直接报错崩盘。如何在生成时进行100% 物理级别的格式阻断

  • 原理:构建一个基于正则表达式或语法规则的 前缀树 (Trie,比如 Marisa Trie)。在每一步解码前,拿着当前的上下文去遍历 Trie 树。凡是不符合 JSON 格式或不在指定实体库里的 Token,哪怕大模型给它 99.9% 的概率,我们也直接强行把它的概率抹杀为负无穷(Logit Masking)。

  • Python 代码实现 (前缀树掩码伪代码)

PYTHON
def constrained_decoding_step(logits, current_prefix, trie_automaton):
    # 1. 查询状态机:在这个前缀下,哪些下一个字符是合法的?
    # 比如当前生成了 `{"name": "`,下一个合法的只能是字符串
    allowed_token_ids = trie_automaton.get_allowed_next_tokens(current_prefix)
    
    # 2. 生成一个全负无穷的 mask
    mask = torch.full_like(logits, float('-inf'))
    
    # 3. 只有合法的 token 才允许原样保留
    mask[0, allowed_token_ids] = 0
    
    # 4. 把 mask 加到模型原始 logits 上,物理阻断非法字符
    constrained_logits = logits + mask
    
    # 5. 最后再用 greedy 或 sampling 挑词
    return greedy_search(constrained_logits)
  • Pros & Cons
    • 优点:将输出格式控制的稳定性提升到了绝对的 100%。是目前各种 Agent 框架(如 LangChain、Guidance、Outlines)在底层赖以生存的基石。
    • 缺点:构建状态机和每次查询匹配极其消耗 CPU 计算资源;而且,如果大模型的“内心极度抗拒”输出这种格式,强行扭转它的 Logits 会导致它生成极其诡异、语无伦次的内容(内部逻辑崩塌)。

Decoding 算法就像是拴在大模型这头猛兽脖子上的那根缰绳。 它不用去改变底座里动辄千亿的参数,仅仅是通过在 Softmax 输出端动一点巧妙的数学手脚,就能让模型在“严谨如机器”与“灵动如文豪”之间自由切换。懂了这些暗箱操作,你才能真正掌控大模型落笔的灵魂。

评论区

0 条评论

仅订阅用户可发表评论。使用订阅邮箱登录后可评论。

还没有评论,来抢沙发。