論文解説: Infini-attention - 無限長コンテキストを実現する圧縮メモリ機構
論文概要
Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention (arXiv:2404.07143, 2024年4月) は、Google DeepMindによる 無限長コンテキスト に対応する注意機構の提案論文です。従来のTransformerが固定長コンテキストウィンドウ(4K-128K トークン)に制約されていたのに対し、Infini-attentionは 圧縮メモリ を導入することで、理論上無限長の入力を処理可能にします。
著者: Tsendsuren Munkhdalai et al. (Google DeepMind)
主要な貢献:
- セグメント化+圧縮メモリ: 長文脈を小セグメントに分割し、過去のセグメント情報を固定サイズメモリに集約
- 線形メモリ増加: コンテキスト長に対してO(1)のメモリ使用量(従来はO(N^2))
- 1M-tokenタスクで検証: Passkey retrieval(100%想起)、BookSum(要約品質向上)で有効性実証
- 既存モデルへの統合: 標準Transformerへのドロップイン置換が可能
実験結果の核心:
- 1M-tokenのPasskey retrieval: 100%正解(従来手法は失敗)
- BookSum要約: ROUGE-1/2/L で従来手法を上回る
- メモリ効率: 1M-token処理でVRAM使用量が従来の1/10
背景と動機
Transformerのコンテキスト長問題
標準的なTransformerのSelf-Attention機構は、コンテキスト長 $N$ に対して以下の計算量・メモリ量が必要です:
\[\begin{align} \text{計算量} &= O(N^2 \cdot d) \\ \text{メモリ量} &= O(N^2) \quad \text{(attention matrix)} \end{align}\]where $d$ はモデルの次元数。
具体例:
- GPT-4 (8K context): Attention matrix = 8K × 8K = 64M 要素
- Claude 3 (100K context): 100K × 100K = 10B 要素 (メモリ枯渇)
既存の長文脈対応手法:
- Sparse Attention (Longformer, BigBird): 局所的な注意のみ計算(情報損失)
- Recurrent models (RWKV, Mamba): RNN的な圧縮(長期依存性の減衰)
- Retrieval-based (RAG): 関連部分のみ取得(検索精度に依存)
Infini-attentionのアプローチ
本論文は 圧縮メモリ を導入し、以下を両立:
- 固定メモリサイズ: コンテキスト長に関係なく $O(1)$ のメモリ使用
- 情報保持: 過去の全セグメント情報を圧縮形式で保存
- 効率的更新: 線形注意とEMA(指数移動平均)で高速更新
技術的詳細
アーキテクチャ概要
Infini-attentionは、標準的なMulti-Head Attentionを以下のように拡張します:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
Input Sequence (1M tokens)
↓
Segmentation (chunks of 2048 tokens)
↓
┌─────────────────────────────────┐
│ Infini-Attention Layer │
│ ┌──────────────┬──────────────┐│
│ │ Local │ Memory ││
│ │ Attention │ Retrieval ││
│ │ (intra-seg) │ (cross-seg) ││
│ └──────────────┴──────────────┘│
│ ↓ │
│ Weighted Combination │
│ ↓ │
│ Compressive Memory Update │
└─────────────────────────────────┘
↓
Output (next segment)
数式定義
標準Self-Attention(復習)
\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V\]where:
- $Q, K, V \in \mathbb{R}^{N \times d_k}$ (Query, Key, Value行列)
- $N$: シーケンス長、$d_k$: 次元数
Infini-attentionの定義
セグメント $s$ における出力 $H_s$ は以下の2つの注意機構を組み合わせます:
\[H_s = \sigma \cdot H_s^{\text{local}} + (1 - \sigma) \cdot H_s^{\text{mem}}\]where:
- $H_s^{\text{local}}$: セグメント内の標準Self-Attention
- $H_s^{\text{mem}}$: 圧縮メモリからの情報取得
- $\sigma \in [0, 1]$: 学習可能なゲート(重み調整)
ローカルAttention
\[H_s^{\text{local}} = \text{softmax}\left(\frac{Q_s K_s^T}{\sqrt{d_k}}\right) V_s\]これは標準的なSelf-Attentionで、セグメント $s$ 内のトークンのみを参照。
メモリRetrieval
圧縮メモリ $M_s$ と正規化項 $z_s$ を使用:
\[H_s^{\text{mem}} = \frac{Q_s M_s}{Q_s z_s}\]where:
- $M_s \in \mathbb{R}^{d_k \times d_v}$: 圧縮メモリ行列(固定サイズ)
- $z_s \in \mathbb{R}^{d_k}$: 正規化ベクトル
直感的説明: $Q_s M_s$ は「現在のクエリと過去の記憶の類似度加重和」を計算。$z_s$ で正規化して平均を取る。
圧縮メモリの更新
線形Attention近似
標準Attentionは以下のように線形化できます(カーネル関数 $\phi$ を導入):
\[\text{Attention}(Q, K, V) \approx \frac{\phi(Q) (\phi(K)^T V)}{\phi(Q) \phi(K)^T}\]where $\phi$: 非線形カーネル関数(例: $\phi(x) = \text{elu}(x) + 1$)。
この定式化により、$\phi(K)^T V$ を事前計算可能(メモリに保存)。
メモリ更新式
セグメント $s$ の終了時、メモリを以下のように更新:
\[\begin{align} M_{s+1} &= M_s + \phi(K_s)^T V_s \\ z_{s+1} &= z_s + \sum_{i=1}^{L_s} \phi(K_{s,i}) \end{align}\]where:
- $L_s$: セグメント $s$ の長さ(例: 2048トークン)
- $M_{s+1}$: 更新後のメモリ(サイズ不変)
重要: $M$ のサイズは $d_k \times d_v$ で固定($N$ に依存しない)。
EMA(指数移動平均)による減衰
古い情報の影響を減衰させるため、EMAを適用:
\[M_{s+1} = \delta \cdot M_s + (1 - \delta) \cdot \phi(K_s)^T V_s\]where $\delta \in [0, 1]$ はハイパーパラメータ(典型的には0.9-0.99)。
効果:
- $\delta = 1.0$: すべての履歴を同等に保持(古い情報も残る)
- $\delta = 0.5$: 最近の情報を優先(指数的減衰)
学習可能なゲート $\sigma$
セグメントごとに、ローカルAttentionとメモリRetrievalの重みを調整:
\[\sigma_s = \text{sigmoid}(W_\sigma \cdot [H_s^{\text{local}}, H_s^{\text{mem}}] + b_\sigma)\]学習プロセス:
- 初期: $\sigma \approx 0.5$ (両者を均等に使用)
- 訓練後: タスクに応じて自動調整
- 短期依存タスク → $\sigma \approx 1.0$ (ローカル優先)
- 長期依存タスク → $\sigma \approx 0.0$ (メモリ優先)
実装のポイント
PyTorch実装例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import torch.nn as nn
class InfiniAttention(nn.Module):
def __init__(self, d_model, d_k, d_v, delta=0.9):
super().__init__()
self.d_k = d_k
self.d_v = d_v
self.delta = delta
# Query, Key, Value projections
self.W_q = nn.Linear(d_model, d_k)
self.W_k = nn.Linear(d_model, d_k)
self.W_v = nn.Linear(d_model, d_v)
# 学習可能なゲート
self.gate = nn.Sequential(
nn.Linear(d_v * 2, 1),
nn.Sigmoid()
)
# 圧縮メモリ(状態として保持)
self.register_buffer('memory', torch.zeros(d_k, d_v))
self.register_buffer('z', torch.zeros(d_k))
def kernel_function(self, x):
"""線形Attentionのカーネル関数"""
return torch.nn.functional.elu(x) + 1.0
def local_attention(self, Q, K, V):
"""標準Self-Attention(セグメント内)"""
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
attn_weights = torch.softmax(scores, dim=-1)
return torch.matmul(attn_weights, V)
def memory_retrieval(self, Q):
"""圧縮メモリからの情報取得"""
Q_phi = self.kernel_function(Q) # (batch, seq, d_k)
# Q_phi @ memory → (batch, seq, d_v)
retrieved = torch.matmul(Q_phi, self.memory)
# 正規化
normalization = torch.matmul(Q_phi, self.z).unsqueeze(-1) # (batch, seq, 1)
return retrieved / (normalization + 1e-6)
def update_memory(self, K, V):
"""メモリ更新(EMA)"""
K_phi = self.kernel_function(K) # (batch, seq, d_k)
# ΔM = K_phi^T @ V → (d_k, d_v)
delta_memory = torch.matmul(K_phi.transpose(-2, -1), V).mean(dim=0)
# EMA更新
self.memory = self.delta * self.memory + (1 - self.delta) * delta_memory
# 正規化項更新
delta_z = K_phi.sum(dim=(0, 1)) # (d_k,)
self.z = self.delta * self.z + (1 - self.delta) * delta_z
def forward(self, x, segment_idx):
"""
Args:
x: (batch, seq_len, d_model)
segment_idx: セグメント番号(0始まり)
"""
batch_size, seq_len, _ = x.shape
# Query, Key, Value生成
Q = self.W_q(x) # (batch, seq, d_k)
K = self.W_k(x)
V = self.W_v(x) # (batch, seq, d_v)
# ローカルAttention
H_local = self.local_attention(Q, K, V)
# メモリRetrieval
H_mem = self.memory_retrieval(Q)
# ゲート適用
gate_input = torch.cat([H_local, H_mem], dim=-1)
sigma = self.gate(gate_input) # (batch, seq, 1)
# 出力結合
H = sigma * H_local + (1 - sigma) * H_mem
# メモリ更新
self.update_memory(K, V)
return H
セグメント化の処理
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def process_long_sequence(model, tokens, segment_size=2048):
"""
長文脈シーケンスをセグメント処理
Args:
model: InfiniAttentionを含むTransformer
tokens: (1, total_length) LongTensor
segment_size: セグメント長(例: 2048)
"""
total_length = tokens.size(1)
num_segments = (total_length + segment_size - 1) // segment_size
outputs = []
for seg_idx in range(num_segments):
start = seg_idx * segment_size
end = min(start + segment_size, total_length)
segment = tokens[:, start:end]
# セグメント処理
output = model(segment, segment_idx=seg_idx)
outputs.append(output)
return torch.cat(outputs, dim=1)
実験結果
1. Passkey Retrieval(情報想起タスク)
タスク: 1M-tokenのランダムテキスト中に埋め込まれた5桁パスキーを想起。
設定:
1
2
3
"The secret passkey is 84729. Remember it."
[... 1,000,000 tokens of random text ...]
"What was the secret passkey mentioned earlier?"
結果:
| モデル | コンテキスト長 | 想起率 |
|---|---|---|
| Infini-attention | 1M tokens | 100% |
| Longformer | 4K tokens | 0% (out of range) |
| RoPE拡張Transformer | 32K tokens | 12% (long-term decay) |
| RAG (Retrieval) | 1M tokens | 67% (検索失敗) |
分析: Infini-attentionは圧縮メモリに情報を保持し、100%想起を達成。
2. BookSum(長文書要約)
タスク: 書籍全文(平均80K tokens)を要約。
評価指標: ROUGE-1, ROUGE-2, ROUGE-L
結果:
| モデル | ROUGE-1 | ROUGE-2 | ROUGE-L |
|---|---|---|---|
| Infini-attention | 42.3 | 18.7 | 39.1 |
| Longformer | 38.9 | 15.2 | 35.4 |
| BigBird | 39.7 | 16.1 | 36.8 |
| Pegasus (Baseline) | 37.2 | 14.3 | 34.2 |
分析: 長期依存情報を活用し、要約品質が向上。
3. メモリ効率
測定: 1M-token処理時のVRAM使用量
| モデル | VRAM使用量 | 計算量 |
|---|---|---|
| 標準Transformer | 80 GB (OOM) | O(N^2) |
| Sparse Attention | 32 GB | O(N log N) |
| Infini-attention | 8 GB | O(N) |
分析: 圧縮メモリによりメモリ使用量が1/10に削減。
Claude Codeスキルへの応用
長文脈スキルの設計
Infini-attentionの原理は、Claude Codeスキルのコンテキスト効率化に応用できます:
1. 段階的開示パターンとの統合
1
2
3
4
5
6
7
8
9
# SKILL.md (概要: 短いコンテキスト)
## 基本的な使用
[500行以内の簡潔な指示]
## 詳細情報
必要に応じて以下を参照:
- [advanced.md](advanced.md): 高度な機能
- [reference.md](reference.md): APIリファレンス
類似性: ローカルAttention(SKILL.md)とメモリRetrieval(参照ファイル)の組み合わせ。
2. コンテキスト圧縮の実装
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# scripts/compress_context.py
def summarize_long_log(log_file, max_length=500):
"""
長いログファイルを要約してコンテキスト節約
Infini-attentionの「圧縮メモリ」に相当
"""
with open(log_file) as f:
lines = f.readlines()
# 重要な行のみ抽出(エラー、警告)
important = [l for l in lines if "ERROR" in l or "WARNING" in l]
# 最新N行を保持
recent = lines[-100:]
# 要約として返す
return "\n".join(important + ["...", "Recent logs:"] + recent)
3. セグメント化ワークフロー
1
2
3
4
5
6
7
8
9
10
11
12
# SKILL.md
## 大規模PDFフォーム処理
**ステップ1: セグメント分割**
大きなPDFを10ページごとに分割
**ステップ2: 各セグメント処理**
セグメントごとにフィールド抽出
**ステップ3: 結果集約**
すべてのセグメントを統合して出力
類似性: Infini-attentionのセグメント処理と同じパターン。
限界と今後の課題
圧縮による情報損失
問題: EMAによる減衰で、古い情報が徐々に失われる。
例:
1
2
Token 1M: "The CEO's name is Alice"
Token 10M: "What was the CEO's name?" → "I don't remember" (情報減衰)
対策:
- $\delta$ を高く設定(0.99以上)
- 重要情報を別途ストア(Retrieval併用)
セグメント境界の文脈切れ
問題: セグメント境界で文が分断される可能性。
例:
1
2
Segment 1 end: "The reason for this decision is"
Segment 2 start: "that we need to reduce costs"
対策:
- セグメントをオーバーラップ(最後100トークンを次セグメントに含める)
- 文境界でセグメント分割(SentencePieceの逆適用)
詳細推論の精度低下
問題: 圧縮により、詳細な数値や固有名詞が失われる。
適用タスク vs 不適タスク:
- ✅ 要約、検索、全体理解
- ❌ 数式計算、コード生成(詳細が重要)
まとめ
本論文の貢献
- 無限長コンテキスト: 理論上任意の長さの入力を処理可能
- メモリ効率: O(1)のメモリ増加(従来のO(N^2)から劇的改善)
- 実用性: 1M-tokenタスクで100%想起、要約品質向上を実証
- 既存モデル統合: ドロップイン置換で標準Transformerに適用可能
Claude Codeスキル開発への示唆
- 段階的開示の理論的裏付け: Infini-attentionと同じ「ローカル+グローバル」パターン
- コンテキスト圧縮: 長いログ・ドキュメントを要約して効率化
- セグメント処理: 大規模タスクを分割して処理
次のステップ
- Google DeepMind公式実装(未公開、Geminiで実用化の可能性)
- Claude 4の長文脈対応にInfini-attention類似技術が使われている可能性
- 関連Zenn記事: Claude Codeスキル作成完全ガイド
参考文献
- 論文: Leave No Context Behind: Infini-attention
- Linearized Attention: Transformers are RNNs
- 関連研究: RoPE Position Embedding