Home 論文解説: One Router, Many Models — Cross-Attentionによるコスト考慮型LLMルーティング
投稿
キャンセル

📄 論文解説: One Router, Many Models — Cross-Attentionによるコスト考慮型LLMルーティング

論文概要(Abstract)

One Router, Many Modelsは、Microsoft Researchが提案したCross-Attentionベースのソフトルーティングメカニズムです。従来の分類ベースルーティング(1モデルを選択)とは異なり、クエリ表現とモデル能力埋め込みの間でCross-Attentionを計算し、ソフトルーティング重みを生成します。さらにコスト考慮パラメータ $\lambda$ を導入し、$\lambda = 0$(品質最大化)から $\lambda = 1$(コスト最小化)まで連続的なトレードオフ制御を可能にしました。MT-Benchでは $\lambda = 0.1$ 設定でMoAの99%品質を30%のコストで達成しています。

この記事は Zenn記事: GeminiとClaudeを使い分けるマルチLLMルーティング実装ガイド の深掘りです。

情報源

  • arXiv ID: 2506.09157
  • URL: https://arxiv.org/abs/2506.09157
  • 著者: Roshini Pulishetty, Xiaodong Liu, Jianfeng Gao, Weizhu Chen(Microsoft Research)
  • 発表年: 2025
  • 分野: cs.CL, cs.LG, cs.AI

背景と動機(Background & Motivation)

LLMルーティングの既存手法には以下の限界があります。

RouteLLM/REGROUP: 分類器がクエリを1つのモデルにハード割り当てする。クエリとモデルの間の微妙な親和性を捉えきれず、「コード生成だが一部要約も必要」といった複合タスクで最適選択を逃す。

MoA: 全モデルを使い品質は最高だが、コストがN倍。品質とコストの連続的な制御ができない

本論文は、TransformerのCross-Attention機構をルーティングに転用することで、クエリ-モデル間の粒度の高い親和性を捉え、$\lambda$ パラメータで品質-コストのトレードオフを連続制御する手法を提案します。Zenn記事のLiteLLMルーターにおける classify_task() の高精度版として位置づけられます。

主要な貢献(Key Contributions)

  • ソフトルーティング: Cross-Attentionによるアテンション重みがルーティング確率を表現。ハード分類(1モデル選択)とソフトアンサンブル(top-k重み付き統合)の両方に対応
  • コスト考慮型損失関数: $\mathcal{L}{\text{total}} = \mathcal{L}{\text{quality}} - \lambda \cdot \mathcal{L}_{\text{cost}}$ により、単一パラメータ $\lambda$ で品質-コストを連続制御
  • モデル能力埋め込み: 各候補LLMの特性(推論力、知識量、コスト等)を学習可能な埋め込みベクトルとして表現。ルーターと共同訓練

技術的詳細(Technical Details)

Cross-Attentionルーティングの数学的定式化

候補モデル集合 $\mathcal{M} = {m_1, m_2, \ldots, m_n}$ に対し、各モデルの能力埋め込みを $\mathbf{K} = [\mathbf{k}_1, \mathbf{k}_2, \ldots, \mathbf{k}_n] \in \mathbb{R}^{n \times d}$ とします。クエリ $q$ のBERTエンコーダ出力を $\mathbf{Q} \in \mathbb{R}^{L \times d}$($L$: シーケンス長)とすると、Cross-Attentionは以下で計算されます:

\[\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d}}\right)\mathbf{V}\]

ここで $\mathbf{V} = \mathbf{K}$(Self-Attention型のKey-Value共有)。ルーティング重みはアテンション重みのシーケンス方向平均として得られます:

\[\mathbf{w} = \frac{1}{L} \sum_{l=1}^{L} \text{softmax}\left(\frac{\mathbf{q}_l \cdot \mathbf{K}^T}{\sqrt{d}}\right) \in \mathbb{R}^n\]

ここで、

  • $\mathbf{q}_l$: クエリの $l$ 番目のトークン表現($d$ 次元)
  • $\mathbf{K}^T$: モデル能力埋め込み行列の転置($d \times n$)
  • $\mathbf{w}$: 各モデルへのルーティング重み(確率分布)
  • $w_j$: モデル $m_j$ に割り当てられる確率

コスト考慮型目的関数

コストペナルティを導入した目的関数:

\[\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{quality}} - \lambda \cdot \sum_{j=1}^{n} w_j \cdot c_j\]

ここで、

  • $\mathcal{L}_{\text{quality}}$: 報酬モデルスコアに基づく品質損失
  • $\lambda$: コスト-品質トレードオフパラメータ($0 \leq \lambda \leq 1$)
  • $c_j$: モデル $m_j$ のトークンあたりコスト
  • $w_j$: モデル $m_j$ へのルーティング重み

$\lambda$ の効果:

  • $\lambda = 0$: 品質のみ最大化(MoA相当の品質、高コスト)
  • $\lambda = 0.1$: 推奨値 — MoAの99%品質、30%コスト
  • $\lambda = 0.5$: 単一モデルコスト、品質はやや向上
  • $\lambda = 1.0$: コスト最小化(品質低下のリスク)

ハードルーティング vs ソフトルーティング

ハードルーティング(1モデル選択):

\[m^* = \arg\max_{j} w_j\]

コストは単一モデル呼び出しのみ(1.0x)。

ソフトルーティング(top-k重み付き統合):

上位 $k$ モデルを選択し、MoA的にAggregatorで統合:

\[r_{\text{final}} = \text{Aggregate}\left(\{(w_j, m_j(q))\}_{j \in \text{top-k}(\mathbf{w})}\right)\]

コストは $k$ 倍だが、品質はハードルーティングより高い。

アルゴリズム実装

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel

class CrossAttentionRouter(nn.Module):
    """Cross-Attentionベースのコスト考慮型LLMルーター

    Args:
        n_models: 候補モデル数
        embed_dim: 埋め込み次元(BERT-base: 768)
        model_costs: 各モデルのトークンあたりコスト($/1Kトークン)
    """
    def __init__(
        self,
        n_models: int = 4,
        embed_dim: int = 768,
        model_costs: list[float] | None = None,
    ):
        super().__init__()
        self.query_encoder = BertModel.from_pretrained("bert-base-uncased")
        self.model_embeddings = nn.Embedding(n_models, embed_dim)
        self.attention = nn.MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=1,  # "One Head" — 論文タイトルの由来
            batch_first=True,
        )
        self.cost_weights = nn.Parameter(
            torch.tensor(model_costs or [1.0] * n_models)
        )

    def forward(
        self, query_tokens: dict, lambda_cost: float = 0.1
    ) -> torch.Tensor:
        """ルーティング重みを計算

        Args:
            query_tokens: BERTトークナイザ出力
            lambda_cost: コスト-品質トレードオフ (0=品質最大, 1=コスト最小)

        Returns:
            routing_weights: shape (batch, n_models) の確率分布
        """
        # クエリをBERTでエンコード
        q = self.query_encoder(**query_tokens).last_hidden_state

        # モデル能力埋め込みをKey/Valueとして使用
        k = v = self.model_embeddings.weight.unsqueeze(0).expand(
            q.size(0), -1, -1
        )

        # Cross-Attention: クエリがモデル埋め込みに注目
        _, attn_weights = self.attention(q, k, v)
        routing_weights = attn_weights.mean(dim=1)  # シーケンス平均

        # コストペナルティ適用
        cost_penalty = lambda_cost * F.softmax(self.cost_weights, dim=0)
        routing_weights = routing_weights - cost_penalty.unsqueeze(0)
        routing_weights = F.softmax(routing_weights, dim=-1)

        return routing_weights

訓練設定

パラメータ推奨値備考
Router学習率1e-4BERTエンコーダのfine-tuning
埋め込み学習率1e-3モデル能力埋め込みの学習
バッチサイズ32メモリに応じて調整
エポック数10早期停止あり
報酬モデルArmoRM-Llama3-8Bコスト効率と精度のバランス
$\lambda$ 初期値0.1予算制約に応じて0.1-0.5で調整
候補モデル数3-55以上は効果逓減

実装のポイント(Implementation)

Zenn記事のLiteLLMとの統合: 本手法のCross-Attentionルーターは、LiteLLMのRouterクラスのカスタムrouting_strategyとして組み込めます。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from litellm import Router

# Cross-Attentionルーターの推論ラッパー
ca_router = CrossAttentionRouter(
    n_models=3,
    model_costs=[0.005, 0.003, 0.0008],  # GPT-4o, Sonnet, LLaMA
)

async def cross_attention_route(prompt: str) -> str:
    """Cross-Attention + LiteLLMのハイブリッド"""
    tokens = tokenizer(prompt, return_tensors="pt")
    weights = ca_router(tokens, lambda_cost=0.1)
    model_idx = weights.argmax(dim=-1).item()
    model_names = ["high-quality", "cost-optimized", "budget"]

    resp = await litellm_router.acompletion(
        model=model_names[model_idx],
        messages=[{"role": "user", "content": prompt}],
    )
    return resp.choices[0].message.content

モデル追加時の注意: 新しいLLMを候補に追加する場合、model_embeddingsの再訓練が必要です。既存の埋め込みは初期値として再利用できますが、新モデルの埋め込みはランダム初期化から学習します。

$\lambda$ のチューニング: まず $\lambda = 0.1$ で運用を開始し、月次コストレビューで調整します。品質低下が許容範囲内であれば $\lambda$ を0.05ずつ増加させてコスト削減を進めます。

Production Deployment Guide

AWS実装パターン(コスト最適化重視)

Cross-AttentionルーターはBERTベースのため、CPU推論でも実用的な速度が出ます。

規模月間リクエスト推奨構成月額コスト主要サービス
Small~3,000 (100/日)Serverless$80-200Lambda (ルーター) + Bedrock
Medium~30,000 (1,000/日)Hybrid$500-1,500ECS Fargate + Bedrock + ElastiCache
Large300,000+ (10,000/日)Container$3,000-8,000EKS + SageMaker Endpoint

コスト試算の注意事項: 上記は2026年2月時点のAWS ap-northeast-1料金に基づく概算値です。ルーター自体のコスト(BERT推論)はLLM API呼び出しの1%未満です。

Terraformインフラコード

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# --- Cross-Attentionルーター用Lambda ---
resource "aws_lambda_function" "ca_router" {
  filename      = "ca_router.zip"
  function_name = "cross-attention-llm-router"
  role          = aws_iam_role.router_role.arn
  handler       = "router.handler"
  runtime       = "python3.12"
  timeout       = 20
  memory_size   = 1024

  environment {
    variables = {
      MODEL_PATH    = "s3://ml-models/ca-router/model.onnx"
      EMBED_PATH    = "s3://ml-models/ca-router/embeddings.pt"
      LAMBDA_COST   = "0.1"
      N_MODELS      = "3"
    }
  }
}

# --- λパラメータ動的調整用 ---
resource "aws_ssm_parameter" "lambda_cost" {
  name  = "/ca-router/lambda-cost"
  type  = "String"
  value = "0.1"
  description = "コスト-品質トレードオフパラメータ (0=品質最大, 1=コスト最小)"
}

運用・監視設定

1
2
3
4
5
6
7
8
9
-- λ別の品質-コスト分布
fields @timestamp, lambda_cost, avg_quality, total_cost
| stats avg(quality_score) as avg_quality,
        sum(token_cost) as total_cost
  by lambda_cost, bin(1d)

-- モデル選択分布(アテンション重み可視化)
fields @timestamp, model_name, attention_weight
| stats avg(attention_weight) as avg_weight by model_name, bin(1h)

コスト最適化チェックリスト

λチューニング:

  • $\lambda = 0.1$ から開始
  • 週次で品質スコアとコストを確認
  • 品質が目標値以上なら $\lambda$ を0.05ずつ増加
  • SSM Parameterで動的変更(再デプロイ不要)

ルーター最適化:

  • ONNX変換でCPU推論を高速化
  • ルーティング結果キャッシュ(同一クエリの再計算回避)
  • 定期再訓練(月次推奨)

LLMコスト削減:

  • Prompt Caching有効化
  • Batch API活用(非リアルタイム処理)
  • max_tokens設定
  • Bedrock Savings Plans検討

監視・アラート:

  • AWS Budgets: λ値に応じた予算設定
  • CloudWatch: アテンション重み分布の異常検知
  • 品質スコアのダッシュボード化

実験結果(Results)

MT-Bench

手法スコア相対コスト
Best Single (GPT-4o)9.061.0x
Classification Router8.821.0x
REGROUP9.08~1.5x
MoA(全モデル)9.254.0x
本手法 ($\lambda = 0.1$)9.181.3x
本手法 ($\lambda = 0.5$)9.051.1x

AlpacaEval 2.0 LC

手法Win Rate相対コスト
GPT-4o57.5%1.0x
REGROUP57.8%~1.5x
MoA65.1%4.0x
本手法 ($\lambda = 0.1$)61.3%1.4x

分析: $\lambda = 0.1$ 設定で、MoA(4.0xコスト)の品質の99%を1.3xコストで達成。Cross-Attention vs 分類ルーター比較ではMT-Benchで+3.2%、ソフト vs ハードルーティングでは+1.8%の改善。コスト-品質のパレート最適フロンティアを $\lambda$ で連続的にトレースできる点が最大の強みです。

実運用への応用(Practical Applications)

Zenn記事で紹介した3段階のルーティング(high-quality / cost-optimized / budget)を、本手法の $\lambda$ パラメータで連続スペクトル化できます。

時間帯別の動的 $\lambda$ 調整:

1
2
3
4
5
6
7
8
9
def get_lambda_for_time() -> float:
    """時間帯に応じたλ値を返す"""
    hour = datetime.now().hour
    if 9 <= hour <= 18:   # ビジネスアワー: 品質重視
        return 0.05
    elif 18 < hour <= 23:  # 夜間: バランス
        return 0.2
    else:                  # 深夜: コスト重視
        return 0.5

関連研究(Related Work)

  • RouteLLM (Ong et al., 2024): 行列分解による2モデル択一。本手法はAttentionでN候補を同時評価
  • REGROUP (Ren et al., 2024): クラスタ別ハードルーティング。本手法はソフト重みで柔軟な制御
  • MoA (Wang et al., 2024): 全モデルアンサンブル。本手法はコスト制約付きの部分アンサンブル

まとめと今後の展望

Cross-Attentionルーティングは、ソフトルーティング$\lambda$ コスト制御の2つの革新により、ハードルーティングとフルアンサンブルの間のギャップを埋める手法です。Zenn記事のLiteLLMルーターに classify_task() の代替として組み込むことで、品質とコストの精密な制御が実現します。今後はモデル追加時の増分学習(既存埋め込みの再利用)や、オンライン $\lambda$ 適応が研究方向として期待されます。

参考文献

この投稿は CC BY 4.0 でライセンスされています。

論文解説: Agentic AI Systems in Financial Services — マルチエージェントLLMの信頼性設計パターン

Anthropic解説: Effective Context Engineering for AI Agents — LLMの注意予算を最適化する実践戦略