Home 論文解説: AgentPRM — プロセス報酬モデルによるLLMエージェントの段階的改善フレームワーク
投稿
キャンセル

📄 論文解説: AgentPRM — プロセス報酬モデルによるLLMエージェントの段階的改善フレームワーク

論文概要

Process Reward Models for LLM Agents: Practical Framework and Directions(Sanjiban Choudhury, Cornell University, 2025年2月)は、LLMエージェントにProcess Reward Model(PRM)を統合する実用的なフレームワーク「AgentPRM」を提案する。従来のOutcome Reward Model(ORM)がエピソード終端でのみ報酬を与えるのに対し、PRMは各ステップで密な報酬信号を提供し、信用割当問題を緩和する。

AgentPRMはMonte Carloロールアウトで自動的にステップレベルの報酬ターゲットを計算し、既存のRLHFパイプラインに最小限の変更で統合できる。ALFWorldベンチマークでLlama3.2-3Bモデルが91.0%の成功率を達成し、GPT-4o(65.7%)やClaude-3.5-Sonnet(76.1%)を大幅に上回った。

背景と動機

LLMエージェントの学習において、報酬信号の疎さは根本的な課題である。ReActパラダイムでエージェントが10〜30ステップの行動系列を実行する場合、ORMではタスク成否という終端報酬しか得られず、どのステップが成功・失敗に寄与したかの信用割当が困難になる。

数学的推論タスクではPRM(MATH-Shepherdなど)がステップレベル報酬で大きな改善を示していたが、エージェントタスクへの適用には固有の課題がある。数学では遷移が決定的であるのに対し、エージェント環境では遷移が確率的で、ビームサーチのような探索手法が直接適用できない。本論文はこのギャップを埋める。

主要な貢献

1. AgentPRM: MC ロールアウトによる自動アノテーション

PRMの訓練ラベルを人手なしで生成する手法を提案する。状態-行動ペア$(s_t, a_t)$に対し、現在のポリシーから複数のロールアウトを完了させ、その平均報酬をQ値ターゲットとする。

\[\hat{Q}(s, a) = \frac{1}{|\mathcal{G}(s,a)|} \sum_{(s_t, a_t) \in \mathcal{D}(s,a)} \sum_{k=t}^{T-1} \gamma^{k-t} r_k\]

ここで$\mathcal{G}(s,a)$は$(s,a)$を通過する軌跡の集合、$\gamma$は割引率である。MCTSと異なり非同期に収集でき、スケーラビリティに優れる。

2. InversePRM: デモンストレーションからの報酬学習

明示的な成否ラベルなしで、エキスパートデモンストレーションからPRMを学習するInversePRMも提案されている。ポリシーのロールアウトとデモンストレーションの遷移を比較し、デモに近い遷移を正例、離れた遷移を負例として分類学習する。1回のイテレーションでSFTを大幅に上回る性能を達成する。

3. 既存RLHFパイプラインへの最小変更での統合

フレームワークの3ステージ(ロールアウト収集→PRM訓練→ポリシーRL更新)のうち、Stage 2・3は標準的なRLHFパイプラインと同一であり、Stage 1の報酬アノテーションのみが新規コンポーネントとなる。OpenInstructのGymラッパーとして実装されている。

技術的詳細

ターンレベルMDPの定式化

エージェント環境をターンレベルMDPとしてモデル化する。ターン$t$の状態$s_t$は観測と行動の履歴${o_0, a_0, \ldots, o_{t-1}}$であり、ポリシー$\pi(a_ts_t)$が行動を生成する。PRMはQ関数として機能する。
\[Q^\pi(s_t, a_t) = \mathbb{E}_\pi\left[\sum_{k=t}^{T} \gamma^{k-t} r(s_k, a_k) \mid s_t, a_t\right]\]

PRM訓練: ソフトBCE損失

PRMはソフト二値交差エントロピー損失で訓練される。MCロールアウトから計算した$\hat{Q}$を$[0,1]$に正規化し、ソフトラベルとして使用する。

\[\mathcal{L}(Q_\phi) = -\mathbb{E}_{(s,a,\hat{Q}) \sim \mathcal{D}} \left[\hat{Q} \log Q_\phi(s,a) + (1 - \hat{Q}) \log(1 - Q_\phi(s,a))\right]\]

ポリシー更新: Online DPO

ポリシー更新にはOnline DPOを採用し、PRMスコアを最大化しつつ前回ポリシーとのKL正則化を行う。

\[\pi_i = \arg\max_{\pi_\theta} \mathbb{E}_{s \sim \mathcal{D}, a \sim \pi_\theta}[Q_\phi(s,a)] - \beta \mathbb{D}_{\text{KL}}[\pi_\theta \| \pi_{i-1}]\]

前回ポリシー$\pi_{i-1}$への正則化は、PRMの訓練分布からの逸脱を防ぐ保守的ポリシー反復の原理に基づく。

Best-of-N推論

テスト時には各ターンで$N$個の候補応答をサンプリングし、PRMスコアが最高のものを選択する。$N$を増やすことでテスト時スケーリングが可能になる。

実装のポイント

AgentPRMの訓練ループの擬似コードを以下に示す。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from dataclasses import dataclass
from typing import Any

@dataclass
class Transition:
    """ターンレベルMDPの遷移データ."""
    state: str          # 観測履歴
    action: str         # エージェント出力(思考+行動)
    reward: float       # 環境からの報酬
    q_target: float     # MCロールアウトによるQ値推定

def compute_mc_targets(
    policy: Any,
    env: Any,
    trajectories: list[list[Transition]],
    gamma: float = 0.99,
) -> dict[tuple[str, str], float]:
    """MC ロールアウトからPRMターゲットを計算する.

    Args:
        policy: 現在のLLMポリシー
        env: 環境インスタンス
        trajectories: 収集済み軌跡のリスト
        gamma: 割引率

    Returns:
        (state, action) -> Q値推定の辞書
    """
    # 状態-行動ペアごとにロールアウトを集約
    g: dict[tuple[str, str], list[float]] = {}
    for traj in trajectories:
        for t, trans in enumerate(traj):
            key = (trans.state, trans.action)
            # 割引累積報酬を計算
            discounted_return = sum(
                gamma ** (k - t) * traj[k].reward
                for k in range(t, len(traj))
            )
            g.setdefault(key, []).append(discounted_return)

    # 平均して正規化
    q_targets = {}
    for key, returns in g.items():
        q_targets[key] = sum(returns) / len(returns)
    return q_targets

各イテレーションでは10,000本のロールアウト軌跡を並列収集する。高速推論にはSGLangやvLLMを使用し、バッチ処理で環境と対話する。

実験結果

ALFWorld ベンチマーク(136 OODゲーム)

手法成功率平均行動数
ReAct GPT-4o65.7%20.2
ReAct Claude-3.5-Sonnet76.1%19.0
Reflexion GPT-3(複数試行)88.0%
AdaPlanner GPT-3(複数試行)91.7%
AgentPRM π₃(3B, 直接推論)88.1%12.7
AgentPRM BoN(π₃,Q₂) N=1691.0%12.5

3Bパラメータモデルが、単一試行でGPT-4oを22.4ポイント上回り、複数試行を許容するReflexion(88.0%)に匹敵する性能を達成した。タスク別ではClean(87.1%)、Heat(91.3%)、Cool(91.3%)、Look(100%)、Pick 2(82.4%)と全カテゴリで高い成功率を示す。

イテレーションごとの改善推移

イテレーションポリシー成功率BoN成功率
π₀(SFT初期化)64.9%67.9%
π₁73.9%84.3%
π₂85.8%88.8%
π₃88.1%91.0%

π₁→π₂間で最大の改善(+11.9ポイント)が発生し、π₂はClaude-3.5-Sonnet(76.1%)を単一試行で上回る。これはQ₁がより成功した軌跡で訓練されたため、ポリシー改善の勾配がより効果的であったと分析されている。

テスト時スケーリング

Best-of-NのNを1→32に増加させると、初期ポリシー(π₀, π₁)では大幅な性能向上が見られるが、後期ポリシー(π₂, π₃)ではN=16以降で飽和する。これはポリシー自体の品質向上に伴い、PRMによる選択の余地が減少するためである。

実運用への応用

AgentPRMの実用的な意義は以下の点にある。

  1. 小型モデルでの高性能: 3Bモデルでも適切なPRM訓練により大規模モデル以上の性能を達成でき、推論コストの大幅削減が可能
  2. テスト時スケーリング: Best-of-N推論により、計算予算に応じた性能調整が可能。レイテンシ許容度に応じてN=1(低遅延)〜N=16(高精度)を選択できる
  3. 継続的改善: イテレーティブ訓練によりデプロイ後もオンラインデータから段階的に性能を改善できる
  4. RLHFインフラの再利用: 既存のRLHFパイプライン(OpenInstruct等)に最小限の変更で統合できるため、導入障壁が低い

関連研究

  • Reflexion(Shinn et al., 2023): 言語フィードバックによるエピソード記憶型自己改善。AgentPRMは明示的なステップ報酬で訓練レベルの改善を行う点で異なる
  • ExpeL(Zhao et al., 2023): 経験からの学習。プロンプトベースであり、重み更新は行わない
  • AgentQ(Putta et al., 2024): Q学習ベースのエージェント訓練。AgentPRMはQ関数をPRMとして分離し、ソフトBCE損失で訓練する点が異なる
  • MATH-Shepherd(Wang et al., 2024): 数学推論向けPRM。AgentPRMは確率的環境遷移を持つエージェントタスクに拡張

まとめ

AgentPRMは、MCロールアウトによる自動アノテーションとソフトBCE損失によるPRM訓練を組み合わせ、LLMエージェントの段階的改善を実現するフレームワークである。3Bモデルで91.0%(ALFWorld)という結果は、適切な報酬モデル設計によりモデルサイズの制約を克服できることを示す。InversePRMによるデモンストレーションからの学習、テスト時スケーリングなど、実用展開に向けた複数の方向性も提示されている。


Production Deployment Guide

概要

AgentPRMのPRM推論をAWSで提供し、Best-of-N推論によるエージェント品質向上を実現する構成。

アーキテクチャ

1
2
3
4
5
[API Gateway] → [Lambda/ECS] → [vLLM on GPU Instance]
                                  ├── Policy Model (3B)
                                  └── PRM Model (3B)
                     ↓
              [S3: Rollout Data] → [SageMaker Training Job]

構成パターン

Small(検証環境): g5.xlarge(A10G×1)でポリシーとPRMを同一GPUに配置。N=4のBest-of-N推論で約80%の性能を実現。月額約$800。

Medium(本番環境): g5.2xlarge×2でポリシーとPRMを分離配置。N=16のBest-of-N推論。SageMaker Training Jobでの定期的なPRM再訓練。月額約$3,000。

Large(大規模運用): p4d.24xlarge(A100×8)でのマルチモデル推論 + EKSでのオートスケーリング。継続的なオンライン学習パイプライン。月額約$15,000。

Terraform構成例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# vLLM推論エンドポイント
resource "aws_sagemaker_endpoint_configuration" "agentprm" {
  name = "agentprm-inference"

  production_variants {
    variant_name           = "primary"
    model_name             = aws_sagemaker_model.agentprm.name
    instance_type          = "ml.g5.2xlarge"
    initial_instance_count = 2
    container_startup_health_check_timeout_in_seconds = 600
  }
}

# ロールアウトデータ保存
resource "aws_s3_bucket" "rollout_data" {
  bucket = "agentprm-rollout-data"

  lifecycle_rule {
    enabled = true
    expiration {
      days = 30
    }
  }
}

# PRM再訓練ジョブ(週次)
resource "aws_sagemaker_training_job" "prm_retrain" {
  training_job_name = "prm-retrain-weekly"

  algorithm_specification {
    training_image = "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-training:2.1-gpu-py310"
    training_input_mode = "File"
  }

  resource_config {
    instance_type  = "ml.g5.12xlarge"
    instance_count = 1
    volume_size_in_gb = 200
  }

  stopping_condition {
    max_runtime_in_seconds = 86400
  }
}

モニタリング

1
2
3
4
5
6
7
8
9
10
11
12
13
# CloudWatch メトリクス
metrics:
  - name: prm_score_distribution
    description: "PRM スコアのヒストグラム(報酬ハッキング検知用)"
    threshold_high: 0.95  # スコアが高すぎる場合は報酬ハッキングの兆候
  - name: best_of_n_improvement
    description: "BoN選択による改善率"
    threshold_low: 0.02   # 改善がない場合はPRM劣化の兆候
  - name: policy_kl_divergence
    description: "現在ポリシーと参照ポリシーのKL距離"
    threshold_high: 0.5   # KLが大きすぎる場合はPRM範囲外
  - name: task_success_rate
    description: "タスク成功率の推移"

コスト最適化

  • 推論コスト: 3Bモデルはg5.xlargeで十分動作し、GPT-4o API呼び出しと比較して1/10以下のコスト
  • BoNのN値調整: レイテンシ要件に応じてN=1(リアルタイム)〜N=16(バッチ処理)を動的に切替
  • Spot Instances: 訓練ジョブにはSpot Instances(最大70%割引)を活用
  • ロールアウトデータのライフサイクル: S3ライフサイクルポリシーで30日後に自動削除し、ストレージコストを抑制
この投稿は CC BY 4.0 でライセンスされています。