Home 論文解説: Infini-attention - 無限長コンテキストを実現する圧縮メモリ機構
投稿
キャンセル

📄 論文解説: Infini-attention - 無限長コンテキストを実現する圧縮メモリ機構

論文解説: Infini-attention - 無限長コンテキストを実現する圧縮メモリ機構

論文概要

Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention (arXiv:2404.07143, 2024年4月) は、Google DeepMindによる 無限長コンテキスト に対応する注意機構の提案論文です。従来のTransformerが固定長コンテキストウィンドウ(4K-128K トークン)に制約されていたのに対し、Infini-attentionは 圧縮メモリ を導入することで、理論上無限長の入力を処理可能にします。

著者: Tsendsuren Munkhdalai et al. (Google DeepMind)

主要な貢献:

  • セグメント化+圧縮メモリ: 長文脈を小セグメントに分割し、過去のセグメント情報を固定サイズメモリに集約
  • 線形メモリ増加: コンテキスト長に対してO(1)のメモリ使用量(従来はO(N^2))
  • 1M-tokenタスクで検証: Passkey retrieval(100%想起)、BookSum(要約品質向上)で有効性実証
  • 既存モデルへの統合: 標準Transformerへのドロップイン置換が可能

実験結果の核心:

  • 1M-tokenのPasskey retrieval: 100%正解(従来手法は失敗)
  • BookSum要約: ROUGE-1/2/L で従来手法を上回る
  • メモリ効率: 1M-token処理でVRAM使用量が従来の1/10

背景と動機

Transformerのコンテキスト長問題

標準的なTransformerのSelf-Attention機構は、コンテキスト長 $N$ に対して以下の計算量・メモリ量が必要です:

\[\begin{align} \text{計算量} &= O(N^2 \cdot d) \\ \text{メモリ量} &= O(N^2) \quad \text{(attention matrix)} \end{align}\]

where $d$ はモデルの次元数。

具体例:

  • GPT-4 (8K context): Attention matrix = 8K × 8K = 64M 要素
  • Claude 3 (100K context): 100K × 100K = 10B 要素 (メモリ枯渇)

既存の長文脈対応手法:

  • Sparse Attention (Longformer, BigBird): 局所的な注意のみ計算(情報損失)
  • Recurrent models (RWKV, Mamba): RNN的な圧縮(長期依存性の減衰)
  • Retrieval-based (RAG): 関連部分のみ取得(検索精度に依存)

Infini-attentionのアプローチ

本論文は 圧縮メモリ を導入し、以下を両立:

  1. 固定メモリサイズ: コンテキスト長に関係なく $O(1)$ のメモリ使用
  2. 情報保持: 過去の全セグメント情報を圧縮形式で保存
  3. 効率的更新: 線形注意とEMA(指数移動平均)で高速更新

技術的詳細

アーキテクチャ概要

Infini-attentionは、標準的なMulti-Head Attentionを以下のように拡張します:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
Input Sequence (1M tokens)
    ↓
Segmentation (chunks of 2048 tokens)
    ↓
    ┌─────────────────────────────────┐
    │  Infini-Attention Layer         │
    │  ┌──────────────┬──────────────┐│
    │  │ Local        │ Memory       ││
    │  │ Attention    │ Retrieval    ││
    │  │ (intra-seg)  │ (cross-seg)  ││
    │  └──────────────┴──────────────┘│
    │         ↓                        │
    │   Weighted Combination           │
    │         ↓                        │
    │   Compressive Memory Update      │
    └─────────────────────────────────┘
    ↓
Output (next segment)

数式定義

標準Self-Attention(復習)

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V\]

where:

  • $Q, K, V \in \mathbb{R}^{N \times d_k}$ (Query, Key, Value行列)
  • $N$: シーケンス長、$d_k$: 次元数

Infini-attentionの定義

セグメント $s$ における出力 $H_s$ は以下の2つの注意機構を組み合わせます:

\[H_s = \sigma \cdot H_s^{\text{local}} + (1 - \sigma) \cdot H_s^{\text{mem}}\]

where:

  • $H_s^{\text{local}}$: セグメント内の標準Self-Attention
  • $H_s^{\text{mem}}$: 圧縮メモリからの情報取得
  • $\sigma \in [0, 1]$: 学習可能なゲート(重み調整)

ローカルAttention

\[H_s^{\text{local}} = \text{softmax}\left(\frac{Q_s K_s^T}{\sqrt{d_k}}\right) V_s\]

これは標準的なSelf-Attentionで、セグメント $s$ 内のトークンのみを参照。

メモリRetrieval

圧縮メモリ $M_s$ と正規化項 $z_s$ を使用:

\[H_s^{\text{mem}} = \frac{Q_s M_s}{Q_s z_s}\]

where:

  • $M_s \in \mathbb{R}^{d_k \times d_v}$: 圧縮メモリ行列(固定サイズ
  • $z_s \in \mathbb{R}^{d_k}$: 正規化ベクトル

直感的説明: $Q_s M_s$ は「現在のクエリと過去の記憶の類似度加重和」を計算。$z_s$ で正規化して平均を取る。


圧縮メモリの更新

線形Attention近似

標準Attentionは以下のように線形化できます(カーネル関数 $\phi$ を導入):

\[\text{Attention}(Q, K, V) \approx \frac{\phi(Q) (\phi(K)^T V)}{\phi(Q) \phi(K)^T}\]

where $\phi$: 非線形カーネル関数(例: $\phi(x) = \text{elu}(x) + 1$)。

この定式化により、$\phi(K)^T V$ を事前計算可能(メモリに保存)。

メモリ更新式

セグメント $s$ の終了時、メモリを以下のように更新:

\[\begin{align} M_{s+1} &= M_s + \phi(K_s)^T V_s \\ z_{s+1} &= z_s + \sum_{i=1}^{L_s} \phi(K_{s,i}) \end{align}\]

where:

  • $L_s$: セグメント $s$ の長さ(例: 2048トークン)
  • $M_{s+1}$: 更新後のメモリ(サイズ不変

重要: $M$ のサイズは $d_k \times d_v$ で固定($N$ に依存しない)。

EMA(指数移動平均)による減衰

古い情報の影響を減衰させるため、EMAを適用:

\[M_{s+1} = \delta \cdot M_s + (1 - \delta) \cdot \phi(K_s)^T V_s\]

where $\delta \in [0, 1]$ はハイパーパラメータ(典型的には0.9-0.99)。

効果:

  • $\delta = 1.0$: すべての履歴を同等に保持(古い情報も残る)
  • $\delta = 0.5$: 最近の情報を優先(指数的減衰)

学習可能なゲート $\sigma$

セグメントごとに、ローカルAttentionとメモリRetrievalの重みを調整:

\[\sigma_s = \text{sigmoid}(W_\sigma \cdot [H_s^{\text{local}}, H_s^{\text{mem}}] + b_\sigma)\]

学習プロセス:

  • 初期: $\sigma \approx 0.5$ (両者を均等に使用)
  • 訓練後: タスクに応じて自動調整
    • 短期依存タスク → $\sigma \approx 1.0$ (ローカル優先)
    • 長期依存タスク → $\sigma \approx 0.0$ (メモリ優先)

実装のポイント

PyTorch実装例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import torch.nn as nn

class InfiniAttention(nn.Module):
    def __init__(self, d_model, d_k, d_v, delta=0.9):
        super().__init__()
        self.d_k = d_k
        self.d_v = d_v
        self.delta = delta

        # Query, Key, Value projections
        self.W_q = nn.Linear(d_model, d_k)
        self.W_k = nn.Linear(d_model, d_k)
        self.W_v = nn.Linear(d_model, d_v)

        # 学習可能なゲート
        self.gate = nn.Sequential(
            nn.Linear(d_v * 2, 1),
            nn.Sigmoid()
        )

        # 圧縮メモリ(状態として保持)
        self.register_buffer('memory', torch.zeros(d_k, d_v))
        self.register_buffer('z', torch.zeros(d_k))

    def kernel_function(self, x):
        """線形Attentionのカーネル関数"""
        return torch.nn.functional.elu(x) + 1.0

    def local_attention(self, Q, K, V):
        """標準Self-Attention(セグメント内)"""
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn_weights = torch.softmax(scores, dim=-1)
        return torch.matmul(attn_weights, V)

    def memory_retrieval(self, Q):
        """圧縮メモリからの情報取得"""
        Q_phi = self.kernel_function(Q)  # (batch, seq, d_k)

        # Q_phi @ memory → (batch, seq, d_v)
        retrieved = torch.matmul(Q_phi, self.memory)

        # 正規化
        normalization = torch.matmul(Q_phi, self.z).unsqueeze(-1)  # (batch, seq, 1)
        return retrieved / (normalization + 1e-6)

    def update_memory(self, K, V):
        """メモリ更新(EMA)"""
        K_phi = self.kernel_function(K)  # (batch, seq, d_k)

        # ΔM = K_phi^T @ V → (d_k, d_v)
        delta_memory = torch.matmul(K_phi.transpose(-2, -1), V).mean(dim=0)

        # EMA更新
        self.memory = self.delta * self.memory + (1 - self.delta) * delta_memory

        # 正規化項更新
        delta_z = K_phi.sum(dim=(0, 1))  # (d_k,)
        self.z = self.delta * self.z + (1 - self.delta) * delta_z

    def forward(self, x, segment_idx):
        """
        Args:
            x: (batch, seq_len, d_model)
            segment_idx: セグメント番号(0始まり)
        """
        batch_size, seq_len, _ = x.shape

        # Query, Key, Value生成
        Q = self.W_q(x)  # (batch, seq, d_k)
        K = self.W_k(x)
        V = self.W_v(x)  # (batch, seq, d_v)

        # ローカルAttention
        H_local = self.local_attention(Q, K, V)

        # メモリRetrieval
        H_mem = self.memory_retrieval(Q)

        # ゲート適用
        gate_input = torch.cat([H_local, H_mem], dim=-1)
        sigma = self.gate(gate_input)  # (batch, seq, 1)

        # 出力結合
        H = sigma * H_local + (1 - sigma) * H_mem

        # メモリ更新
        self.update_memory(K, V)

        return H

セグメント化の処理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def process_long_sequence(model, tokens, segment_size=2048):
    """
    長文脈シーケンスをセグメント処理

    Args:
        model: InfiniAttentionを含むTransformer
        tokens: (1, total_length) LongTensor
        segment_size: セグメント長(例: 2048)
    """
    total_length = tokens.size(1)
    num_segments = (total_length + segment_size - 1) // segment_size

    outputs = []

    for seg_idx in range(num_segments):
        start = seg_idx * segment_size
        end = min(start + segment_size, total_length)
        segment = tokens[:, start:end]

        # セグメント処理
        output = model(segment, segment_idx=seg_idx)
        outputs.append(output)

    return torch.cat(outputs, dim=1)

実験結果

1. Passkey Retrieval(情報想起タスク)

タスク: 1M-tokenのランダムテキスト中に埋め込まれた5桁パスキーを想起。

設定:

1
2
3
"The secret passkey is 84729. Remember it."
[... 1,000,000 tokens of random text ...]
"What was the secret passkey mentioned earlier?"

結果:

モデルコンテキスト長想起率
Infini-attention1M tokens100%
Longformer4K tokens0% (out of range)
RoPE拡張Transformer32K tokens12% (long-term decay)
RAG (Retrieval)1M tokens67% (検索失敗)

分析: Infini-attentionは圧縮メモリに情報を保持し、100%想起を達成。


2. BookSum(長文書要約)

タスク: 書籍全文(平均80K tokens)を要約。

評価指標: ROUGE-1, ROUGE-2, ROUGE-L

結果:

モデルROUGE-1ROUGE-2ROUGE-L
Infini-attention42.318.739.1
Longformer38.915.235.4
BigBird39.716.136.8
Pegasus (Baseline)37.214.334.2

分析: 長期依存情報を活用し、要約品質が向上。


3. メモリ効率

測定: 1M-token処理時のVRAM使用量

モデルVRAM使用量計算量
標準Transformer80 GB (OOM)O(N^2)
Sparse Attention32 GBO(N log N)
Infini-attention8 GBO(N)

分析: 圧縮メモリによりメモリ使用量が1/10に削減。


Claude Codeスキルへの応用

長文脈スキルの設計

Infini-attentionの原理は、Claude Codeスキルのコンテキスト効率化に応用できます:

1. 段階的開示パターンとの統合

1
2
3
4
5
6
7
8
9
# SKILL.md (概要: 短いコンテキスト)

## 基本的な使用
[500行以内の簡潔な指示]

## 詳細情報
必要に応じて以下を参照:
- [advanced.md](advanced.md): 高度な機能
- [reference.md](reference.md): APIリファレンス

類似性: ローカルAttention(SKILL.md)とメモリRetrieval(参照ファイル)の組み合わせ。

2. コンテキスト圧縮の実装

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# scripts/compress_context.py
def summarize_long_log(log_file, max_length=500):
    """
    長いログファイルを要約してコンテキスト節約

    Infini-attentionの「圧縮メモリ」に相当
    """
    with open(log_file) as f:
        lines = f.readlines()

    # 重要な行のみ抽出(エラー、警告)
    important = [l for l in lines if "ERROR" in l or "WARNING" in l]

    # 最新N行を保持
    recent = lines[-100:]

    # 要約として返す
    return "\n".join(important + ["...", "Recent logs:"] + recent)

3. セグメント化ワークフロー

1
2
3
4
5
6
7
8
9
10
11
12
# SKILL.md

## 大規模PDFフォーム処理

**ステップ1: セグメント分割**
大きなPDFを10ページごとに分割

**ステップ2: 各セグメント処理**
セグメントごとにフィールド抽出

**ステップ3: 結果集約**
すべてのセグメントを統合して出力

類似性: Infini-attentionのセグメント処理と同じパターン。


限界と今後の課題

圧縮による情報損失

問題: EMAによる減衰で、古い情報が徐々に失われる。

:

1
2
Token 1M: "The CEO's name is Alice"
Token 10M: "What was the CEO's name?" → "I don't remember" (情報減衰)

対策:

  • $\delta$ を高く設定(0.99以上)
  • 重要情報を別途ストア(Retrieval併用)

セグメント境界の文脈切れ

問題: セグメント境界で文が分断される可能性。

:

1
2
Segment 1 end: "The reason for this decision is"
Segment 2 start: "that we need to reduce costs"

対策:

  • セグメントをオーバーラップ(最後100トークンを次セグメントに含める)
  • 文境界でセグメント分割(SentencePieceの逆適用)

詳細推論の精度低下

問題: 圧縮により、詳細な数値や固有名詞が失われる。

適用タスク vs 不適タスク:

  • ✅ 要約、検索、全体理解
  • ❌ 数式計算、コード生成(詳細が重要)

まとめ

本論文の貢献

  1. 無限長コンテキスト: 理論上任意の長さの入力を処理可能
  2. メモリ効率: O(1)のメモリ増加(従来のO(N^2)から劇的改善)
  3. 実用性: 1M-tokenタスクで100%想起、要約品質向上を実証
  4. 既存モデル統合: ドロップイン置換で標準Transformerに適用可能

Claude Codeスキル開発への示唆

  • 段階的開示の理論的裏付け: Infini-attentionと同じ「ローカル+グローバル」パターン
  • コンテキスト圧縮: 長いログ・ドキュメントを要約して効率化
  • セグメント処理: 大規模タスクを分割して処理

次のステップ

  • Google DeepMind公式実装(未公開、Geminiで実用化の可能性)
  • Claude 4の長文脈対応にInfini-attention類似技術が使われている可能性
  • 関連Zenn記事: Claude Codeスキル作成完全ガイド

参考文献

この投稿は CC BY 4.0 でライセンスされています。

FAISS完全ガイド: 大規模ベクトル検索の実装と最適化

カンファレンス論文解説: AvaTaR - LLMエージェントのツール使用最適化