Home 論文解説: Matryoshka Representation Learning — 可変次元埋め込みによる適応的表現学習
投稿
キャンセル

📄 論文解説: Matryoshka Representation Learning — 可変次元埋め込みによる適応的表現学習

論文概要(Abstract)

Matryoshka Representation Learning(MRL)は、University of Washingtonの Kusupatiらが2022年に提案した表現学習手法である。通常の埋め込みモデルは固定次元のベクトルを出力するが、MRLでは単一のモデルで複数の粒度(8, 16, 32, …, 2048次元)の埋め込みを同時に学習する。学習済みベクトルの先頭 $m$ 次元を切り出すだけで、その次元数に最適化された表現として利用できる。ImageNet-1K分類やMS-COCO検索において、独立に訓練した各次元モデルと同等以上の精度を達成しつつ、適応的検索で最大14倍の計算削減を実現したと報告されている。

この記事は Zenn記事: MTEB×JMTEBで選ぶEmbeddingモデル:精度評価の実践ガイド の深掘りです。

情報源

  • arXiv ID: 2212.09741
  • URL: https://arxiv.org/abs/2212.09741
  • 著者: Aditya Kusupati, Gantavya Bhatt, Aniket Rege et al.(University of Washington, Google Research)
  • 発表年: 2022(NeurIPS 2022 Poster)
  • 分野: cs.LG, cs.AI, cs.CV, cs.IR, stat.ML

背景と動機(Background & Motivation)

テキストや画像の埋め込みモデルは、固定次元のベクトルを出力するのが一般的である。例えばOpenAIのtext-embedding-ada-002は1536次元、BERTのCLS出力は768次元で固定されている。しかし、実運用では以下のようなトレードオフが存在する。

  1. 精度 vs ストレージ: 高次元ベクトルは精度が高いが、大規模コーパスでのインデックスサイズが膨大になる
  2. 精度 vs レイテンシ: ベクトル検索のコサイン類似度計算は次元数に比例する
  3. デバイス制約: エッジデバイスではメモリ・計算リソースが限られる

従来、異なる次元数が必要な場合は次元ごとに独立したモデルを訓練するか、PCA等の後処理で次元削減する必要があった。前者はストレージ・管理コストが高く、後者は精度低下が大きい。MRLはこの問題を、単一モデルの訓練時に複数粒度の損失を同時に最適化することで解決する。

主要な貢献(Key Contributions)

  • MRL損失関数: 先頭 $m$ 次元が常に有効な表現となるよう、マルチスケール損失を定義。エンコーダのアーキテクチャ変更なしで適用可能
  • 適応的検索(Adaptive Retrieval): 低次元で粗い候補を絞り込み、高次元で精密にリランキングする2段階検索で最大14倍の計算削減を達成
  • 幅広い適用性: ResNet-50, ViT-B/16, ALIGN, BERTなど異なるエンコーダ・モダリティで有効性を確認

技術的詳細(Technical Details)

MRL損失関数

MRLの核心は、複数の次元数における分類損失を同時に最適化する損失関数にある。

\[\mathcal{L}_{\text{MRL}} = \sum_{m \in \mathcal{M}} c_m \cdot L\bigl(W_m^\top \mathbf{z}_{1:m},\ y\bigr)\]

ここで、

  • $\mathcal{M} = {8, 16, 32, 64, 128, 256, 512, 1024, 2048}$: 表現サイズの集合
  • $\mathbf{z} \in \mathbb{R}^d$: エンコーダの出力ベクトル($d = 2048$)
  • $\mathbf{z}_{1:m}$: $\mathbf{z}$ の先頭 $m$ 次元のスライス
  • $W_m \in \mathbb{R}^{m \times K}$: 次元 $m$ 用の線形分類器($K$: クラス数)
  • $c_m$: 各スケールの重み係数(全実験で $c_m = 1$)
  • $y$: 正解ラベル
  • $L$: Softmaxクロスエントロピー損失

この設計のポイントは以下の通りである。

  1. エンコーダのアーキテクチャは変更しない: 追加されるのは $\mathcal{M}$ 個の線形ヘッドのみ
  2. 推論時は線形ヘッド不要: 推論時はエンコーダ出力 $\mathbf{z}$ の先頭 $m$ 次元を切り出すだけでよい
  3. 先頭次元に粗い情報が集約: 損失関数の構造上、先頭の少数次元にはデータの大まかな構造が、後方の次元には詳細な情報がエンコードされる

学習アルゴリズム

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import torch.nn as nn
import torch.nn.functional as F


class MRLHead(nn.Module):
    """Matryoshka Representation Learning用のマルチスケール線形ヘッド

    Args:
        max_dim: エンコーダ出力の最大次元数
        num_classes: 分類クラス数
        granularities: 表現サイズの集合
    """

    def __init__(
        self,
        max_dim: int = 2048,
        num_classes: int = 1000,
        granularities: tuple[int, ...] = (8, 16, 32, 64, 128, 256, 512, 1024, 2048),
    ):
        super().__init__()
        self.granularities = granularities
        # 各次元用の線形分類器
        self.classifiers = nn.ModuleDict({
            str(m): nn.Linear(m, num_classes)
            for m in granularities
        })

    def forward(self, z: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """MRL損失を計算

        Args:
            z: エンコーダ出力 (batch_size, max_dim)
            y: 正解ラベル (batch_size,)

        Returns:
            MRL損失値
        """
        total_loss = torch.tensor(0.0, device=z.device)
        for m in self.granularities:
            z_m = z[:, :m]  # 先頭m次元をスライス
            logits = self.classifiers[str(m)](z_m)
            total_loss += F.cross_entropy(logits, y)  # c_m = 1
        return total_loss

推論時の次元切り出しは以下のように行う。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def encode_with_mrl(
    encoder: nn.Module,
    x: torch.Tensor,
    target_dim: int,
) -> torch.Tensor:
    """MRL学習済みエンコーダで任意次元の埋め込みを取得

    Args:
        encoder: MRL学習済みエンコーダ
        x: 入力テンソル
        target_dim: 出力次元数(8, 16, 32, ...)

    Returns:
        L2正規化された埋め込み (batch_size, target_dim)
    """
    with torch.no_grad():
        z_full = encoder(x)  # (batch_size, 2048)
        z_truncated = z_full[:, :target_dim]  # 先頭target_dim次元を切り出し
        z_normalized = F.normalize(z_truncated, p=2, dim=-1)
    return z_normalized

MRL-E(Efficient MRL)

MRL-Eは、ネスト構造の中間層を導入したパラメータ効率版である。分類器の総パラメータ数を削減しつつ同等の柔軟性を維持する。小次元での精度がMRLよりやや低いが、ストレージ・計算コストで有利である。

適応的検索(Adaptive Retrieval)

MRLの最大の実用的価値は、2段階の適応的検索にある。

  1. Stage 1(Shortlist): 低次元(例: 64次元)で高速にコサイン類似度を計算し、候補を$k$件に絞り込む
  2. Stage 2(Rerank): 高次元(例: 2048次元)で候補 $k$ 件を精密にリランキングする

これにより、全データに対して2048次元のベクトル検索を行うのと比較して、精度をほぼ維持しながら計算量を大幅に削減できる。

実装のポイント(Implementation)

重要な制約: MRLの「先頭 $m$ 次元が有効」という保証は、PCA等の次元変換を適用すると崩れる。次元削減には必ずスライス(先頭 $m$ 次元の切り出し)を使用すること。

$\mathcal{M}$ の設計: 2の累乗を推奨する。論文の実験では ${8, 16, 32, 64, 128, 256, 512, 1024, 2048}$ だが、用途に応じて最小・最大を調整できる。例えばテキスト埋め込みでは ${64, 128, 256, 512, 768}$ が実用的。

重み $c_m$: 全実験で $c_m = 1$(均等重み)が使われており、非均等化の感度は低いと著者らは述べている。

訓練コスト: 線形ヘッドの追加分のみ増加し、バックボーンのフォワードパスは1回で済む。$\mathcal{M}= 9$ の場合、訓練時間の増加は数%程度と報告されている。

L2正規化: コサイン類似度で比較する場合、スライス後にL2正規化を適用する。

Production Deployment Guide

AWS実装パターン(コスト最適化重視)

MRLを活用した2段階適応的検索システムをAWSにデプロイする構成を示す。

トラフィック量別の推奨構成:

構成トラフィック月額コスト概算サービス構成
Small~100 req/日$50-150Lambda + OpenSearch Serverless
Medium~1000 req/日$300-800ECS Fargate + OpenSearch + ElastiCache
Large10000+ req/日$2,000-5,000EKS + Spot + OpenSearch + ElastiCache

MRL固有のコスト削減効果:

  • 64次元での初期フィルタリングにより、OpenSearchのストレージを最大32倍削減(2048→64次元)
  • 適応的検索で計算量を最大14倍削減し、CPUノードでも実用的なレイテンシを実現
  • 単一モデルで複数次元に対応するため、モデル管理コストが削減される

Small構成の詳細:

  • Lambda: ARM64, 512MB RAM(低次元推論はメモリ消費が少ない)
  • OpenSearch Serverless: 2 OCU, kNNインデックス(64次元 + 2048次元の2種類)
  • 月額: Lambda $5 + OpenSearch $50 + その他 = $55-100

Large構成の詳細:

  • EKS: c6i.xlarge × 2(CPU推論で十分)
  • OpenSearch: r6g.large.search × 2(マルチAZ、64次元+2048次元インデックス)
  • ElastiCache: r6g.large(2048次元埋め込みキャッシュ、64次元はインラインで計算)
  • 月額: EKS $150 + EC2 $300 + OpenSearch $500 + ElastiCache $200 = $1,500-3,000

注意: 上記コストはAWS ap-northeast-1(東京)リージョンの2026年2月時点の概算値です。実際のコストはトラフィックパターンにより変動します。最新料金はAWS料金計算ツールで確認してください。

Terraformインフラコード

Small構成(Serverless + 2段階検索):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# MRL適応的検索 Serverless構成
# Lambda + OpenSearch Serverless(64次元 + 2048次元インデックス)

terraform {
  required_version = ">= 1.9"
  required_providers {
    aws = {
      source  = "hashicorp/aws"
      version = "~> 5.80"
    }
  }
}

provider "aws" {
  region = "ap-northeast-1"
}

# VPC(NAT Gateway不使用でコスト削減)
resource "aws_vpc" "main" {
  cidr_block           = "10.0.0.0/16"
  enable_dns_hostnames = true
  tags = { Name = "mrl-search-vpc" }
}

resource "aws_subnet" "private" {
  count             = 2
  vpc_id            = aws_vpc.main.id
  cidr_block        = "10.0.${count.index + 1}.0/24"
  availability_zone = data.aws_availability_zones.available.names[count.index]
  tags = { Name = "mrl-private-${count.index}" }
}

data "aws_availability_zones" "available" {
  state = "available"
}

# KMS暗号化キー
resource "aws_kms_key" "main" {
  description             = "MRL search service encryption"
  deletion_window_in_days = 7
}

# Lambda関数(2段階検索ハンドラ)
resource "aws_lambda_function" "mrl_search" {
  function_name = "mrl-adaptive-search"
  runtime       = "python3.12"
  handler       = "handler.search"
  architectures = ["arm64"]
  memory_size   = 512  # 低次元推論は軽量
  timeout       = 15

  role = aws_iam_role.lambda_role.arn

  environment {
    variables = {
      SHORTLIST_DIM    = "64"    # Stage 1: 粗い検索
      RERANK_DIM       = "2048"  # Stage 2: 精密リランク
      SHORTLIST_TOP_K  = "100"   # Stage 1の候補数
      FINAL_TOP_K      = "10"    # 最終結果数
    }
  }

  vpc_config {
    subnet_ids         = aws_subnet.private[*].id
    security_group_ids = [aws_security_group.lambda.id]
  }
}

# IAMロール(最小権限)
resource "aws_iam_role" "lambda_role" {
  name = "mrl-search-lambda-role"
  assume_role_policy = jsonencode({
    Version = "2012-10-17"
    Statement = [{
      Action    = "sts:AssumeRole"
      Effect    = "Allow"
      Principal = { Service = "lambda.amazonaws.com" }
    }]
  })
}

resource "aws_iam_role_policy_attachment" "lambda_vpc" {
  role       = aws_iam_role.lambda_role.name
  policy_arn = "arn:aws:iam::aws:policy/service-role/AWSLambdaVPCAccessExecutionRole"
}

# CloudWatchアラーム(検索レイテンシ監視)
resource "aws_cloudwatch_metric_alarm" "search_latency" {
  alarm_name          = "mrl-search-high-latency"
  comparison_operator = "GreaterThanThreshold"
  evaluation_periods  = 3
  metric_name         = "Duration"
  namespace           = "AWS/Lambda"
  period              = 300
  statistic           = "p95"
  threshold           = 10000
  alarm_actions       = [aws_sns_topic.alerts.arn]
  dimensions = {
    FunctionName = aws_lambda_function.mrl_search.function_name
  }
}

resource "aws_sns_topic" "alerts" {
  name = "mrl-search-alerts"
}

Large構成(Container + 適応的検索):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# EKS + Karpenter(CPU推論、MRLにより GPU不要)
module "eks" {
  source  = "terraform-aws-modules/eks/aws"
  version = "~> 20.31"

  cluster_name    = "mrl-search-cluster"
  cluster_version = "1.31"
  vpc_id          = aws_vpc.main.id
  subnet_ids      = aws_subnet.private[*].id

  cluster_endpoint_public_access = false
  enable_irsa                    = true
}

# Karpenter Provisioner(CPU Spot優先 — MRLにより低次元はGPU不要)
resource "kubectl_manifest" "karpenter_cpu" {
  yaml_body = yamlencode({
    apiVersion = "karpenter.sh/v1beta1"
    kind       = "NodePool"
    metadata   = { name = "mrl-cpu-spot" }
    spec = {
      template = {
        spec = {
          requirements = [
            { key = "karpenter.sh/capacity-type", operator = "In", values = ["spot", "on-demand"] },
            { key = "node.kubernetes.io/instance-type", operator = "In", values = ["c6i.xlarge", "c6i.2xlarge", "c7i.xlarge"] },
          ]
          nodeClassRef = { name = "default" }
        }
      }
      limits   = { cpu = "64", memory = "256Gi" }
      disruption = { consolidationPolicy = "WhenUnderutilized" }
    }
  })
}

# AWS Budgets
resource "aws_budgets_budget" "monthly" {
  name         = "mrl-search-monthly"
  budget_type  = "COST"
  limit_amount = "3000"
  limit_unit   = "USD"
  time_unit    = "MONTHLY"

  notification {
    comparison_operator       = "GREATER_THAN"
    threshold                 = 80
    threshold_type            = "PERCENTAGE"
    notification_type         = "FORECASTED"
    subscriber_email_addresses = ["ops@example.com"]
  }
}

運用・監視設定

CloudWatch Logs Insights — 2段階検索パフォーマンス分析:

1
2
3
4
5
6
7
8
fields @timestamp, shortlist_ms, rerank_ms, total_ms, shortlist_dim, rerank_dim
| stats avg(shortlist_ms) as avg_shortlist,
        avg(rerank_ms) as avg_rerank,
        avg(total_ms) as avg_total,
        percentile(total_ms, 95) as p95_total,
        count(*) as request_count
  by bin(1h) as hour
| sort hour desc

CloudWatchアラーム設定(Python):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import boto3

cloudwatch = boto3.client("cloudwatch", region_name="ap-northeast-1")

# 2段階検索の合計レイテンシ監視
cloudwatch.put_metric_alarm(
    AlarmName="mrl-total-latency-p95",
    MetricName="TotalSearchDurationMs",
    Namespace="MRLSearch",
    Statistic="p95",
    Period=300,
    EvaluationPeriods=3,
    Threshold=3000,
    ComparisonOperator="GreaterThanThreshold",
    AlarmActions=["arn:aws:sns:ap-northeast-1:123456789:mrl-alerts"],
)

X-Rayトレーシング設定:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from aws_xray_sdk.core import xray_recorder, patch_all

patch_all()

@xray_recorder.capture("adaptive_search")
def adaptive_search(
    query_embedding_full: list[float],
    shortlist_dim: int = 64,
    shortlist_k: int = 100,
    final_k: int = 10,
) -> list[dict]:
    """MRL適応的2段階検索"""
    subsegment = xray_recorder.current_subsegment()
    subsegment.put_annotation("shortlist_dim", shortlist_dim)
    subsegment.put_annotation("rerank_dim", len(query_embedding_full))

    # Stage 1: 低次元で候補絞り込み
    q_short = query_embedding_full[:shortlist_dim]
    candidates = opensearch_knn(q_short, k=shortlist_k, index="mrl-64d")
    subsegment.put_metadata("shortlist_count", len(candidates))

    # Stage 2: 高次元で精密リランク
    results = rerank_by_full_dim(query_embedding_full, candidates, k=final_k)
    return results

Cost Explorer日次レポート:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import boto3
from datetime import date, timedelta

ce = boto3.client("ce", region_name="us-east-1")

def get_daily_cost() -> dict:
    """日次コストレポート取得"""
    end = date.today().isoformat()
    start = (date.today() - timedelta(days=1)).isoformat()

    response = ce.get_cost_and_usage(
        TimePeriod={"Start": start, "End": end},
        Granularity="DAILY",
        Metrics=["UnblendedCost"],
        Filter={
            "Tags": {
                "Key": "Project",
                "Values": ["mrl-search"],
            }
        },
    )
    return response["ResultsByTime"][0]

コスト最適化チェックリスト

アーキテクチャ選択:

  • トラフィック量に基づきServerless/Hybrid/Containerを選択
  • MRL対応モデルを選定(OpenAI text-embedding-3, nomic-embed等)
  • 初期フィルタリング次元を精度要件に応じて決定(64/128/256)

リソース最適化:

  • EC2: CPU Spotインスタンス優先(MRLの低次元推論はGPU不要)
  • Reserved Instances: 安定ワークロードには1年コミット
  • Savings Plans: コンピュートワークロード全体で検討
  • Lambda: ARM64でコスト20%削減
  • Lambda: メモリ512MBで十分(低次元推論は軽量)

MRL固有のコスト削減:

  • 64次元Shortlist → 2048次元Rerankで計算量14倍削減
  • OpenSearchインデックスを低次元/高次元に分離し、ストレージ最適化
  • 低次元埋め込みのインメモリキャッシュで検索レイテンシ短縮
  • 単一モデルでの複数次元対応により、モデル管理コスト削減

監視・アラート:

  • AWS Budgets: 月額予算アラート
  • CloudWatch: 2段階検索合計レイテンシP95監視
  • Cost Anomaly Detection: 日次異常検知
  • 日次コストレポート: SNS通知

リソース管理:

  • 未使用OpenSearchインデックスの定期削除
  • タグ戦略: Project/Environment/Dimensionタグ必須
  • 開発環境: 夜間・週末のEKSノード停止
  • CloudTrail/Config有効化
  • S3ライフサイクルポリシー(ログのGlacier移行)

実験結果(Results)

ImageNet-1K分類

論文Table 1より、Fixed Feature(FF)での線形プローブ精度を示す。

次元数独立訓練 FFMRL FF差分
820.0%29.2%+9.2
1645.4%55.7%+10.3
3257.4%67.6%+10.2
6465.7%73.9%+8.2
12870.6%76.5%+5.9
25674.1%78.2%+4.1
51276.1%79.3%+3.2
102477.5%79.7%+2.2
204878.5%80.0%+1.5

低次元ほどMRLの優位性が顕著であり、8次元で+9.2%、16次元で+10.3%の改善を示す。2048次元(最大次元)でも独立訓練を+1.5%上回っており、MRLによる多粒度最適化が全次元で精度を向上させている。

適応的検索での計算削減

著者らの報告によると、ImageNet検索において64次元でShortlist→2048次元でRerankの2段階構成で、精度損失0.3%以内に最大14倍の計算削減を達成している。BERTベースのテキスト埋め込みでも2倍以上の計算削減が確認されている。

エンコーダバリエーション

ResNet-50、ViT-B/16、ALIGN(マルチモーダル)、BERTの各エンコーダでMRLの有効性が確認されており、アーキテクチャに依存しない汎用的な手法であると報告されている。

実運用への応用(Practical Applications)

OpenAI text-embedding-3での採用: OpenAIは2024年1月リリースのtext-embedding-3-small/largeでMRLを採用した。ユーザーはdimensionsパラメータで出力次元を256〜3072の間で指定でき、ストレージコストと精度のトレードオフを調整できる。Zenn記事で紹介されているMTEB/JMTEBベンチマークにおいて、text-embedding-3-largeは高い性能を示している。

大規模ベクトル検索のコスト削減: 数百万〜数十億件のコーパスに対するkNN検索では、ベクトル次元がストレージとレイテンシに直結する。MRLにより64次元でインデックスを構築し、上位100件のみ2048次元でリランキングすることで、インデックスサイズを32分の1に圧縮しつつ精度を維持できる。

エッジ・クラウドのハイブリッドデプロイ: エッジデバイスでは低次元(64〜256次元)で高速推論し、クラウドでは高次元(1024〜2048次元)で精密推論する、単一モデルでのハイブリッド構成が可能になる。

関連研究(Related Work)

  • Product Quantization (Jégou et al., 2011): ベクトルの量子化による検索高速化手法。MRLとは直交する技術であり、MRLで次元削減した後にPQを適用することでさらなる効率化が可能
  • Matryoshka Multimodal Models (Cai et al., 2024): MRLの概念をLarge Multimodal Models(LMM)のビジュアルトークン数に適用。16トークンで256トークン相当の性能を達成
  • MTEB (Muennighoff et al., 2022): 埋め込みモデルのベンチマーク。MRL対応モデルの次元別性能評価に活用されている(別記事で解説: MTEB論文解説

まとめと今後の展望

MRLは単一モデルの訓練に多粒度損失を追加するだけで、推論時に任意の次元で高精度な埋め込みを取得できる手法である。エンコーダのアーキテクチャ変更が不要であり、OpenAI text-embedding-3への採用が示す通り、産業的にも広く受け入れられている。

今後の方向性として、日本語埋め込みモデル(Ruri等、別記事で解説: Ruri論文解説)へのMRL適用や、LoRAとの組み合わせによるファインチューニング効率化が挙げられる。また、2段階検索のShortlist次元の自動選択(クエリ難易度に応じた動的切り替え)も実用化に向けた課題である。

参考文献

この投稿は CC BY 4.0 でライセンスされています。

CIDR 2025論文解説: Text2SQL is Not Enough — TAGフレームワークによるDB×LLM推論の統合

論文解説: MemGPT — OS仮想メモリ概念でLLMエージェントの長期記憶を実現する