Home 論文解説: Jamba — Transformer-Mamba-MoEハイブリッド言語モデルの設計原理
投稿
キャンセル

📄 論文解説: Jamba — Transformer-Mamba-MoEハイブリッド言語モデルの設計原理

本記事は Jamba: A Hybrid Transformer-Mamba Language Model の解説記事です。

論文概要(Abstract)

Jambaは、AI21 Labsが2024年3月に発表した、Transformer Attention、Mamba SSM、Mixture of Experts(MoE)の3つの要素を組み合わせたハイブリッド言語モデルである。52Bの総パラメータ数のうち12Bのみがアクティブであり、256Kトークンのコンテキスト長を単一のA100 80GB GPUで処理可能であると著者らは報告している。Nemotron 3 Nano OmniもMamba SSM層・MoE層・GQA層を組み合わせたハイブリッド構成を採用しており、本論文はその設計上の先行事例として重要な位置づけにある。

この記事は Zenn記事: Nemotron 3 Nano Omniで構築するマルチモーダルAIエージェント実践ガイド の深掘りです。

情報源

  • arXiv ID: 2403.19887
  • URL: https://arxiv.org/abs/2403.19887
  • 著者: Opher Lieber, Barak Lenz, Hofit Bata et al.(AI21 Labs)
  • 発表年: 2024
  • 分野: cs.LG, cs.CL

背景と動機(Background & Motivation)

2024年初頭の時点で、TransformerとSSM(Mamba)はそれぞれ固有の強みと弱みを持つ競合アーキテクチャとして認識されていた。Transformerは豊富な研究蓄積と高いIn-Context Learning能力を持つが、$O(n^2)$ の計算量と線形に増大するKVキャッシュが長コンテキスト処理の障壁であった。一方、MambaはO(n)の計算量とKVキャッシュ不要の特性を持つが、グローバルな情報参照能力がTransformerに劣るという課題があった。

著者らは、「これらのアーキテクチャは競合するものではなく、相補的に組み合わせるべきである」という仮説の下、3つの要素を混合したハイブリッドモデルを構築した。さらにMoEを統合することで、パラメータ数のスケーリングと計算コストの抑制を両立させた。

この設計思想は、後にNVIDIAがNemotron 3 Nano Omni(30B-A3B)で採用した構成——Mamba-2 SSM層23層 + MoE層23層 + GQA層6層——に直接的な影響を与えたと考えられる。

主要な貢献(Key Contributions)

  • 貢献1: Transformer、Mamba、MoEの3要素をブロックレベルで混合するハイブリッドアーキテクチャを提案し、各要素の比率が性能に与える影響を体系的に分析
  • 貢献2: 52Bパラメータ(12Bアクティブ)で256Kコンテキストを単一A100 80GB GPUで処理可能であることを実証。同規模のTransformerモデルと比較して最大3倍の推論スループット向上を報告
  • 貢献3: Transformer:Mamba比率のアブレーション研究を実施し、1:7(12.5% Attention)が精度と効率の最適なバランスであることを特定

技術的詳細(Technical Details)

ハイブリッドブロック構成

Jambaのアーキテクチャは、Transformer層とMamba層を交互に配置する「ブロック」構造を基本とする。各ブロック内のレイヤー構成は以下の通りである。

\[\text{Block}_i = \begin{cases} \text{Attention}(\mathbf{x}) + \text{MoE-FFN}(\mathbf{x}) & \text{if } i \bmod 8 = 0 \\ \text{Mamba}(\mathbf{x}) + \text{FFN}(\mathbf{x}) & \text{otherwise} \end{cases}\]

著者らが選択した比率はTransformer:Mamba = 1:7であり、8層ごとに1層のAttention層が配置される。残りの7層はMamba SSM層で構成される。

重要な設計判断として、MoEはTransformer層のFFN部分にのみ適用されている。Mamba層のFFNは通常の単一FFNのままである。この理由について著者らは、Mamba層にMoEを追加しても精度の改善が見られなかったためであると説明している。

MoE構成の詳細

JambaのMoE構成は以下の通りである。

パラメータ
総エキスパート数16
アクティブエキスパート数(Top-K)2
総パラメータ数52B
アクティブパラメータ数12B

ルーターは標準的なTop-Kゲーティングを採用している。

\[g(\mathbf{x}) = \text{Top-K}(\text{softmax}(\mathbf{W}_r \mathbf{x}))\] \[\text{MoE-FFN}(\mathbf{x}) = \sum_{i \in \text{Top-K}} g_i(\mathbf{x}) \cdot \text{FFN}_i(\mathbf{x})\]

ここで、

  • $\mathbf{W}_r \in \mathbb{R}^{E \times d}$: ルーターの重み行列($E$: エキスパート数)
  • $g_i(\mathbf{x})$: トークン $\mathbf{x}$ に対するエキスパート $i$ のゲート重み
  • $\text{FFN}_i$: 第 $i$ エキスパートのFeed-Forward Network

比較として、Nemotron 3 Nano Omniは128エキスパートからTop-6を選択する、はるかに大規模なMoE構成を採用している。この差異は、Jambaが2024年初頭の技術であるのに対し、Nemotron 3 Nano Omniが2026年のモデルであり、MoEのスケーリングに関する知見が蓄積されていることを反映している。

KVキャッシュの効率性分析

Jambaの最大の実用的利点は、KVキャッシュの大幅な削減にある。Transformer層が全体の12.5%(1/8)のみであるため、KVキャッシュは同規模の純粋なTransformerの約1/8で済む。

256Kトークンのコンテキストにおけるメモリ使用量の比較(著者らの報告に基づく推定)は以下の通りである。

モデルパラメータKVキャッシュ(256K)合計VRAM
Llama-2 70B(Transformer)70B約200GB約340GB
Mixtral 8x7B(MoE Transformer)47B約120GB約200GB
Jamba(Hybrid)52B約25GB約80GB

(出典: 論文Figure 3の推定値。著者らの報告に基づく)

Jamba が単一の A100 80GB GPU に収まるのは、Mamba 層で KV キャッシュが不要であることと、MoE によりアクティブパラメータが 12B に抑えられていることの二重の効果によるものである。

Attention層の配置戦略

著者らは、Attention層を均等間隔で配置する戦略を採用している(8層ごとに1層)。これは後のNemotron-Hの知見——Attention層を後半に集中させるほうが精度が向上する——とは異なるアプローチである。

この差異は、学習データ量とモデルスケールの違いに起因する可能性がある。Jambaの分析では均等配置で十分な精度が得られたが、より大規模な学習において最適な配置パターンが変化することをNemotron-Hの結果は示唆している。

アルゴリズム実装

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import torch.nn as nn

class JambaLayer(nn.Module):
    """Jambaのハイブリッドレイヤー(簡略化実装)"""

    def __init__(
        self,
        d_model: int,
        layer_idx: int,
        n_experts: int = 16,
        top_k: int = 2,
    ):
        super().__init__()
        self.use_attention = (layer_idx % 8 == 0)
        self.use_moe = self.use_attention
        self.norm = nn.RMSNorm(d_model)

        if self.use_attention:
            self.mixer = nn.MultiheadAttention(
                d_model, num_heads=32, batch_first=True
            )
        else:
            from mamba_ssm import Mamba
            self.mixer = Mamba(d_model=d_model, d_state=16, d_conv=4, expand=2)

        if self.use_moe:
            self.router = nn.Linear(d_model, n_experts, bias=False)
            self.experts = nn.ModuleList([
                nn.Sequential(
                    nn.Linear(d_model, d_model * 4),
                    nn.SiLU(),
                    nn.Linear(d_model * 4, d_model),
                )
                for _ in range(n_experts)
            ])
            self.top_k = top_k
        else:
            self.ffn = nn.Sequential(
                nn.Linear(d_model, d_model * 4),
                nn.SiLU(),
                nn.Linear(d_model * 4, d_model),
            )

        self.ffn_norm = nn.RMSNorm(d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """レイヤーの前方計算

        Args:
            x: (batch, seq_len, d_model)

        Returns:
            (batch, seq_len, d_model)
        """
        residual = x
        x = self.norm(x)

        if self.use_attention:
            x, _ = self.mixer(x, x, x)
        else:
            x = self.mixer(x)
        x = x + residual

        residual = x
        x = self.ffn_norm(x)

        if self.use_moe:
            router_logits = self.router(x)
            weights, indices = torch.topk(
                torch.softmax(router_logits, dim=-1), self.top_k, dim=-1
            )
            weights = weights / weights.sum(dim=-1, keepdim=True)

            output = torch.zeros_like(x)
            for k in range(self.top_k):
                expert_idx = indices[:, :, k]
                expert_weight = weights[:, :, k].unsqueeze(-1)
                for e in range(len(self.experts)):
                    mask = (expert_idx == e)
                    if mask.any():
                        expert_input = x[mask]
                        expert_output = self.experts[e](expert_input)
                        output[mask] += expert_weight[mask] * expert_output
            x = output
        else:
            x = self.ffn(x)

        return x + residual

実装のポイント(Implementation)

Jambaの実装において注意すべき点は以下の通りである。

  • ロードバランシング: MoEのルーティングで特定のエキスパートに負荷が集中する問題が報告されている。著者らはauxiliary lossによるロードバランシングを採用しているが、学習初期の不安定性が課題
  • Mamba層の並列学習: Mamba層は再帰的計算であるため、学習時の並列化がAttentionより困難。畳み込みモードでの学習が推奨されるが、選択的パラメータの導入によりこのモードが使えないケースがある
  • メモリ管理: Attention層のKVキャッシュとMamba層の隠れ状態は異なるメモリ管理が必要。vLLMやHugging Face transformersでのサポートが提供されている

実験結果(Results)

アブレーション: Attention比率

著者らは、Transformer:Mamba比率を変化させたアブレーション研究を実施している。

比率 (Attn:Mamba)精度 (avg)スループット
1:1 (50%)最高
1:3 (25%)
1:7 (12.5%)
0:1 (0%)最高

(出典: 論文Section 5のアブレーション研究。著者らの報告に基づく定性的まとめ)

著者らは、1:7が精度と効率の最適なトレードオフであると結論づけている。Attention比率を12.5%以下に下げると、In-Context Learningタスクでの精度低下が顕著になると報告されている。

標準ベンチマーク

ベンチマークJamba (12B active)Mixtral 8x7B (12.9B active)Llama-2 70B
HellaSwag87.186.787.3
WinoGrande82.581.283.7
ARC-Challenge64.462.567.3

(出典: 論文Table 2。著者らの報告に基づく)

12Bのアクティブパラメータで、Mixtral 8x7B(12.9Bアクティブ)を全ベンチマークで上回り、70BのフルTransformerであるLlama-2に迫る結果を報告している。

長コンテキスト性能

著者らは、256Kトークンのコンテキストでの「Needle-in-a-Haystack」テストにおいて、Jambaが正確に情報を検索できることを実証している。これは、少数のAttention層がグローバルな情報参照を担い、Mamba層が効率的な状態伝播を担うという役割分担が機能していることを示唆している。

実運用への応用(Practical Applications)

Jambaの設計原理は、Nemotron 3 Nano Omniに以下の形で応用されている。

  • ハイブリッド構成の発展: Jambaの1:7比率はNemotron-HのMamba比率94%、さらにNemotron 3 Nano OmniのMamba-2 SSM 23層 + MoE 23層 + GQA 6層の構成へと発展
  • MoEのスケーリング: Jambaの16エキスパートTop-2からNemotron 3 Nano Omniの128エキスパートTop-6への拡大は、MoEスケーリングの実用的知見の蓄積を反映
  • 単一GPU動作の実現: Jambaが示した「KVキャッシュ削減+スパース活性化」による単一GPU動作の実現は、NVFP4量子化と組み合わせてRTX 5090等のコンシューマGPUでの動作を可能にする設計の原型

関連研究(Related Work)

  • Mamba(Gu & Dao, 2023): Jambaが採用するSSM層の基盤技術。選択的状態空間モデルの導入により、SSMの表現力を大幅に向上させた
  • Mixtral 8x7B(Mistral AI, 2024): 同時期のMoEモデル。純粋なTransformer+MoE構成であり、SSM層を含まない。Jambaは同等のアクティブパラメータ数でMixtralを上回る精度を報告
  • Nemotron-H(NVIDIA, 2025): Jambaの後継的研究。Mamba-2を採用し、Attention比率をさらに6%に削減。Attention層の配置を後半に集中させるなど、Jambaの均等配置とは異なる戦略を採用

まとめと今後の展望

Jambaは、Transformer+Mamba+MoEの三要素ハイブリッドアーキテクチャの実用性を初めて大規模に実証した研究である。著者らが示した1:7のAttention比率、MoEをAttention層のFFNのみに適用する設計判断、256Kコンテキストの単一GPU動作は、後続のNemotron-HおよびNemotron 3 Nano Omniの設計に直接的な影響を与えた。SSM+MoEハイブリッドは、マルチモーダルAIエージェントの効率的な推論基盤として今後も発展が期待される分野である。

参考文献

  • arXiv: https://arxiv.org/abs/2403.19887
  • Related Zenn article: https://zenn.dev/0h_n0/articles/fabaf781f4158d
  • Mamba: https://arxiv.org/abs/2312.00752
  • Nemotron-H: https://arxiv.org/abs/2504.11849
  • Mixtral 8x7B: https://arxiv.org/abs/2401.04088
この投稿は CC BY 4.0 でライセンスされています。