Home 論文解説: Fast Inference from Transformers via Speculative Decoding — 投機的デコーディングの原論文
投稿
キャンセル

📄 論文解説: Fast Inference from Transformers via Speculative Decoding — 投機的デコーディングの原論文

本記事は Fast Inference from Transformers via Speculative Decoding の解説記事です。

論文概要(Abstract)

本論文は、投機的デコーディング(Speculative Decoding)の理論的基盤を確立した原論文の一つである。著者らは、CPUの投機的実行(speculative execution)の概念をLLMの自己回帰生成に応用し、小型のドラフトモデルが複数の候補トークンを高速に生成し、大型のターゲットモデルが1回のフォワードパスで一括検証するフレームワークを提案した。著者らは、修正サンプリング(modified rejection sampling)により出力分布がターゲットモデルと完全に一致する(ロスレス)ことを数学的に証明し、T5-XXL(11B)ターゲット + T5-Small(60M)ドラフトで約2〜3倍の推論高速化を報告している。

この記事は Zenn記事: vLLM投機的デコーディング+Medusa Headで推論レイテンシを半減させる の深掘りです。

情報源

  • arXiv ID: 2211.17192
  • URL: https://arxiv.org/abs/2211.17192
  • 著者: Yaniv Leviathan, Matan Kalman, Yossi Matias(Google DeepMind)
  • 発表年: 2023(初版2022年11月)
  • 分野: cs.CL, cs.LG

背景と動機(Background & Motivation)

Transformer型LLMの自己回帰生成は、1トークンずつ逐次的にフォワードパスを実行する。この処理はメモリバウンドであり、現代のGPU/TPUの計算能力は十分に活用されていない。著者らは以下の2つの観察に基づいてSpeculative Decodingを着想した。

観察1: トークン難易度の不均一性。テキスト生成において、一部のトークンは予測が容易であり(例: 定型句、文法的に確定する語)、小型モデルでも高い精度で予測できる。一方、創造的な内容や専門用語は大型モデルの表現力を必要とする。

観察2: メモリ帯域幅のボトルネック。現代のアクセラレータ(GPU/TPU)は「毎秒数百兆回の演算」が可能だが、メモリ帯域幅は「毎秒数兆バイト」に制限される。Transformerの推論では「読み込んだ1バイトあたり数回の演算」しか行わず、計算ユニットが大幅に遊んでいる状態である(Google Researchブログより引用)。

この2つの観察から、「小型モデルで簡単なトークンを高速に予測し、大型モデルの遊んでいる計算能力で検証する」というアイデアが生まれた。

主要な貢献(Key Contributions)

  • 貢献1: CPUの投機的実行を確率的設定に一般化した「投機的サンプリング(Speculative Sampling)」アルゴリズムを提案し、出力分布の同一性を数学的に証明
  • 貢献2: ドラフトモデル + ターゲットモデルの2段階フレームワークを定式化し、スピードアップの理論的上限を導出
  • 貢献3: T5-XXL(11Bパラメータ)ターゲット + T5-Small(60Mパラメータ)ドラフトの組み合わせで、WMT翻訳・CNN/DailyMail要約タスクにおいて約2〜3倍の高速化を実証

技術的詳細(Technical Details)

投機的サンプリングアルゴリズム

投機的デコーディングの核となるのは、修正棄却サンプリング(modified rejection sampling)である。アルゴリズムの各ステップは以下の通り:

  1. ドラフト生成: 小型のドラフトモデル$M_q$から$\gamma$個のトークン $x_1, x_2, \ldots, x_\gamma$ を自己回帰的に生成する
  2. 並列検証: ターゲットモデル$M_p$が、$\gamma$個のトークンを含む入力列に対して1回のフォワードパスを実行し、各位置の確率分布 $p(x_ix_{<i})$ を計算する
  3. 受理/棄却判定: 各ドラフトトークン$x_i$について、以下の確率で受理する
\[\text{accept}(x_i) = \min\left(1, \frac{p(x_i | x_{<i})}{q(x_i | x_{<i})}\right)\]

ここで:

  • $p(x_ix_{<i})$: ターゲットモデルがトークン$x_i$に割り当てる条件付き確率
  • $q(x_ix_{<i})$: ドラフトモデルがトークン$x_i$に割り当てる条件付き確率
  1. 棄却時の修正サンプリング: トークン$x_j$が棄却された場合、位置$j$で以下の修正分布からトークンをリサンプリングする
\[p'(x) = \frac{\max(0, p(x | x_{<j}) - q(x | x_{<j}))}{\sum_{x'} \max(0, p(x' | x_{<j}) - q(x' | x_{<j}))}\]

この修正分布は「ターゲットモデルの確率がドラフトモデルの確率を上回る部分」を正規化したものであり、受理/棄却の全体を通じて出力分布がターゲットモデルと完全に一致することが保証される。

出力分布の同一性の証明

著者らは、修正棄却サンプリングにより生成されるトークン列の分布が、ターゲットモデルから直接サンプリングした場合と同一であることを証明している。直感的には、以下のように理解できる:

  • ドラフトトークンが受理される確率は $\min(1, p/q)$ であり、ターゲットモデルの確率が高いトークンほど受理されやすい
  • 棄却された場合の修正分布 $p’$ は、ドラフトモデルが「過剰に」サンプリングした確率質量を補正する
  • 受理と棄却を合わせた全体の分布は、任意のトークン$x$に対して $p(xx_{<j})$ と一致する

この性質により、投機的デコーディングはロスレスな高速化手法となる。出力品質の劣化は理論的に生じない。

スピードアップの理論的上限

著者らは、投機的デコーディングのスピードアップ比の理論的上限を以下のように導出している。

ドラフトトークンの平均受理率を $\alpha$ とすると、1回の投機的デコーディングステップで期待される受理トークン数は:

\[\mathbb{E}[\text{accepted tokens}] = \frac{1 - \alpha^{\gamma+1}}{1 - \alpha}\]

ドラフトモデルの1フォワードパスのコストをターゲットモデルの$c$倍($c < 1$)とすると、スピードアップ比は:

\[\text{Speedup} = \frac{\mathbb{E}[\text{accepted tokens}]}{c \cdot \gamma + 1}\]

$\alpha = 1$(理想的なケース)では $\text{Speedup} = (\gamma + 1) / (c \cdot \gamma + 1)$ となる。例えば $c = 0.05$(T5-Smallの計算コストがT5-XXLの5%)、$\gamma = 5$ の場合、理論上限は $6 / 1.25 = 4.8$ 倍となる。

ドラフト長$\gamma$の最適化

最適なドラフト長$\gamma^*$は受理率$\alpha$の関数として決まる。受理率が高いほど長いドラフトが有利であり、低いほど短いドラフトが最適となる。

\[\gamma^* = \arg\max_{\gamma} \frac{1 - \alpha^{\gamma+1}}{(1 - \alpha)(c \cdot \gamma + 1)}\]

実用的には、$\alpha$をオンラインで推定し、$\gamma$を動的に調整する方法が効果的とされている(後続のEAGLE-2等で実装)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
from typing import Optional

def speculative_decode(
    target_model: torch.nn.Module,
    draft_model: torch.nn.Module,
    input_ids: torch.Tensor,
    gamma: int = 5,
    temperature: float = 1.0,
) -> torch.Tensor:
    """投機的デコーディングの1ステップ(簡略化)

    Args:
        target_model: ターゲットモデル(大型)
        draft_model: ドラフトモデル(小型)
        input_ids: 入力トークン列 (1, seq_len)
        gamma: ドラフトトークン数
        temperature: サンプリング温度

    Returns:
        受理されたトークン列 (1, num_accepted + 1)
    """
    # Step 1: ドラフト生成
    draft_tokens = []
    draft_probs = []
    current_input = input_ids

    for _ in range(gamma):
        with torch.no_grad():
            logits = draft_model(current_input).logits[:, -1, :]
            probs = torch.softmax(logits / temperature, dim=-1)
            token = torch.multinomial(probs, num_samples=1)
            draft_tokens.append(token)
            draft_probs.append(probs)
            current_input = torch.cat([current_input, token], dim=-1)

    # Step 2: ターゲットモデルで並列検証
    draft_sequence = torch.cat(draft_tokens, dim=-1)  # (1, gamma)
    full_input = torch.cat([input_ids, draft_sequence], dim=-1)

    with torch.no_grad():
        target_logits = target_model(full_input).logits
        # 検証位置のlogitsを取得
        verify_start = input_ids.shape[1] - 1
        target_probs_all = torch.softmax(
            target_logits[:, verify_start:verify_start + gamma, :] / temperature,
            dim=-1,
        )

    # Step 3: 受理/棄却判定
    accepted_tokens = []
    for i in range(gamma):
        token = draft_tokens[i].item()
        p = target_probs_all[0, i, token].item()  # ターゲット確率
        q = draft_probs[i][0, token].item()  # ドラフト確率

        # 受理確率
        accept_prob = min(1.0, p / (q + 1e-10))
        if torch.rand(1).item() < accept_prob:
            accepted_tokens.append(token)
        else:
            # 棄却: 修正分布からリサンプリング
            residual = torch.clamp(target_probs_all[0, i] - draft_probs[i][0], min=0)
            residual = residual / (residual.sum() + 1e-10)
            resampled = torch.multinomial(residual, num_samples=1)
            accepted_tokens.append(resampled.item())
            break  # 棄却以降のトークンは無効

    return torch.tensor([accepted_tokens], device=input_ids.device)

実装のポイント(Implementation)

ドラフトモデルの選択基準: 著者らは、ドラフトモデルとターゲットモデルが同じ語彙(トークナイザ)を共有する必要があることを強調している。同一ファミリのモデル(例: T5-Small + T5-XXL、Llama-3.2-1B + Llama-3.3-70B)を使用するのが安全である。

バッチサイズの制約: 投機的デコーディングはバッチサイズ1〜4で最も効果が高い。バッチサイズが大きいと、ターゲットモデルのフォワードパスがコンピュートバウンドになり、並列検証の余剰計算能力が減少する。

KVキャッシュの管理: ドラフトモデルとターゲットモデルの両方のKVキャッシュを管理する必要がある。棄却が発生した場合、棄却位置以降のKVキャッシュを巻き戻す(ロールバックする)必要がある。

Greedy decodingとの互換性: temperature=0(greedy decoding)では受理判定がargmaxの一致判定に単純化される。多くの実用ケースではgreedy decodingが使われるため、完全にロスレスな高速化が得られる。

実験結果(Results)

著者らはT5モデルファミリで評価を行っている。

ターゲットドラフトタスクスピードアップ
T5-XXL (11B)T5-Small (60M)WMT EN→DE2.05x
T5-XXL (11B)T5-Small (60M)WMT DE→EN2.27x
T5-XXL (11B)T5-Small (60M)CNN/DailyMail2.54x
T5-XXL (11B)T5-Base (220M)WMT EN→DE1.82x
T5-XXL (11B)T5-Base (220M)CNN/DailyMail2.36x

分析: 論文の実験結果より、ドラフトモデルのサイズはコスト比$c$と受理率$\alpha$のトレードオフを決める。T5-Smallは$c$が小さいが$\alpha$も低い。T5-Baseは$c$が大きいが$\alpha$が高い。最適な選択はタスクとハードウェアに依存する。要約タスク(CNN/DailyMail)では翻訳タスクより高いスピードアップが得られており、これはプロンプトと出力の重複(コピー操作)が多いためドラフトの受理率が高くなることが要因とされている。

Google Search AI Overviewsでの本番稼働: Google Researchブログの報告によると、投機的デコーディングはGoogle SearchのAI Overviews機能で本番稼働しており、応答品質を維持しながら生成速度を向上させているとされている。

実運用への応用(Practical Applications)

低レイテンシが求められるAPI: チャットボット、コード補完、検索エンジンのAI回答など、ユーザーの体感レイテンシが重要なサービスに適している。Google自身がSearchで採用していることが実用性を示している。

既存モデルの高速化: 新たなモデルの学習やファインチューニングが不要であり、同一ファミリの小型モデルをドラフトモデルとして指定するだけで導入できる。これは後続のEAGLEやMedusaと比較した場合の大きな利点である。

制約: 高QPSのバッチ処理環境では効果が限定的である。また、ドラフトモデルを別途GPUメモリに載せる必要があるため、メモリ制約のある環境ではMedusa方式が優れている。

関連研究(Related Work)

  • Chen et al. (2023): Google DeepMindの別チームによる同時期の独立した研究(arXiv: 2302.01318)。同様の投機的サンプリングアルゴリズムを提案し、理論的保証の補完的な証明を提供
  • EAGLE (Li et al., 2024): Feature-levelドラフトにより受理率を大幅改善。本論文のドラフトモデル方式を基盤としつつ、ドラフト精度を向上
  • Medusa (Cai et al., 2024): ドラフトモデル不要の追加ヘッド方式。本論文の外部ドラフトモデルの課題(メモリ、管理コスト)を解決
  • SpecInfer (Miao et al., 2024): ツリー構造の投機推論でバッチスループットを改善

まとめと今後の展望

本論文は、投機的デコーディングの理論的基盤を確立した原論文であり、ドラフトモデル + ターゲットモデルの2段階検証フレームワーク、修正棄却サンプリングによるロスレス保証、スピードアップの理論的上限の導出という3つの貢献を行った。T5-XXL + T5-Smallで2〜3倍の高速化を実証し、Google SearchのAI Overviewsで本番稼働している。後続のEAGLE、Medusa、EAGLE-3はすべて本論文のフレームワークを基盤として発展しており、投機的デコーディング研究の出発点として位置づけられている。

参考文献

この投稿は CC BY 4.0 でライセンスされています。