Home ICLR 2025論文解説: Jamba — Transformer-Mambaハイブリッドアーキテクチャの設計と実装
投稿
キャンセル

📄 ICLR 2025論文解説: Jamba — Transformer-Mambaハイブリッドアーキテクチャの設計と実装

本記事は Jamba: A Hybrid Transformer-Mamba Language Model (arXiv:2403.19887) の解説記事です。ICLR 2025 採択論文。

論文概要(Abstract)

Jambaは、AI21 Labsが提案したTransformerとMamba(State Space Model, SSM)のハイブリッド言語モデルである。Attention層とMamba層を1:7の比率で交互に配置し、さらにMixture-of-Experts(MoE)を組み合わせることで、52Bパラメータ(アクティブ12B)を単一80GB GPUで動作可能にしている。256Kトークンのコンテキスト長をサポートし、Mixtral 8x7B比で3倍のスループットを報告している。

この記事は Zenn記事: Attention機構の全史 Bahdanauから FlashAttention4・MLAまでの数学と実装 の深掘りです。

情報源

カンファレンス情報

Jambaは2024年3月にプレプリントとして公開され、ICLR 2025に採択された。SSM(State Space Model)とTransformerのハイブリッドアーキテクチャとしては初のプロダクションレベルのモデルであり、オープンウェイト(Apache 2.0ライセンス)で公開されている。

技術的詳細(Technical Details)

なぜハイブリッドか:AttentionとSSMの相補性

AttentionとSSMはそれぞれ異なる強みと弱みを持つ。

特性AttentionSSM(Mamba)
計算量(シーケンス長 $n$)$O(n^2 d)$$O(n d)$
情報検索精度高い(全トークンペアを明示的に比較)低い(再帰的に圧縮された状態から検索)
長距離依存性KVキャッシュが線形増加固定サイズの状態で表現
並列学習効率高い(全位置を同時計算可能)中程度(再帰計算にシーケンシャル成分あり)

著者らは、2025年のアブレーション研究でAttention層を完全に除去すると検索精度が0%に低下することを確認しており、SSMだけでは「正確な情報の取り出し」が困難であることが実験的に裏付けられている。

Mambaの基本メカニズム

Jambaが採用するMambaは、Selective State Space Model(S6)とも呼ばれ、入力に応じてSSMのパラメータを動的に変化させる機構を持つ。

離散時間SSMの定式化:

\(h_t = \bar{A} h_{t-1} + \bar{B} x_t\) \(y_t = C h_t\)

ここで、

  • $h_t \in \mathbb{R}^{N}$: 隠れ状態(状態次元 $N$)
  • $x_t \in \mathbb{R}^{D}$: 入力
  • $\bar{A} \in \mathbb{R}^{N \times N}$: 離散化された状態遷移行列
  • $\bar{B} \in \mathbb{R}^{N \times D}$: 入力射影行列
  • $C \in \mathbb{R}^{D \times N}$: 出力射影行列

Mambaの選択機構(Selective Scan): 標準SSMでは $A, B, C$ は入力非依存の固定パラメータだが、Mambaでは入力 $x_t$ に依存して $B_t, C_t, \Delta_t$(離散化ステップ)を動的に生成する。

\[B_t = \text{Linear}(x_t), \quad C_t = \text{Linear}(x_t), \quad \Delta_t = \text{softplus}(\text{Linear}(x_t))\]

これにより、モデルは入力の内容に応じて「何を記憶し、何を忘れるか」を選択的に制御できる。

Jambaのブロック構成

Jambaは以下の階層構造で構成される。

graph TD
    subgraph "Jamba Block(4層で1ブロック)"
        L1["Layer 1: Mamba + MLP"]
        L2["Layer 2: Mamba + MLP"]
        L3["Layer 3: Mamba + MLP"]
        L4["Layer 4: Attention + MLP"]
    end
    subgraph "MoE適用(2ブロックに1回)"
        MOE["MoE: 16 experts<br>top-2 routing"]
    end
    L1 --> L2 --> L3 --> L4
    L4 --> MOE

具体的な構成:

  • Attention : Mamba比率 = 1 : 7(8層のうち1層がAttention)
  • MoE適用頻度: 2ブロック(8層)に1回
  • エキスパート数: 16(うち2を活性化、top-2 routing)
  • 全パラメータ数: 52B
  • アクティブパラメータ数: 12B(トークンあたり)
  • コンテキスト長: 256K トークン

メモリ効率の分析

Jambaのメモリ効率が高い理由は、KVキャッシュがAttention層のみに限定されることにある。

KVキャッシュ比較(256Kコンテキスト時、論文Table 3より概算):

モデルKVキャッシュサイズ(概算)備考
Llama-2 70B(全層Attention)~128 GB全80層でKV保持
Mixtral 8x7B(全層Attention)~48 GB32層、GQA使用
Jamba 52B(1/8がAttention)~8 GB7/8層はSSMで固定サイズ状態のみ

Mamba層は固定サイズの隠れ状態($h_t$)のみを保持するため、シーケンス長に依存しないメモリ消費となる。これがJambaの256Kコンテキスト対応を可能にしている核心である。

実装例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn as nn

class JambaBlock(nn.Module):
    """Jamba Block: 3 Mamba layers + 1 Attention layer"""
    def __init__(
        self,
        d_model: int = 4096,
        n_heads: int = 32,
        ssm_state_dim: int = 16,
        ssm_expand: int = 2,
    ):
        super().__init__()
        # 3 Mamba layers + 1 Attention layer
        self.mamba_layers = nn.ModuleList([
            MambaLayer(d_model, ssm_state_dim, ssm_expand)
            for _ in range(3)
        ])
        self.attention_layer = AttentionLayer(d_model, n_heads)

        self.layer_norms = nn.ModuleList([
            nn.RMSNorm(d_model) for _ in range(4)
        ])

    def forward(
        self,
        x: torch.Tensor,
        ssm_states: list[torch.Tensor] | None = None,
        kv_cache: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]:
        new_ssm_states = []

        # Mamba layers (3/4)
        for i, (mamba, norm) in enumerate(
            zip(self.mamba_layers, self.layer_norms[:3])
        ):
            residual = x
            x = norm(x)
            state_in = ssm_states[i] if ssm_states else None
            x, state_out = mamba(x, state_in)
            x = x + residual
            new_ssm_states.append(state_out)

        # Attention layer (1/4)
        residual = x
        x = self.layer_norms[3](x)
        x, new_kv_cache = self.attention_layer(x, kv_cache)
        x = x + residual

        return x, new_ssm_states, new_kv_cache


class MambaLayer(nn.Module):
    """Simplified Mamba (Selective SSM) layer"""
    def __init__(self, d_model: int, state_dim: int, expand: int):
        super().__init__()
        d_inner = d_model * expand
        self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
        self.conv1d = nn.Conv1d(d_inner, d_inner, kernel_size=4,
                                padding=3, groups=d_inner)
        # Selective parameters
        self.x_proj = nn.Linear(d_inner, state_dim * 2 + 1, bias=False)
        self.dt_proj = nn.Linear(1, d_inner, bias=True)
        self.A_log = nn.Parameter(torch.randn(d_inner, state_dim))
        self.D = nn.Parameter(torch.ones(d_inner))
        self.out_proj = nn.Linear(d_inner, d_model, bias=False)

    def forward(
        self,
        x: torch.Tensor,
        ssm_state: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        B, T, D = x.shape
        xz = self.in_proj(x)
        x_in, z = xz.chunk(2, dim=-1)

        x_conv = self.conv1d(x_in.transpose(1, 2))[:, :, :T].transpose(1, 2)
        x_conv = torch.silu(x_conv)

        # Selective scan (simplified)
        ssm_params = self.x_proj(x_conv)
        A = -torch.exp(self.A_log)

        # Recurrent computation
        if ssm_state is None:
            ssm_state = torch.zeros(B, x_in.size(-1), A.size(-1),
                                     device=x.device)

        outputs = []
        for t in range(T):
            ssm_state = ssm_state * torch.exp(A.unsqueeze(0)) + \
                        x_conv[:, t:t+1, :].unsqueeze(-1)
            y_t = (ssm_state @ ssm_params[:, t:t+1, :A.size(-1)].unsqueeze(-1)
                   ).squeeze(-1)
            outputs.append(y_t)

        y = torch.stack(outputs, dim=1)
        y = y * torch.silu(z)
        output = self.out_proj(y)

        return output, ssm_state

実装のポイント(Implementation)

Mamba SSMライブラリへの依存: Jambaの実行にはMamba SSMライブラリ(mamba-ssmパッケージ)が必要であり、これはTritonに依存する。CUDAバージョンとの互換性に注意が必要。

Attention/Mamba比率の選択: 著者らは1:7(8層中1層がAttention)を推奨しているが、タスクの性質によって最適な比率は異なる。検索精度が重要なRAGタスクでは、Attention層の比率を高める(1:3や1:5)ことが有効な場合がある。

メモリ管理: 推論時、Mamba層のSSM状態は固定サイズ($d_{inner} \times N_{state}$)であるのに対し、Attention層のKVキャッシュはシーケンス長に比例して増大する。長コンテキスト推論では、Attention層のKVキャッシュがメモリボトルネックとなりうる。

Hugging Face Hub: ai21labs/Jamba-v0.1 としてモデル重みが公開されており、transformers ライブラリ経由でロード可能(Apache 2.0ライセンス)。

Production Deployment Guide

AWS実装パターン(コスト最適化重視)

Jambaモデルの推論構成を示す。Jambaは52Bパラメータだが12Bアクティブのため、単一GPUでの推論が可能である。

規模月間リクエスト推奨構成月額コスト概算主要サービス
Small~3,000SageMaker Serverless$200-500SageMaker (g5.2xlarge)
Medium~30,000SageMaker Real-time$2,500-5,000SageMaker (g5.12xlarge × 2)
Large300,000+EKS Self-managed$8,000-18,000EKS + g5.48xlarge / p5

単一GPU動作のメリット: Jamba 52Bは80GB GPU 1枚で動作するため、マルチGPU分散推論のオーバーヘッドが不要。g5.48xlarge(A10G×8)でバッチ推論を行う場合、Mixtral 8x7B比で3倍のスループットが期待できる。

コスト試算の注意事項: 上記は2026年4月時点のAWS東京リージョン料金に基づく概算値。Jambaの長コンテキスト対応(256K)を活用する場合、KVキャッシュ用メモリが追加で必要となる。

Terraformインフラコード

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
resource "aws_sagemaker_model" "jamba" {
  name               = "jamba-hybrid-inference"
  execution_role_arn = aws_iam_role.sagemaker.arn

  primary_container {
    image = "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-inference:2.3.0-gpu-py311-cu121-ubuntu22.04-sagemaker"
    model_data_url = "s3://${aws_s3_bucket.model.bucket}/jamba/model.tar.gz"
    environment = {
      MODEL_NAME       = "ai21labs/Jamba-v0.1"
      MAX_CONTEXT_LEN  = "262144"
      MAMBA_SSM_VERSION = "1.2.0"
    }
  }
}

resource "aws_sagemaker_endpoint_configuration" "jamba" {
  name = "jamba-config"

  production_variants {
    variant_name           = "default"
    model_name             = aws_sagemaker_model.jamba.name
    instance_type          = "ml.g5.12xlarge"
    initial_instance_count = 1
  }
}

resource "aws_cloudwatch_metric_alarm" "ssm_state_memory" {
  alarm_name          = "jamba-memory-usage"
  comparison_operator = "GreaterThanThreshold"
  evaluation_periods  = 2
  metric_name         = "GPUMemoryUtilization"
  namespace           = "AWS/SageMaker"
  period              = 300
  statistic           = "Maximum"
  threshold           = 90.0
  alarm_description   = "GPUメモリ90%超過: コンテキスト長の制限またはインスタンスサイズ変更を検討"
}

コスト最適化チェックリスト

  • Jamba単一GPU推論でマルチGPU分散コスト回避
  • Mamba層の固定メモリ消費を活用した長コンテキスト推論
  • SageMaker Autoscaling設定
  • Spot Instances活用(g5: 最大60%削減)
  • Continuous Batching有効化(vLLM Jamba対応設定)
  • CloudWatch GPU メモリ/使用率アラーム設定
  • AWS Budgets月額予算アラート
  • 開発環境エンドポイント夜間削除
  • Cost Anomaly Detection有効化
  • 長コンテキスト推論のバッチサイズ最適化

実験結果(Results)

著者らが報告している主要な実験結果を以下にまとめる(論文Table 1-4より)。

ベンチマークMixtral 8x7BLlama-2 70BJamba 52B (12B active)
HellaSwag86.787.387.1
WinoGrande81.283.782.5
ARC-Challenge65.767.364.4
PIQA83.282.883.2

スループット比較(論文Figure 6より):

  • Jamba: Mixtral 8x7B比で約3倍のスループット(128Kコンテキスト時)
  • メモリ使用量: 長コンテキストでの増加率がTransformerのみのモデルと比較して大幅に低い

制約事項: 短文タスク(ARC-Challenge等)ではLlama-2 70BやMixtral 8x7Bに対して若干劣るケースがある。これはアクティブパラメータ数の差(12B vs 47B/70B)に起因すると考えられる。長コンテキストタスクでの優位性が、Jambaの主要な価値となっている。

実運用への応用(Practical Applications)

Jambaの実運用上の主要な価値は、単一GPUで256Kトークンのコンテキストを処理できる点にある。これは以下のユースケースで特に有用である。

ドキュメント分析: 長い法律文書や技術文書(100ページ以上)を1回のコンテキストで処理可能。Attention層が精密な情報検索を担保し、Mamba層が長距離の文脈把握を担う。

コード理解: 大規模コードベース(数千行)をコンテキストに含めた上でのコード生成・レビュー。

マルチターン対話: 長い対話履歴を保持しつつ効率的な推論が可能。

ただし、Mamba SSMライブラリの成熟度はTransformerエコシステムと比較して発展途上であり、推論フレームワーク(vLLM等)のJamba対応も限定的である点は実運用上の考慮事項である。

まとめと今後の展望

Jambaは、TransformerとMamba(SSM)を1:7の比率で組み合わせたハイブリッドアーキテクチャにより、長コンテキスト処理のメモリ効率と推論スループットを大幅に改善した。52Bパラメータを単一80GB GPUで動作させた点は、デプロイメントコストの観点からも意義がある。

今後の方向性としては、Attention/SSM比率の動的調整(タスクに応じた適応的切り替え)、MambaとMLAの組み合わせによるさらなるKVキャッシュ削減、およびJamba 1.5以降のスケールアップが注目される。

参考文献

この投稿は CC BY 4.0 でライセンスされています。

Anthropic公式ブログ解説: Claude Opus 4.7 — 高解像度ビジョンとアダプティブ推論の技術詳細

論文解説: SafeChain — 長Chain-of-Thought推論の安全性分析と安全性訓練データセット