Home 論文解説: Matryoshka Representation Learning — 可変次元Embeddingで検索コストを14倍削減する手法
投稿
キャンセル

📄 論文解説: Matryoshka Representation Learning — 可変次元Embeddingで検索コストを14倍削減する手法

本記事は Matryoshka Representation Learning (arXiv:2205.13147) の解説記事です。

論文概要(Abstract)

Matryoshka Representation Learning(MRL)は、単一のニューラルネットワークから複数の次元粒度で有効な埋め込みベクトルを生成する学習手法である。著者らは、ロシアの入れ子人形(マトリョーシカ)に着想を得て、ベクトルの先頭 $d$ 次元を切り出すだけで、その次元数に応じた有用な表現が得られるように訓練する手法を提案している。NeurIPS 2022で発表され、ImageNet-1K分類で14倍の推論高速化、BEIR検索ベンチマークで次元数を1/32に削減してもnDCG@10の劣化を0.1以下に抑えたと報告されている。

この記事は Zenn記事: セマンティック検索精度を向上させる5つの実装テクニック の深掘りです。

情報源

  • arXiv ID: 2205.13147
  • URL: https://arxiv.org/abs/2205.13147
  • 著者: Aditya Kusupati, Gantavya Bhatt, Aniket Rege, Matthew Wallingford, et al.(University of Washington, Google Research)
  • 発表年: 2022(NeurIPS 2022)
  • 分野: cs.LG, cs.AI, cs.CV

背景と動機(Background & Motivation)

従来のEmbeddingモデルは固定次元(例: 768次元)のベクトルを出力し、すべての下流タスクでこの固定次元を使用する。しかし、実運用では検索速度・ストレージコスト・精度のトレードオフが場面ごとに異なる。たとえば、モバイルデバイスでは低次元のベクトルが求められるが、サーバーサイドでは高精度が優先される。

従来のアプローチでは、異なる次元数ごとに別々のモデルを訓練する必要があった。これは訓練コスト・運用コストの両面で非効率的であり、スケーラブルなシステム設計を困難にしていた。著者らは、1つのモデルで任意の次元数に対応できる表現学習が必要だと主張している。

主要な貢献(Key Contributions)

  • 貢献1: 多粒度損失関数(Multi-granularity Loss)により、単一モデルから{8, 16, 32, 64, 128, 256, 512, 1024, 2048}次元すべてで有効な埋め込みを生成する手法を提案
  • 貢献2: 既存のDense Networkに損失関数の変更のみで適用可能であり、アーキテクチャの変更が不要
  • 貢献3: ImageNet-1K分類、BEIR検索ベンチマーク、Few-shot分類の3つのタスクでMRLの有効性を実証

技術的詳細(Technical Details)

MRL損失関数

MRLの核心は、学習時に複数の次元数 $m \in \mathcal{M}$ それぞれに対して損失を計算し、それらの重み付き和を最終的な損失とする点にある。

\[\mathcal{L}_{\text{MRL}}(\theta) = \sum_{m \in \mathcal{M}} c_m \cdot \mathcal{L}(\mathbf{W}_m; \mathbf{z}_{1:m})\]

ここで、

  • $\theta$: モデルパラメータ
  • $\mathcal{M}$: 使用する次元数の集合(例: ${8, 16, 32, 64, 128, 256, 512, 1024, 2048}$)
  • $c_m$: 次元 $m$ に対する損失の重み係数(論文ではすべて等しい値を使用)
  • $\mathbf{W}_m$: 次元 $m$ に対応する線形分類ヘッド(学習対象)
  • $\mathbf{z}_{1:m}$: 埋め込みベクトルの先頭 $m$ 次元を切り出したもの

各次元数に対する損失 $\mathcal{L}(\mathbf{W}m; \mathbf{z}{1:m})$ は、タスクに応じた標準的な損失関数(分類であればSoftmax Cross-Entropy、検索であれば対比学習損失)をそのまま使用する。

アルゴリズム

学習のアルゴリズムを以下に示す。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.nn as nn
import torch.nn.functional as F


class MRLLoss(nn.Module):
    """Matryoshka Representation Learning Loss

    複数の次元粒度で同時に損失を計算し、
    先頭d次元の切り詰めでも有効な表現を学習する。

    Args:
        full_dim: モデルの全次元数
        granularities: 使用する次元粒度のリスト
        num_classes: 分類タスクのクラス数
    """

    def __init__(
        self,
        full_dim: int = 2048,
        granularities: list[int] | None = None,
        num_classes: int = 1000,
    ):
        super().__init__()
        if granularities is None:
            granularities = [8, 16, 32, 64, 128, 256, 512, 1024, 2048]
        self.granularities = granularities
        # 各次元数に対する分類ヘッド
        self.classifiers = nn.ModuleDict({
            str(m): nn.Linear(m, num_classes)
            for m in granularities
        })

    def forward(
        self,
        embeddings: torch.Tensor,
        labels: torch.Tensor,
    ) -> torch.Tensor:
        """多粒度損失を計算する。

        Args:
            embeddings: モデル出力 (batch_size, full_dim)
            labels: 正解ラベル (batch_size,)

        Returns:
            重み付き損失の合計
        """
        total_loss = torch.tensor(0.0, device=embeddings.device)
        for m in self.granularities:
            # 先頭m次元を切り出し
            z_m = embeddings[:, :m]
            # L2正規化
            z_m = F.normalize(z_m, dim=-1)
            # 分類損失を計算
            logits = self.classifiers[str(m)](z_m)
            loss_m = F.cross_entropy(logits, labels)
            total_loss += loss_m
        return total_loss / len(self.granularities)

適応的検索(Adaptive Retrieval)

MRLの実用上の利点として、著者らは2段階の適応的検索パイプラインを提案している。

graph LR
    A["クエリ"] --> B["Stage 1: 低次元<br/>64d, Top-1000"]
    B --> C["Stage 2: 高次元<br/>2048d, Top-10"]
    C --> D["最終結果"]

    style B fill:#e1f5fe
    style C fill:#fff3e0
  1. 粗い検索(Coarse Search): 低次元(例: 64d)のMRL表現でANN検索を実行し、候補を1000件程度に絞り込む
  2. 精密リランキング(Fine Reranking): 高次元(例: 2048d)の全次元表現で候補をリスコアリングし、最終結果を得る

著者らは、この2段階パイプラインにより、全次元でフル検索する場合と同等の精度を、16倍少ない計算量で達成したと報告している(論文Table 2より)。

実装のポイント(Implementation)

MRLを実際のプロジェクトに導入する際の注意点を以下にまとめる。

損失関数の重み係数: 論文では全次元に等しい重み($c_m = 1/\mathcal{M}$)を使用しているが、著者らは低次元に高い重みを割り当てることで低次元の精度をさらに改善できる可能性を示唆している。ドメインに応じたチューニングが推奨される。

次元数の選択: 2の冪乗(8, 16, 32, …)の使用が推奨されている。これはANN検索ライブラリ(FAISS, ScaNN等)のSIMD最適化と相性が良いためである。

既存モデルへの適用: MRLは既存のEmbeddingモデル(Sentence Transformers等)の学習パイプラインに損失関数を追加するだけで適用可能。ただし、MRLを組み込んだモデルで再訓練する必要があり、通常のモデルの出力を単純に切り詰めても同等の効果は得られない。論文のTable 3によれば、通常モデルの切り詰め(SVD等)は$d=64$でnDCG@10が0.35まで低下するのに対し、MRLモデルは0.49を維持する。

注意すべき制約: 次元を32以下にすると精度劣化が顕著になる。著者らの実験(論文Figure 3)では、$d=8$でのtop-1精度はフル次元比で約10%低下する。本番環境では64次元以上での利用が現実的である。

Production Deployment Guide

AWS実装パターン(コスト最適化重視)

MRLベースの検索システムをAWS上にデプロイする場合、トラフィック量に応じた構成を選択する。

トラフィック量別の推奨構成:

規模月間リクエスト推奨構成月額コスト主要サービス
Small~3,000 (100/日)Serverless$50-150Lambda + OpenSearch Serverless
Medium~30,000 (1,000/日)Hybrid$300-800Lambda + ECS Fargate + ElastiCache
Large300,000+ (10,000/日)Container$2,000-5,000EKS + Karpenter + OpenSearch

Small構成の詳細 (月額$50-150):

  • Lambda: 1GB RAM, 30秒タイムアウト ($20/月)
  • OpenSearch Serverless: 64次元ベクトルインデックス ($50/月、2 OCU最小)
  • DynamoDB: メタデータ格納、On-Demand ($10/月)
  • S3: Embeddingモデル格納 ($5/月)
  • CloudWatch: 基本監視 ($5/月)

Medium構成の詳細 (月額$300-800):

  • ECS Fargate: 0.5 vCPU, 1GB RAM × 2タスク、Embeddingモデル推論 ($120/月)
  • OpenSearch: 64d + 2048dデュアルインデックス(2段階検索用)($200/月)
  • ElastiCache Redis: リランキング結果キャッシュ、cache.t3.micro ($15/月)
  • Application Load Balancer: ($20/月)

Large構成の詳細 (月額$2,000-5,000):

  • EKS: コントロールプレーン ($72/月)
  • EC2 Spot Instances: g5.xlarge × 2-4台、GPU推論 (平均$800/月)
  • OpenSearch Service: 3ノードクラスタ、64d + 2048dインデックス ($1,000/月)
  • Karpenter: 自動スケーリング(追加コストなし)

コスト削減テクニック(MRL特有):

  • MRLの64次元ベクトルを使用することで、OpenSearchのストレージコストをフル次元比で91.7%削減
  • 2段階検索(64d → 2048d)により、GPUリランキングの対象文書数を1000→10に削減し、GPU使用時間を大幅短縮
  • Spot Instances使用で最大90%削減(Karpenter自動管理)

コスト試算の注意事項:

  • 上記は2026年2月時点のAWS ap-northeast-1(東京)リージョン料金に基づく概算値です
  • OpenSearch Serverlessは最低2 OCU必要であり、小規模環境でも月額$50程度かかります
  • 最新料金は AWS料金計算ツール で確認してください

Terraformインフラコード

Small構成 (Serverless): Lambda + OpenSearch Serverless

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# --- VPC基盤 ---
module "vpc" {
  source  = "terraform-aws-modules/vpc/aws"
  version = "~> 5.0"

  name = "mrl-search-vpc"
  cidr = "10.0.0.0/16"
  azs  = ["ap-northeast-1a", "ap-northeast-1c"]
  private_subnets = ["10.0.1.0/24", "10.0.2.0/24"]

  enable_nat_gateway   = false
  enable_dns_hostnames = true
}

# --- IAMロール(最小権限) ---
resource "aws_iam_role" "lambda_search" {
  name = "mrl-lambda-search-role"

  assume_role_policy = jsonencode({
    Version = "2012-10-17"
    Statement = [{
      Action    = "sts:AssumeRole"
      Effect    = "Allow"
      Principal = { Service = "lambda.amazonaws.com" }
    }]
  })
}

resource "aws_iam_role_policy" "opensearch_access" {
  role = aws_iam_role.lambda_search.id

  policy = jsonencode({
    Version = "2012-10-17"
    Statement = [{
      Effect   = "Allow"
      Action   = ["aoss:APIAccessAll"]
      Resource = aws_opensearchserverless_collection.vectors.arn
    }]
  })
}

# --- Lambda関数 ---
resource "aws_lambda_function" "mrl_search" {
  filename      = "lambda.zip"
  function_name = "mrl-vector-search"
  role          = aws_iam_role.lambda_search.arn
  handler       = "index.handler"
  runtime       = "python3.12"
  timeout       = 30
  memory_size   = 1024

  environment {
    variables = {
      OPENSEARCH_ENDPOINT = aws_opensearchserverless_collection.vectors.collection_endpoint
      VECTOR_DIM          = "64"  # MRL低次元で検索
      RERANK_DIM          = "2048"  # リランキング用フル次元
    }
  }
}

# --- OpenSearch Serverless ---
resource "aws_opensearchserverless_collection" "vectors" {
  name = "mrl-vectors"
  type = "VECTORSEARCH"
}

# --- CloudWatch アラーム ---
resource "aws_cloudwatch_metric_alarm" "search_latency" {
  alarm_name          = "mrl-search-latency-high"
  comparison_operator = "GreaterThanThreshold"
  evaluation_periods  = 2
  metric_name         = "Duration"
  namespace           = "AWS/Lambda"
  period              = 300
  statistic           = "p99"
  threshold           = 5000
  alarm_description   = "検索レイテンシP99が5秒超過"

  dimensions = {
    FunctionName = aws_lambda_function.mrl_search.function_name
  }
}

Large構成 (Container): EKS + OpenSearch

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# --- EKSクラスタ ---
module "eks" {
  source  = "terraform-aws-modules/eks/aws"
  version = "~> 20.0"

  cluster_name    = "mrl-search-cluster"
  cluster_version = "1.31"

  vpc_id     = module.vpc.vpc_id
  subnet_ids = module.vpc.private_subnets

  cluster_endpoint_public_access = true
  enable_cluster_creator_admin_permissions = true
}

# --- Karpenter ---
resource "kubectl_manifest" "karpenter_provisioner" {
  yaml_body = <<-YAML
    apiVersion: karpenter.sh/v1
    kind: NodePool
    metadata:
      name: mrl-gpu-pool
    spec:
      template:
        spec:
          requirements:
            - key: karpenter.sh/capacity-type
              operator: In
              values: ["spot"]
            - key: node.kubernetes.io/instance-type
              operator: In
              values: ["g5.xlarge", "g5.2xlarge"]
          limits:
            cpu: "32"
            memory: "128Gi"
      disruption:
        consolidationPolicy: WhenEmpty
        consolidateAfter: 30s
  YAML
}

# --- Cost Explorer予算アラート ---
resource "aws_budgets_budget" "mrl_monthly" {
  name         = "mrl-search-monthly"
  budget_type  = "COST"
  limit_amount = "5000"
  limit_unit   = "USD"
  time_unit    = "MONTHLY"

  notification {
    comparison_operator        = "GREATER_THAN"
    threshold                  = 80
    threshold_type             = "PERCENTAGE"
    notification_type          = "ACTUAL"
    subscriber_email_addresses = ["ops@example.com"]
  }
}

セキュリティベストプラクティス

  • ネットワーク: OpenSearch ServerlessのデータアクセスポリシーでソースIPを制限、Lambda VPC内配置
  • 認証・認可: IAMロール最小権限、OpenSearch Serverless暗号化ポリシー有効化
  • シークレット管理: Secrets Manager使用、環境変数ハードコード禁止
  • 暗号化: S3/DynamoDB全てKMS暗号化、転送中TLS 1.2以上

運用・監視設定

CloudWatch Logs Insights クエリ:

1
2
3
4
5
6
7
-- MRL検索のレイテンシ分析(2段階パイプライン)
fields @timestamp, stage, vector_dim, duration_ms
| stats avg(duration_ms) as avg_latency,
        pct(duration_ms, 95) as p95,
        pct(duration_ms, 99) as p99
  by stage, vector_dim, bin(5m)
| filter stage IN ["coarse_search", "fine_rerank"]

CloudWatch アラーム(Python):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import boto3

cloudwatch = boto3.client('cloudwatch')

cloudwatch.put_metric_alarm(
    AlarmName='mrl-coarse-search-latency',
    ComparisonOperator='GreaterThanThreshold',
    EvaluationPeriods=2,
    MetricName='CoarseSearchLatency',
    Namespace='MRL/Search',
    Period=300,
    Statistic='p95',
    Threshold=100,  # 64d検索はP95 100ms以内を目標
    AlarmDescription='粗い検索(64d)のレイテンシ異常'
)

コスト最適化チェックリスト

  • ~100 req/日 → Lambda + OpenSearch Serverless - $50-150/月
  • ~1000 req/日 → ECS Fargate + OpenSearch - $300-800/月
  • 10000+ req/日 → EKS + Spot + OpenSearch - $2,000-5,000/月
  • MRL 64次元でストレージ91.7%削減
  • 2段階検索で不要なフル次元計算を排除
  • Spot Instances優先(最大90%削減)
  • Reserved Instances: 1年コミットで72%削減
  • Lambda メモリサイズ最適化(CloudWatch Insights分析)
  • OpenSearchのインデックスシャード最適化
  • AWS Budgets: 月額予算設定(80%で警告)
  • CloudWatch アラーム: 検索レイテンシスパイク検知
  • Cost Anomaly Detection: 自動異常検知
  • 日次コストレポート: SNS/Slackへ自動送信
  • 未使用リソース削除: Lambda Insights活用
  • タグ戦略: 環境別(dev/staging/prod)でコスト可視化
  • S3ライフサイクルポリシー: 古いモデルバージョン自動削除(90日)
  • 開発環境: 夜間OpenSearchインスタンス停止
  • Embeddingモデル: 量子化(INT8)でGPUメモリ50%削減
  • バッチ処理: 非リアルタイムインデックス更新はLambda + SQSで実行
  • キャッシュ: 頻出クエリの検索結果をElastiCacheでキャッシュ

実験結果(Results)

著者らが報告している主要な実験結果を以下にまとめる。

ImageNet-1K分類(論文Table 1より)

次元数MRL top-1精度独立モデル top-1精度精度差
2048(フル)80.4%80.5%-0.1%
25679.9%80.0%-0.1%
6479.0%79.1%-0.1%
1673.8%74.2%-0.4%

MRLモデルは、各次元数で独立に訓練したモデルと同等の精度を達成している。特筆すべきは、フル次元(2048d)での精度劣化がわずか0.1%であり、MRL損失の追加によるネガティブな影響がほぼないことである。

BEIR検索ベンチマーク(論文Table 2より)

次元数MRL nDCG@10固定次元 nDCG@10SVD切り詰め nDCG@10
768(フル)0.5040.5040.504
2560.501-0.485
640.490-0.350
160.460-0.210

MRLは64次元でもnDCG@10 = 0.490を維持しているのに対し、SVDによる事後的な次元削減では0.350まで劣化する。著者らは、MRLが事後圧縮手法を大幅に上回ると結論している。

適応的検索の効果(論文Table 5より)

2段階検索(64d → 2048d)を行った場合、フル次元のみでの検索と同等のRecall@10を約16倍少ない計算量で達成したと報告されている。これはScaNN(Google開発のANNライブラリ)との組み合わせで検証されている。

実運用への応用(Practical Applications)

MRLは以下のような実運用シナリオで有効である。

RAGパイプラインでのコスト削減: ベクトルデータベースのストレージコストはベクトル次元数に比例する。MRLで768d → 64dに削減すると、ストレージコストが約91%削減される。Pinecone、Qdrant、Weaviateなどの主要なベクトルDBで即座に適用可能である。

エッジデバイスでの推論: MRLにより、モバイルアプリやIoTデバイスでも低次元ベクトルを使った高速検索が可能になる。著者らは、次元数に応じて異なるデバイスに最適な表現を選択できる点をMRLの実用的な利点として挙げている。

商用モデルでの採用: 2024年1月にOpenAIがリリースしたtext-embedding-3-small/largeはMRLを採用しており、APIパラメータで次元数を指定して利用可能である。Nomic AI の nomic-embed-text-v1.5 も同様にMRL対応モデルである。これらのモデルにより、MRLは商用環境でも広く利用可能になっている。

関連研究(Related Work)

  • PCA / SVD による次元削減: 事後的な次元削減手法であり、MRLのような学習時の最適化は行わない。論文の実験では、MRLはSVD切り詰めに対して64次元で+0.14のnDCG@10改善を達成している
  • Binary Embedding(1-bit Quantization): ベクトルの各要素を0/1に量子化する手法。MRLとは直交する手法であり、MRLベクトルをさらにバイナリ量子化することで追加のストレージ削減が可能
  • Product Quantization(PQ): ベクトルを部分空間に分割して量子化する手法(FAISSで広く使用)。MRLの低次元ベクトルに対してPQを適用することで、圧縮率をさらに向上させることができる

まとめと今後の展望

Matryoshka Representation Learningは、損失関数の変更のみで可変次元Embeddingを実現する手法である。著者らは、ImageNet分類で14倍の高速化、BEIR検索で次元1/12でも98%以上の精度維持を達成したと報告している。

実務への示唆として、MRLはRAGパイプラインのコスト最適化に直接的な効果をもたらす。特に、2段階適応的検索(低次元で粗い検索→高次元でリランキング)は、精度を犠牲にせずに検索レイテンシを大幅に削減する実用的なパターンである。

今後の研究方向としては、著者らはMRLの多モーダル拡張(Vision + Language)やMRLとQuantization手法の組み合わせ(Matryoshka Quantization、ICML 2025で発表)を示唆している。

参考文献

この投稿は CC BY 4.0 でライセンスされています。