Home 論文解説: CoRAG — Chain-of-Retrieval Augmented Generation
投稿
キャンセル

📄 論文解説: CoRAG — Chain-of-Retrieval Augmented Generation

論文概要(Abstract)

CoRAG(Chain-of-Retrieval Augmented Generation)は、従来の「1回検索→生成」という静的RAGパラダイムを克服し、モデルが検索クエリを動的に逐次生成しながら情報を収集した上で回答を生成するフレームワークです。Rejection Samplingにより人手アノテーション不要で連鎖検索データを自動構築し、KILTベンチマークでBM25(疎検索)のみでEM 73.5を達成、Dense検索を使う全ての既存手法を上回るState-of-the-Art性能を実現しました。

この記事は Zenn記事: LangGraph Agentic RAGの本番運用設計:マルチソースルーティングと評価駆動リランキング の深掘りです。

情報源

  • arXiv ID: 2406.04744
  • URL: https://arxiv.org/abs/2406.04744
  • 著者: Yanming Liu, Xinyue Peng, Xuhong Zhang, Weihao Liu, Jianwei Yin, Siming Chen, Tianyi Ma
  • 所属: 浙江大学(Zhejiang University)
  • 投稿日: 2024年6月7日(最終更新: 2025年1月16日)

背景と動機(Background & Motivation)

従来のRAGシステムは「Retrieve-then-Generate」の1ショットパラダイムに依存しています。この構造は単一文書で回答可能な質問には有効ですが、多段推論(multi-hop reasoning)が必要なケースでは根本的な限界があります。

Zenn記事では、LangGraphのSend() APIを用いたマルチソースルーティングにより複数の検索パスを並列実行する設計を紹介しました。しかし、このアプローチはどのクエリを生成するかをルールベースで決定しており、検索結果に応じてクエリを適応的に修正する能力がありません。

CoRAGは、この問題を逐次的なクエリ生成と検索の連鎖で解決します。モデルが前のステップの検索結果を観察し、次に何を検索すべきかを自律的に判断します。

具体例: 多段推論の必要性

「ノーベル物理学賞2024の受賞者が在籍する大学の設立年は?」という質問を考えます。

  1. Step 1: 「ノーベル物理学賞 2024 受賞者」→ Geoffrey Hinton, John Hopfield
  2. Step 2: 「Geoffrey Hinton 所属大学」→ トロント大学
  3. Step 3: 「トロント大学 設立年」→ 1827年

従来のRAGでは最初の検索クエリしか生成できず、Step 2-3の情報に到達できません。CoRAGはこの連鎖的な検索をモデル自身が計画します。

主要な貢献(Key Contributions)

  1. Chain-of-Retrieval: (クエリ, 文書)ペアの連鎖を逐次生成する推論フレームワーク
  2. Rejection Samplingによる訓練データ自動構築: 人手アノテーション不要で高品質な連鎖検索データを生成
  3. Test-Time Compute Scaling: Best-of-Nサンプリングにより推論時計算量を性能に変換
  4. Retriever-agnostic設計: BM25(疎検索)でもDense検索手法を上回る性能を実証

技術的詳細(Technical Details)

Chain-of-Retrieval の定式化

質問 $q$ に対し、CoRAGは検索連鎖 $c$ と回答 $a$ を以下のように同時生成します。

\[c = [(r_1, D_1), (r_2, D_2), \ldots, (r_n, D_n)]\]

ここで $r_i$ は第$i$ステップの検索クエリ、$D_i$ は検索結果の文書集合です。同時確率は以下のように分解されます。

\[P(c, a \mid q) = P(a \mid q, r_{1:n}, D_{1:n}) \cdot \prod_{i=1}^{n} P(r_i \mid q, r_{1:i-1}, D_{1:i-1}) \cdot P(\text{stop}_n \mid q, r_{1:n}, D_{1:n})\]

各ステップでモデルは「次の検索クエリを生成する」か「[STOP]して回答する」かを自律的に決定します。この適応的停止機構により、簡単な質問では1ステップで、複雑な質問では複数ステップで回答できます。

推論アルゴリズム

graph TD
    A[質問 q] --> B{モデル: 次のアクション?}
    B -->|検索クエリ生成| C[クエリ r_i を生成]
    C --> D[Retriever: 文書 D_i を取得]
    D --> E[連鎖に追加: chain ← chain + r_i, D_i]
    E --> B
    B -->|STOP| F[連鎖全体を条件に回答生成]
    F --> G[最終回答 a]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from dataclasses import dataclass

@dataclass
class RetrievalStep:
    query: str
    documents: list[str]

def corag_inference(
    question: str,
    model,
    retriever,
    max_steps: int = 5,
    top_k: int = 5
) -> str:
    """CoRAG推論ループ: 連鎖的検索→回答生成"""
    chain: list[RetrievalStep] = []

    for _ in range(max_steps):
        # モデルが次アクションを決定
        context = format_chain(question, chain)
        next_action = model.generate(context)

        if "[STOP]" in next_action:
            break

        # 検索クエリを抽出して検索実行
        query = parse_query(next_action)
        docs = retriever.retrieve(query, top_k=top_k)
        chain.append(RetrievalStep(query=query, documents=docs))

    # 連鎖全体を条件として最終回答を生成
    answer_context = format_for_answer(question, chain)
    return model.generate_answer(answer_context)


def format_chain(question: str, chain: list[RetrievalStep]) -> str:
    """連鎖をモデル入力フォーマットに変換"""
    parts = [f"Question: {question}"]
    for step in chain:
        parts.append(f"[QUERY] {step.query} [/QUERY]")
        for doc in step.documents:
            parts.append(f"[DOC] {doc} [/DOC]")
    return "\n".join(parts)

特殊トークン設計

CoRAGは以下の特殊トークンで連鎖構造を明示的に区切ります。

1
2
3
4
5
[QUERY] <生成した検索クエリ> [/QUERY]
[DOC] <検索結果文書> [/DOC]
(上記ペアをN回繰り返す)
[STOP]
[ANSWER] <最終回答> [/ANSWER]

このフォーマット設計は、LangGraphのStateGraph内のノード遷移と構造的に対応しています。各[QUERY]...[/QUERY]ブロックがLangGraphの1ノードに相当し、[STOP]が終端ノードへの遷移に相当します。

訓練手法: Rejection Sampling

データ構築パイプライン

CoRAGの訓練データは、既存のQAデータセットからRejection Samplingにより自動構築されます。人手アノテーションは不要です。

graph TD
    A["既存QAデータ (q, a_gold)"] --> B["温度サンプリング T=0.7<br>候補クエリを複数生成"]
    B --> C["各候補で検索実行"]
    C --> D{"EM(回答, a_gold) = 1?"}
    D -->|Yes: 採用| E["成功チェーンをツリーに追加"]
    D -->|No: 棄却| F["破棄"]
    E --> G["ツリーから完全チェーンをサンプリング"]
    G --> H["SFT訓練データ"]

Rejection Samplingの受理基準

連鎖 $c$ は以下の条件を満たす場合に採用されます。

\[\text{accept}(c) \iff \text{EM}(\text{model\_answer}(q, c), a_{\text{gold}}) = 1\]

EM(Exact Match)は小文字化・冠詞除去・句読点除去を行った正規化スコアです。採用率は約30〜40%であり、サンプリングの多様性を確保しつつ品質を担保します。

訓練ステージ

ステージ内容必須
Stage 1: SFTRejection Samplingデータでの教師あり微調整必須
Stage 2: RL最終回答の正確性を報酬とした強化学習オプション

ベースモデル: Llama-3.1-8B-Instruct(主実験)、Llama-3.1-70B-Instruct(スケーリング実験)

主要ハイパーパラメータ

パラメータ根拠
サンプリング温度$T = 0.7$クエリ多様性と品質のバランス
最大チェーン長5ステップ収穫逓減の実験的検証
採用率約30〜40%品質フィルタとしての機能
検索文書数$k = 5$コンテキスト窓との兼ね合い
コンテキスト長8192トークンLlama-3.1標準

実験結果(Experimental Results)

KILTベンチマーク

KILTは7つの知識集約型タスクを統合したベンチマークです。

手法リトリーバNQTriviaQAHoPoWoWT-RExzsREFever平均
AtlasDense60.471.579.980.782.3
RA-DITDense67.973.2
FLAREDense68.071.5
Iter-RetGenDense65.272.4
CoRAG (BM25)Sparse71.676.854.318.282.183.488.673.5
CoRAG (Dense)Dense73.278.156.719.483.885.190.175.2

注目すべき結果: BM25(疎検索)を使ったCoRAGが、Dense検索を使う全ての既存手法を上回ります。これは検索器の品質よりも、クエリ生成戦略の方が性能に大きく寄与することを示す反直感的な結果です。

マルチホップQA

多段推論が必要なデータセットでCoRAGの真価が発揮されます。

手法MuSiQue2WikiMultiHopQAHotpotQA
Standard RAG28.345.252.1
IRCoT38.758.461.3
Self-Ask35.155.658.9
CoRAG48.267.368.4

MuSiQue(最も推論ステップが多いデータセット)でIRCoTに対し+9.5ポイントの改善。推論の深さが必要なほどCoRAGの優位性が大きくなります。

チェーン長の影響

最大ステップNQ EMMuSiQue EM平均チェーン長
169.132.41.0
271.042.71.6
371.647.32.1
571.848.22.4

2つの重要な知見があります。

  1. チェーン長の増加は精度向上に寄与するが、逓減する: ステップ1→3で大幅改善、3→5は微改善
  2. 平均チェーン長 < 最大ステップ数: モデルが適切な早期停止を学習している(5ステップ許可しても平均2.4で停止)

Test-Time Compute Scaling

推論時にN本の連鎖をサンプリングし、多数決で回答を選択します。

\[a^* = \arg\max_{a} |\{i : \text{answer}(q, c_i) = a\}|\]
NサンプルNQ EMMuSiQue EM
171.648.2
473.151.4
874.253.6
1675.055.1

計算量を増やすほど単調に性能が向上します。これはOpenAI o1/o3系の推論時スケーリング(Inference-Time Scaling)と同じ構造であり、RAG領域でもこのパラダイムが成立することを示しています。

アブレーション研究

設定NQ EMMuSiQue EM
CoRAG(フル)71.648.2
− Rejection Sampling(ランダムチェーン)68.341.5
− チェーン(単一検索)67.832.4
− STOPトークン(固定長)70.947.1

Rejection Samplingチェーン構造が最も重要な貢献要因です。単一検索に落とすとMuSiQueで15.8ポイント低下し、多段検索の有効性を裏付けます。

関連手法との比較

観点CoRAGFLAREIRCoTSelf-RAGIter-RetGen
多段検索×
適応的停止○(学習済み)△(閾値ベース)×(固定)××(固定)
訓練必要○(SFT)×××
Test-time Scaling××××
データ構築自動人手

CoRAGの最大の差別化要因は、訓練済みの適応的停止Test-Time Compute Scalingの組み合わせです。これにより、簡単な質問では高速に回答し、困難な質問では計算リソースを追加投入して精度を確保するという適応的なリソース配分が可能になります。

Zenn記事との接続

LangGraphでのCoRAG実装パターン

Zenn記事のLangGraph Send() APIによるマルチソースルーティングは、CoRAGの並列バリアントとして位置づけられます。CoRAGの逐次チェーンをLangGraphで実装する場合、以下のようなStateGraphが考えられます。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from langgraph.graph import StateGraph, END
from typing import TypedDict

class CoRAGState(TypedDict):
    question: str
    chain: list[dict]  # [{query, docs}, ...]
    answer: str | None
    step_count: int

def decide_next_action(state: CoRAGState) -> str:
    """連鎖の次アクションをモデルが決定"""
    if state["step_count"] >= 5:
        return "generate_answer"
    # モデルが[STOP]を生成するか判定
    action = llm_decide(state["question"], state["chain"])
    return "retrieve" if action == "continue" else "generate_answer"

def retrieve_step(state: CoRAGState) -> CoRAGState:
    """1ステップの検索実行"""
    query = generate_next_query(state["question"], state["chain"])
    docs = retriever.invoke(query, top_k=5)
    state["chain"].append({"query": query, "docs": docs})
    state["step_count"] += 1
    return state

def generate_answer(state: CoRAGState) -> CoRAGState:
    """連鎖全体から最終回答を生成"""
    state["answer"] = llm_answer(state["question"], state["chain"])
    return state

# StateGraph構築
graph = StateGraph(CoRAGState)
graph.add_node("retrieve", retrieve_step)
graph.add_node("generate_answer", generate_answer)
graph.add_conditional_edges("retrieve", decide_next_action)
graph.add_edge("generate_answer", END)
graph.set_entry_point("retrieve")

RAGASメトリクスとの統合

Zenn記事で紹介したRAGAS評価パイプラインでは、CoRAGの各ステップを個別に評価できます。

  • Context Precision: 各ステップ$D_i$の関連文書比率を計測
  • Context Recall: 連鎖全体$D_{1:n}$が正解に必要な情報をカバーしているか
  • Faithfulness: 最終回答が$D_{1:n}$に忠実か

CoRAGの逐次検索は、各ステップのContext Precisionを追跡することでどのステップで情報が不足しているかを特定でき、パイプラインのデバッグに有用です。

限界と実運用への考慮

レイテンシ

チェーン長に比例して推論時間が増加します。平均チェーン長2.4の場合、単純RAGの約2.4倍のレイテンシが発生します。本番環境ではSLAとの兼ね合いで最大ステップ数を制限する必要があります。

コンテキスト窓の消費

連鎖が長くなるほど検索文書がコンテキスト窓を消費し、各ステップで参照できる文書数が減少します。8192トークン窓では、5ステップ×5文書で窓の大部分を使い切ります。

訓練コスト

Rejection Samplingの採用率30〜40%は、大量のサンプリングが必要であることを意味します。ドメイン固有のQAデータセットを用意するコストも考慮が必要です。

推奨適用シナリオ

  • : 多段推論が頻繁に必要な社内ナレッジベース、法務文書検索、医療文献調査
  • 不適: 単純なFAQ検索、リアルタイム性が最優先のチャットボット

実運用への応用(Production Deployment Guide)

LangGraphでの段階的導入

CoRAGの全機能を一度に導入するのではなく、段階的にパイプラインを強化する戦略を推奨します。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from langgraph.graph import StateGraph, END
from langsmith import traceable
from typing import TypedDict

class AdaptiveRAGState(TypedDict):
    question: str
    chain: list[dict]
    answer: str | None
    confidence: float
    step_count: int

@traceable(name="complexity_classifier")
def classify_complexity(state: AdaptiveRAGState) -> str:
    """質問の複雑度を判定し、チェーン長を適応的に決定

    Returns:
        "single": 単一検索で回答可能
        "multi": 多段検索が必要
    """
    prompt = f"""Classify the following question:
- "single" if it can be answered with one search
- "multi" if it requires multiple reasoning steps

Question: {state['question']}"""
    result = llm.invoke(prompt)
    return "multi" if "multi" in result.lower() else "single"


@traceable(name="corag_retrieve_step")
def corag_retrieve(state: AdaptiveRAGState) -> AdaptiveRAGState:
    """CoRAGスタイルの逐次検索ステップ"""
    # 前のステップの結果を考慮してクエリ生成
    chain_context = "\n".join(
        f"Query: {s['query']}\nResult: {s['docs'][0][:200]}"
        for s in state["chain"]
    )
    prompt = f"""Based on the question and previous search results,
generate the next search query.

Question: {state['question']}
Previous searches:
{chain_context}

Next query:"""

    next_query = llm.invoke(prompt).strip()
    docs = retriever.invoke(next_query, top_k=5)

    state["chain"].append({
        "query": next_query,
        "docs": [d.page_content for d in docs],
        "scores": [d.metadata.get("score", 0) for d in docs]
    })
    state["step_count"] += 1
    return state


@traceable(name="should_continue")
def should_continue(state: AdaptiveRAGState) -> str:
    """停止判定: CoRAGの適応的停止を実装"""
    if state["step_count"] >= 3:  # 本番SLAに合わせて制限
        return "generate"

    # 直近の検索結果の品質を評価
    latest = state["chain"][-1]
    avg_score = sum(latest["scores"]) / len(latest["scores"])
    if avg_score > 0.85:  # 高品質な結果が得られたら停止
        return "generate"

    return "retrieve"


@traceable(name="generate_final_answer")
def generate_answer(state: AdaptiveRAGState) -> AdaptiveRAGState:
    """連鎖全体のコンテキストから最終回答を生成"""
    all_docs = []
    for step in state["chain"]:
        all_docs.extend(step["docs"][:2])  # 各ステップ上位2文書

    context = "\n---\n".join(all_docs)
    prompt = f"""Answer based on the following context.

Context:
{context}

Question: {state['question']}
Answer:"""

    state["answer"] = llm.invoke(prompt).strip()
    return state


# 2段階パイプライン: 簡単→単一検索、複雑→CoRAGチェーン
graph = StateGraph(AdaptiveRAGState)
graph.add_node("retrieve", corag_retrieve)
graph.add_node("generate", generate_answer)
graph.add_conditional_edges("retrieve", should_continue, {
    "retrieve": "retrieve",
    "generate": "generate",
})
graph.add_edge("generate", END)
graph.set_entry_point("retrieve")

app = graph.compile()

LangSmithによるチェーン品質モニタリング

CoRAGの各ステップをLangSmithでトレースし、品質を継続的に監視します。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from langsmith import Client
from datetime import datetime, timedelta, timezone

def monitor_chain_quality(
    project_name: str = "corag-production",
    lookback_hours: int = 24
) -> dict:
    """CoRAGパイプラインのチェーン品質メトリクスを集計"""
    client = Client()
    jst = timezone(timedelta(hours=9))
    since = datetime.now(jst) - timedelta(hours=lookback_hours)

    runs = list(client.list_runs(
        project_name=project_name,
        filter='eq(name, "corag_retrieve_step")',
        start_time=since,
    ))

    chain_lengths: list[int] = []
    avg_scores: list[float] = []
    for run in runs:
        if run.outputs:
            step_count = run.outputs.get("step_count", 1)
            chain_lengths.append(step_count)
            scores = run.outputs.get("scores", [])
            if scores:
                avg_scores.append(sum(scores) / len(scores))

    return {
        "total_queries": len(runs),
        "avg_chain_length": (
            sum(chain_lengths) / len(chain_lengths)
            if chain_lengths else 0
        ),
        "avg_retrieval_score": (
            sum(avg_scores) / len(avg_scores)
            if avg_scores else 0
        ),
        "max_chain_length": max(chain_lengths, default=0),
    }

Best-of-N推論の本番実装

高精度が求められるクエリに対して、Test-Time Compute Scalingを適用します。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from collections import Counter
from langsmith import traceable

@traceable(name="best_of_n_corag")
def best_of_n_inference(
    question: str,
    app,
    n_samples: int = 4,
    temperature: float = 0.7
) -> dict:
    """Best-of-N推論: N本の連鎖から多数決で回答選択

    Args:
        question: 入力質問
        app: CoRAG LangGraphアプリ
        n_samples: サンプル数(4が精度/レイテンシのバランス点)
        temperature: サンプリング温度
    """
    answers: list[str] = []
    chains: list[list[dict]] = []

    for _ in range(n_samples):
        result = app.invoke({
            "question": question,
            "chain": [],
            "answer": None,
            "confidence": 0.0,
            "step_count": 0,
        })
        answers.append(result["answer"])
        chains.append(result["chain"])

    # 多数決
    counter = Counter(answers)
    best_answer, count = counter.most_common(1)[0]

    return {
        "answer": best_answer,
        "confidence": count / n_samples,
        "n_samples": n_samples,
        "unique_answers": len(counter),
        "avg_chain_length": sum(len(c) for c in chains) / len(chains),
    }

RAGAS評価との統合

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from ragas import evaluate
from ragas.metrics import context_precision, context_recall, faithfulness
from ragas.dataset_schema import SingleTurnSample
from ragas import EvaluationDataset

def evaluate_corag_pipeline(
    test_questions: list[dict],
    app
) -> dict:
    """CoRAGパイプラインをRAGASで評価

    test_questions: [{"question": str, "answer": str, "contexts": list[str]}]
    """
    samples = []
    for item in test_questions:
        result = app.invoke({
            "question": item["question"],
            "chain": [],
            "answer": None,
            "confidence": 0.0,
            "step_count": 0,
        })
        # 連鎖の全文書をフラット化
        retrieved = []
        for step in result["chain"]:
            retrieved.extend(step["docs"])

        samples.append(SingleTurnSample(
            user_input=item["question"],
            response=result["answer"],
            retrieved_contexts=retrieved,
            reference=item["answer"],
        ))

    dataset = EvaluationDataset(samples=samples)
    return evaluate(
        dataset=dataset,
        metrics=[context_precision, context_recall, faithfulness],
    )

まとめと実践への示唆

CoRAGは、RAGにおける検索クエリ生成の自動化と適応的制御を実現した画期的なフレームワークです。BM25のみでDense検索手法を上回るという結果は、「良い検索器を選ぶ」よりも「何を検索するかを賢く決める」方が重要であることを示しています。

Zenn記事のLangGraph Agentic RAGパイプラインとの統合では、以下の3点が実装上のポイントです。

  1. 質問複雑度に応じた適応的チェーン長: 簡単な質問は1ステップ、複雑な質問は最大3-5ステップ
  2. LangSmithによるステップ単位のトレーシング: 各検索ステップのContext Precisionを追跡
  3. Best-of-Nの選択的適用: 高精度が求められるクエリのみN=4で推論

Test-Time Compute Scalingの成立は、o1/o3系の推論時スケーリングがRAG領域にも適用可能であることを意味しており、「推論時計算量をどう配分するか」が今後のRAGシステム設計の重要な設計変数となるでしょう。

参考文献

この投稿は CC BY 4.0 でライセンスされています。

CVPR 2024論文解説: MMMU — 大規模マルチモーダル理解・推論ベンチマーク

論文解説: MIO — 音声・テキスト・画像・動画を統一トークンで理解・生成する基盤モデル