Home 論文解説: MCTS-RAG — モンテカルロ木探索で小規模LMの検索拡張推論を飛躍的に強化
投稿
キャンセル

📄 論文解説: MCTS-RAG — モンテカルロ木探索で小規模LMの検索拡張推論を飛躍的に強化

論文概要(Abstract)

MCTS-RAG(Hu et al., 2025)は、Monte Carlo Tree Search(MCTS)とRetrieval-Augmented Generation(RAG)を統合するフレームワークです。RAG単体の「推論深度不足」とMCTS単体の「事実的接地の欠如」を同時に解決し、7Bパラメータの小規模言語モデルでGPT-4oと競合する性能を実現します。PopQAで63.9%(GPT-4o: 56.3%を上回る)、GPQAで52.5%(GPT-4o: 50.0%を上回る)を達成しています。

この記事は Zenn記事: LangGraph×Claude Sonnet 4.6で実装する階層的Agentic RAG検索パイプライン の深掘りです。

情報源

背景と動機(Background & Motivation)

LLMは多様な推論タスクで高い能力を示しますが、2つの根本的な制約があります。(1)複雑な推論に必要な体系的探索戦略の欠如、(2)学習データに存在しない事実のハルシネーション。RAGは(2)に対応しますが、「1回の検索→1回の生成」という固定パイプラインでは推論パスの探索ができません。MCTSは(1)に対応しますが、外部知識にアクセスできないため知識集約型タスクには不向きです。

MCTS-RAGは、QAタスクを探索問題として定式化し、探索木の各ノードで「検索」「推論」「終了」の3アクションを選択できるようにすることで、両者の制約を同時に克服します。

Zenn記事の階層的検索パイプラインが「エージェントが3つのツールを選択する」アプローチであるのに対し、MCTS-RAGは「MCTSの探索木で最適なツール使用順序を発見する」アプローチです。MCTS-RAGの探索的アプローチは、エージェントが最適な検索戦略を事前に知らない場合に特に有効です。

主要な貢献(Key Contributions)

  • 貢献1: 検索と推論をMCTSフレームワーク内で統合し、検索パスと推論パスの同時最適化を実現
  • 貢献2: 7Bモデル(Llama-3.1-8B-Instruct)でフロンティアLLM(GPT-4o)と競合する性能を達成
  • 貢献3: Retrieve / Generate / Terminateの3アクション空間により、検索と推論の最適な組み合わせを自動発見

技術的詳細(Technical Details)

問題設定

質問$q$が与えられたとき、MCTS-RAGは推論パスの空間を探索して回答$a$を見つけます。各パスは、Retrieve(R)、Generate(G)、Terminate(T)のアクション列で構成されます。

状態$s_t$は以下で定義されます。

\[s_t = (q, a_1, \ldots, a_t, d_1, \ldots, d_k)\]

ここで、$q$は質問、$a_i$は$i$番目のアクション、$d_j$は$j$番目の検索済みドキュメントです。

3つのアクション

アクション記号内容
RetrieveR現在の状態に基づいて検索クエリを生成し、関連ドキュメントを取得
GenerateG質問・検索済みドキュメント・過去の推論ステップから新たな推論ステップを生成
TerminateT探索を終了し、現在の状態から最終回答を生成

MCTSの4フェーズ

Phase 1: Selection(選択)

ルートノードから、UCB1ポリシーに従って子ノードをトラバースします。

\[\text{UCB1}(n) = \bar{v}(n) + c \sqrt{\frac{\ln N(\text{parent}(n))}{N(n)}}\]

ここで、

  • $\bar{v}(n)$: ノード$n$の平均評価値
  • $N(n)$: ノード$n$の訪問回数
  • $c$: 探索定数($c = 0.5$で最良バランス)

UCB1の第1項は活用(高い評価値のノードを優先)、第2項は探索(訪問回数が少ないノードを優先)を表します。$c$のチューニングが探索-活用のバランスを制御します。

Phase 2: Expansion(展開)

リーフノードに到達したら、$k=3$の子ノードを展開します。Retrieveアクションでは検索クエリを生成してドキュメントを取得、Generateアクションでは推論ステップを生成します。

Phase 3: Evaluation(評価)

展開された各ノードを評価関数でスコアリングします。評価関数は同じLLM(Llama-3.1-8B-Instruct)をプロンプトで使用し、「現在の推論パスが正解に向かっているか」を0〜10のスケールで評価します。

\[v(n) = \text{Evaluate}(q, \text{path}(n))\]

Phase 4: Backpropagation(逆伝播)

評価スコアをルートまで逆伝播し、祖先ノードの$\bar{v}$と$N$を更新します。

\[\bar{v}(n) \leftarrow \frac{N(n) \cdot \bar{v}(n) + v(\text{child})}{N(n) + 1}\]

実装パラメータ

パラメータ根拠
ベースモデルLlama-3.1-8B-Instruct7Bクラスでの検証
検索システムWikipedia + BM25標準的なベンチマーク設定
探索定数$c$0.5予備実験で最良バランス
最大ツリー深度6深すぎると計算コスト増大
ブランチ係数$k$3各ノードから3つの候補を展開
ノード予算3030ノード以降は収穫逓減

アルゴリズム

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import math
from dataclasses import dataclass, field


@dataclass
class MCTSNode:
    """MCTS探索木のノード"""
    state: dict
    parent: "MCTSNode | None" = None
    children: list["MCTSNode"] = field(default_factory=list)
    visits: int = 0
    value: float = 0.0
    action_type: str = ""  # "R", "G", "T"

    @property
    def ucb1(self) -> float:
        """UCB1スコアを計算"""
        if self.visits == 0:
            return float("inf")
        exploitation = self.value / self.visits
        exploration = 0.5 * math.sqrt(
            math.log(self.parent.visits) / self.visits
        )
        return exploitation + exploration


def mcts_rag(
    question: str,
    model,
    retriever,
    evaluator,
    node_budget: int = 30,
    branch_factor: int = 3,
    max_depth: int = 6,
) -> str:
    """MCTS-RAGメインアルゴリズム

    Args:
        question: 入力質問
        model: 言語モデル(検索クエリ生成・推論ステップ生成)
        retriever: ドキュメント検索器(BM25)
        evaluator: 推論パス評価関数
        node_budget: 探索ノード数の上限
        branch_factor: 各ノードの子ノード数
        max_depth: 探索木の最大深度

    Returns:
        最終回答文字列
    """
    root = MCTSNode(state={"question": question, "docs": [], "steps": []})
    root.visits = 1

    for _ in range(node_budget):
        # Phase 1: Selection
        node = root
        while node.children:
            node = max(node.children, key=lambda n: n.ucb1)

        # 深度チェック
        depth = 0
        tmp = node
        while tmp.parent:
            depth += 1
            tmp = tmp.parent
        if depth >= max_depth:
            continue

        # Phase 2: Expansion
        for _ in range(branch_factor):
            action = model.select_action(node.state)

            if action == "R":
                query = model.generate_query(node.state)
                docs = retriever.retrieve(query, top_k=3)
                new_state = {
                    **node.state,
                    "docs": node.state["docs"] + docs,
                }
            elif action == "G":
                step = model.generate_reasoning_step(node.state)
                new_state = {
                    **node.state,
                    "steps": node.state["steps"] + [step],
                }
            else:  # "T"
                new_state = {**node.state, "terminated": True}

            child = MCTSNode(
                state=new_state,
                parent=node,
                action_type=action,
            )
            node.children.append(child)

            # Phase 3: Evaluation
            score = evaluator.evaluate(question, new_state)
            child.value = score
            child.visits = 1

            # Phase 4: Backpropagation
            current = child
            while current.parent:
                current.parent.visits += 1
                current.parent.value += score
                current = current.parent

    # 最良リーフから回答抽出
    best_leaf = _find_best_leaf(root)
    return model.generate_final_answer(question, best_leaf.state)


def _find_best_leaf(node: MCTSNode) -> MCTSNode:
    """再帰的に最良リーフノードを探索"""
    if not node.children:
        return node
    best_child = max(
        node.children,
        key=lambda n: n.value / max(n.visits, 1),
    )
    return _find_best_leaf(best_child)

実装のポイント(Implementation)

LangGraphとの統合

MCTS-RAGの3アクション(R/G/T)は、LangGraphのStateGraphで表現可能です。ただし、MCTSの探索木構造はLangGraphの線形グラフとは異なるため、カスタムの木探索ロジックが必要です。

Zenn記事のiteration_count(最大5回)はMCTS-RAGのノード予算(30ノード)に対応しますが、MCTSはより広い探索空間を体系的に調べるため、同じ計算予算でもより良い検索パスを発見できる可能性があります。

計算コスト

  • LLM呼び出し回数: 1問あたり20〜30回(標準RAGの1〜3回に対して約10倍)
  • レイテンシ: 7Bモデルで数秒〜数十秒(GPUの種類とノード予算に依存)
  • コスト削減策: 重要度の低いステップ(Evaluation)に小規模モデルを使用

評価関数の重要性

自己評価(同じ7Bモデル)でPopQA 63.9%に対し、外部評価(GPT-4o-mini)では65.2%に向上。評価関数の品質がMCTS全体の探索効率を大きく左右します。

実験結果(Results)

データセットMCTS-RAG (7B)GPT-4oStandard RAGMCTS only
PopQA63.9%56.3%46.2%38.7%
GPQA52.5%50.0%37.8%35.2%
HotpotQA61.8%65.2%52.3%44.1%
MuSiQue42.3%45.6%33.1%27.4%
2WikiMultihop71.2%73.8%58.4%52.9%

注目すべき点:

  • PopQAとGPQAでは7BモデルがGPT-4oを上回る。これは構造化された探索が、モデルサイズの不足を補えることを示唆
  • Retrieveアクション除去で-25.2%の性能低下(63.9%→38.7%)。事実的接地の重要性が顕著
  • ノード予算30が費用対効果の最適点。50ノードでの改善は限定的

実運用への応用(Practical Applications)

MCTS-RAGのアプローチは以下のユースケースに特に適しています。

  1. 希少エンティティの検索: PopQAのようなロングテール質問で最大の効果を発揮
  2. 専門分野QA: GPQA(大学院レベル物理・化学)での高精度は、医療・法律QAへの応用可能性を示唆
  3. 小規模モデルの活用: GPUコストを抑えつつ高精度を実現したい場面で、7Bモデル+MCTS-RAGの組み合わせが有効

一方、レイテンシが許容されない場面(リアルタイムチャット)では、Zenn記事のようなプロンプトベースの階層的検索がより適切です。

関連研究(Related Work)

  • RAP (Hao et al., 2023): MCTSをLLMの計画タスクに適用。MCTS-RAGの直接的な先行研究で、MCTS-RAGはRAPに検索機能を追加
  • IRCoT (Trivedi et al., 2022): 検索とChain-of-Thought推論のインターリービング。MCTS-RAGはIRCoTの固定パターンをMCTS探索に発展
  • Search-o1 (Li et al., 2025): o1型推論に検索を統合。MCTS-RAGと類似のアプローチだが、MCTSの原則的な探索フレームワークを持たない
  • StructRAG (Li et al., 2024): 構造化知識でRAGを強化。NeurIPS 2024採択。MCTS-RAGとは相補的

まとめと今後の展望

MCTS-RAGは、「検索」と「推論」を体系的に探索するフレームワークにより、小規模モデルでフロンティアLLMと競合する性能を実現しました。計算コスト(20-30回のLLM呼び出し/問)と引き換えに、従来のRAGやMCTS単体では到達できない精度を達成しています。今後は、学習ベースのアクションポリシー、効率的な探索戦略、マルチモーダルへの拡張が研究方向として示されています。

参考文献

この投稿は CC BY 4.0 でライセンスされています。

論文解説: MindSearch — DAGベース並列検索エージェントによるマルチソース情報統合

論文解説: OctoTools — DAG並列実行で推論時間47%削減の拡張可能なエージェントフレームワーク