本記事は arXiv:2205.13147 Matryoshka Representation Learning(NeurIPS 2022採択)の解説記事です。
論文概要(Abstract)
Kusupati et al.(University of Washington, Google)は、単一の埋め込みモデルから可変サイズの表現を生成するMatryoshka Representation Learning(MRL)を提案している。MRLでは、$d$次元の埋め込みベクトルの最初の$m$次元($m < d$)が、独立に訓練された$m$次元の埋め込みと同等以上の情報量を持つよう訓練される。著者らの実験によると、MRLは近似最近傍(ANN)検索で14倍の高速化を達成しつつ、精度低下を2%未満に抑えたと報告されている。NeurIPS 2022で採択された本論文は、OpenAIのtext-embedding-3シリーズやsentence-transformersのMatryoshkaLossの基礎となっている。
この記事は Zenn記事: Embedding Fine-tuning実践:合成データと評価ループでRAG検索精度を改善する の深掘りです。
情報源
- arXiv ID: 2205.13147
- URL: https://arxiv.org/abs/2205.13147
- 著者: Aditya Kusupati, Gantavya Bhatt, Aniket Rege, Matthew Wallingford, Aditya Sinha, Vivek Ramanujan, William Howard-Snyder, Kaifeng Chen, Sham Kakade, Prateek Jain, Ali Farhadi
- 発表年: 2022(NeurIPS 2022)
- 分野: cs.LG, cs.CV, cs.IR
背景と動機(Background & Motivation)
機械学習システムにおいて、埋め込み(embedding)は分類・検索・クラスタリング等の多くのタスクの基盤である。しかし、従来の埋め込みモデルは固定サイズの表現を出力する。例えばBERT-baseは768次元、ResNet-50は2048次元の固定ベクトルを出力する。
この固定サイズアプローチには以下の問題がある。
- 非効率性: 精度要件が低いタスクでも全次元を使用する必要がある
- 柔軟性の欠如: 異なるハードウェア制約やレイテンシ予算に対応できない
- 画一的: 精度と効率のトレードオフを単一モデルで制御できない
RAGシステムではベクトルDBのストレージコストと検索レイテンシがスケーリングの障壁となるため、次元数を柔軟に選択できるMRLの恩恵が大きい。Zenn記事で紹介されているMatryoshkaLossは、まさにこの論文の手法をsentence-transformersに実装したものである。
主要な貢献(Key Contributions)
- 貢献1: 入れ子構造の表現学習フレームワーク(MRL)の提案。追加パラメータなしで、既存の任意の表現学習手法に適用可能
- 貢献2: 画像(ResNet, ViT)、テキスト(BERT)、マルチモーダル(CLIP)での有効性の実証
- 貢献3: ANN検索で14倍の高速化、2段階検索(64→2048次元)で12倍の高速化を達成しつつ精度低下を2%未満に抑制
技術的詳細(Technical Details)
MRLの訓練目標
MRLの核心は、入れ子次元の集合 $\mathcal{M} = {m_1, m_2, \ldots, m_k}$ に対して、各次元のプレフィックスが独立した表現として有用になるよう訓練する点にある。
分類タスクの場合:
\[\mathcal{L}_{\text{MRL}} = \sum_{m \in \mathcal{M}} c_m \cdot \mathcal{L}(\mathbf{W}_m \cdot \mathbf{z}[1\!:\!m], \, y)\]ここで、
- $\mathcal{M}$: 入れ子次元の集合(例: ${8, 16, 32, 64, 128, 256, 512, 1024, 2048}$)
- $\mathbf{z} \in \mathbb{R}^d$: バックボーンからの完全な$d$次元表現
- $\mathbf{z}[1!:!m]$: $\mathbf{z}$の最初の$m$次元のプレフィックス
- $\mathbf{W}_m \in \mathbb{R}^{C \times m}$: 次元$m$に対応する分類ヘッド
$c_m$: 損失の重み(通常 $c_m = 1/ \mathcal{M} $ で均等) - $\mathcal{L}$: クロスエントロピー損失
- $y$: 正解ラベル
コントラスト学習(検索タスク)の場合:
\[\mathcal{L}_{\text{MRL-metric}} = \sum_{m \in \mathcal{M}} c_m \cdot \mathcal{L}_{\text{contrastive}}(\mathbf{z}[1\!:\!m])\]Zenn記事で紹介されているsentence-transformersのMatryoshkaLossは、この $\mathcal{L}{\text{MRL-metric}}$ を実装したものであり、内部の$\mathcal{L}{\text{contrastive}}$としてMultipleNegativesRankingLoss等をラップして使用する。
なぜMRLが機能するか
MRLの入れ子制約は、モデルに情報を階層的に組織化することを強制する。著者らの分析によると、各次元帯は以下の粒度の情報を担う。
| 次元帯 | 担う情報の粒度 | 画像の例 |
|---|---|---|
| 8-16次元 | 粗いカテゴリ | 「動物」vs「乗り物」 |
| 32-64次元 | サブカテゴリ | 「犬」vs「猫」 |
| 128-256次元 | インスタンスレベル | 特定の犬種 |
| 512-2048次元 | 最細粒度の詳細 | 個体識別 |
この階層構造は、脳の情報処理が複数のスケールで行われることと類似していると著者らは述べている。
PyTorch実装例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
def create_mrl_loss(
model: SentenceTransformer,
matryoshka_dims: list[int] | None = None,
) -> MatryoshkaLoss:
"""MatryoshkaLoss + MultipleNegativesRankingLossを構成する
Args:
model: 学習対象のSentenceTransformerモデル
matryoshka_dims: 入れ子次元のリスト
Returns:
MatryoshkaLossインスタンス
"""
if matryoshka_dims is None:
matryoshka_dims = [768, 512, 256, 128, 64]
base_loss = MultipleNegativesRankingLoss(model)
return MatryoshkaLoss(
model,
base_loss,
matryoshka_dims=matryoshka_dims,
)
class MRLClassificationHead(nn.Module):
"""MRL用の分類ヘッド(複数次元対応)
Args:
full_dim: 完全な埋め込み次元数
num_classes: 分類クラス数
nested_dims: 入れ子次元のリスト
"""
def __init__(
self,
full_dim: int,
num_classes: int,
nested_dims: list[int],
):
super().__init__()
self.nested_dims = nested_dims
self.heads = nn.ModuleDict({
str(d): nn.Linear(d, num_classes)
for d in nested_dims
})
def forward(
self,
z: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
"""各入れ子次元での損失の平均を計算する
Args:
z: バックボーンの出力 (batch_size, full_dim)
labels: 正解ラベル (batch_size,)
Returns:
MRL損失値
"""
total_loss = torch.tensor(0.0, device=z.device)
for d in self.nested_dims:
z_prefix = z[:, :d]
logits = self.heads[str(d)](z_prefix)
loss = nn.functional.cross_entropy(logits, labels)
total_loss += loss
return total_loss / len(self.nested_dims)
2段階検索(Adaptive Retrieval)
MRLの実用的な応用として、2段階検索が挙げられる。
graph LR
A[クエリ] --> B[小次元<br/>64次元で<br/>ANN検索]
B --> C[Top-K候補<br/>取得]
C --> D[大次元<br/>2048次元で<br/>リランキング]
D --> E[最終結果]
第1段階で小次元(64次元等)の埋め込みを使ってANNインデックスから候補を高速に取得し、第2段階で大次元(2048次元等)の埋め込みで精密にリランキングする。これは同一モデルの異なるプレフィックスを使い分けるだけであり、追加のモデルは不要である。
実装のポイント(Implementation)
バックボーンの変更不要: MRLは損失関数の変更のみで実装可能。ResNet、ViT、BERT等のバックボーンアーキテクチャはそのまま使える
訓練メモリの増加: 各入れ子次元に対応する分類ヘッドが必要なため、訓練時のメモリが$ \mathcal{M} $倍に増加する。ただし推論時は不要 次元セット $\mathcal{M}$ の選択: 使用予定の次元を必ず$\mathcal{M}$に含める。Zenn記事のコード例では$[768, 512, 256, 128, 64]$が使用されている
- 推論時の次元切り替え:
model.encode()後にベクトルをスライスするだけで任意の次元に切り替え可能。特別な処理は不要
実験結果(Results)
ImageNet-1K分類(ResNet50)
著者らが報告している次元別の精度比較は以下のとおりである(論文Table 1より)。
| 表現次元 | 標準Top-1精度 | MRL Top-1精度 | 差分 |
|---|---|---|---|
| 2048(フル) | 76.2% | 77.1% | +0.9% |
| 512 | 75.8% | 76.8% | +1.0% |
| 128 | 71.5% | 75.3% | +3.8% |
| 64 | 68.0% | 73.5% | +5.5% |
| 16 | 53.0% | 64.9% | +11.9% |
| 8 | 42.0% | 56.8% | +14.8% |
MRLは全次元で標準的な固定サイズ表現を上回っており、特に小次元での改善が顕著である。
MS-MARCOテキスト検索(MRR@10)
テキスト検索タスクでの結果は以下のとおりである(論文Table 3より)。
| 埋め込み次元 | MRL | 標準 | 差分 |
|---|---|---|---|
| 768(フル) | 35.0 | 34.8 | +0.2 |
| 256 | 34.3 | 33.5 | +0.8 |
| 128 | 33.4 | 32.0 | +1.4 |
| 64 | 31.9 | 29.5 | +2.4 |
| 32 | 29.2 | 26.1 | +3.1 |
ANN検索の高速化
100万件のデータセットでのANN検索の結果は以下のとおりである(論文Table 5より)。
| 手法 | 速度倍率 | 精度 |
|---|---|---|
| フル(2048次元) | 1x | 100% |
| ANN(2048次元) | 3x | 99.5% |
| ANN(64次元MRL) | 14x | 97.0% |
| 2段階(64→2048次元MRL) | 12x | 98.5% |
| 2段階(128→2048次元MRL) | 8x | 99.2% |
2段階検索(64→2048次元)では12倍の高速化を達成しつつ、精度は98.5%を維持している。
Zenn記事との関連
Philschmid氏の実験(Zenn記事で引用)では、BAAI/bge-base-en-v1.5をMatryoshkaLossでFine-tuningした結果、128次元でもFine-tuning前の768次元ベースライン(NDCG@10 = 0.7684)を上回る0.8184を記録したと報告されている。これはMRL論文の知見と一致しており、次元削減と精度向上の両立が実証されている。
実運用への応用(Practical Applications)
ベクトルDBのストレージ削減: 768次元→128次元への削減で83%のストレージ削減が可能。数百万件以上のドキュメントを持つRAGシステムでは、ベクトルDBのコストが大幅に削減される。
レイテンシの改善: ANN検索のレイテンシは次元数にほぼ比例するため、次元数を1/6にすることで検索レイテンシも大幅に短縮される。
段階的なデプロイ戦略: 同一モデルから異なる次元の埋め込みを生成できるため、精度要件の異なるユースケースに単一モデルで対応できる。開発初期は低次元でプロトタイピングし、本番では高次元に切り替えるといった運用が可能になる。
関連研究(Related Work)
- Product Quantization(PQ): Jegou et al.(2011)が提案した埋め込み圧縮手法。MRLは次元削減、PQはベクトル量子化であり、両者は補完的に使用可能
- OpenAI text-embedding-3: OpenAIのtext-embedding-3-smallおよびtext-embedding-3-largeは、MRLの手法を採用しており、API呼び出し時にdimensionsパラメータで出力次元を指定できる
- 2D Matryoshka Sentence Embeddings(Li et al., 2023, arXiv:2309.07597): MRLを文埋め込みに特化して拡張し、2次元の入れ子構造(次元と層)を導入した研究
まとめと今後の展望
Matryoshka Representation Learning(MRL)は、単一モデルから可変サイズの埋め込みを生成する手法であり、追加パラメータなしに既存の訓練パイプラインに組み込むことができる。著者らは、ANN検索で14倍の高速化を達成しつつ精度低下を2%未満に抑えたと報告しており、RAGシステムのスケーリングにおけるストレージコスト・検索レイテンシの課題を解決する有力な手法である。
Zenn記事で紹介されているMatryoshkaLossを使ったFine-tuningは、MRLの手法をsentence-transformersで直接利用する実装であり、MultipleNegativesRankingLossやGISTEmbedLossと組み合わせることで、精度向上と次元削減を同時に実現できる。
参考文献
- arXiv: https://arxiv.org/abs/2205.13147
- Code: https://github.com/RAIVNLab/MRL(Apache 2.0 License)
- NeurIPS 2022: https://proceedings.neurips.cc/paper_files/paper/2022
- sentence-transformers MatryoshkaLoss: https://sbert.net/docs/package_reference/sentence_transformer/losses.html#matryoshkaloss
- Related Zenn article: https://zenn.dev/0h_n0/articles/3a80f7fd58cc8e
:::message この記事はAI(Claude Code)により自動生成されました。内容の正確性については論文の原文で検証していますが、実際の利用時は公式ドキュメントもご確認ください。 :::