Home 論文解説: Mamba — 選択的状態空間モデルによる線形時間シーケンスモデリング
投稿
キャンセル

📄 論文解説: Mamba — 選択的状態空間モデルによる線形時間シーケンスモデリング

本記事は Mamba: Linear-Time Sequence Modeling with Selective State Spaces の解説記事です。

論文概要(Abstract)

Mambaは、Albert GuとTri Daoが2023年12月に発表した選択的状態空間モデル(Selective State Space Model)である。従来の固定パラメータSSM(S4等)を超え、入力に依存してパラメータを動的に選択する機構を導入することで、Transformerの二次計算量 $O(n^2)$ を線形 $O(n)$ に削減しながら、言語モデリングでTransformerと同等以上の精度を達成したと著者らは報告している。Hardware-aware parallel scanアルゴリズムにより、GPU上での実用的な推論速度を確保している。Zenn記事で紹介したNemotron 3 Nano OmniのアーキテクチャはこのMambaの後継版であるMamba-2をSSM層として採用しており、本論文はその技術的基盤を理解するための必須文献である。

この記事は Zenn記事: Nemotron 3 Nano Omniで構築するマルチモーダルAIエージェント実践ガイド の深掘りです。

情報源

  • arXiv ID: 2312.00752
  • URL: https://arxiv.org/abs/2312.00752
  • 著者: Albert Gu, Tri Dao
  • 発表年: 2023
  • 分野: cs.LG, cs.AI

背景と動機(Background & Motivation)

シーケンスモデリングにおいて、Transformerは長距離依存関係の捕捉において優れた性能を示してきた。しかし、Self-Attentionの計算量 $O(n^2)$ とKVキャッシュのメモリ消費 $O(n)$ は、シーケンス長の増大に伴い実用上の障壁となっている。

構造化状態空間モデル(Structured State Space Model; S4)は線形時間で動作する代替手法として注目されたが、S4は固定パラメータであるため入力内容に応じた情報の選択ができず、言語モデリングでTransformerに劣る結果であった。具体的には、S4はすべての入力トークンを等しく処理してしまい、重要な情報と無関係な情報を区別できないという根本的な制約があった。

著者らは、SSMのパラメータを入力に依存させる(selective)ことでこの制約を克服し、同時にhardware-awareなアルゴリズム設計でGPUメモリ階層を効率的に活用することを提案した。

主要な貢献(Key Contributions)

  • 貢献1: 入力依存のパラメータ選択機構(Selection Mechanism)をSSMに導入。$\mathbf{B}$, $\mathbf{C}$, $\Delta$ を入力から動的に生成することで、トークンごとの情報のフィルタリングを実現
  • 貢献2: Hardware-aware parallel scanアルゴリズムの設計。GPU SRAM(高速メモリ)でSSM状態を計算し、HBM(高帯域メモリ)へのアクセスを最小化
  • 貢献3: 言語・DNA・音声等の多様なドメインで、同スケールのTransformerと同等以上の精度を達成。The Pile上の言語モデリングではperplexityで上回る結果を報告

技術的詳細(Technical Details)

連続時間SSMの離散化

SSMの出発点は、連続時間の線形常微分方程式(ODE)である。

\[\frac{d\mathbf{h}(t)}{dt} = \mathbf{A}\mathbf{h}(t) + \mathbf{B}\mathbf{x}(t)\] \[\mathbf{y}(t) = \mathbf{C}\mathbf{h}(t)\]

ここで、

  • $\mathbf{h}(t) \in \mathbb{R}^{N}$: 隠れ状態ベクトル($N$: state dimension)
  • $\mathbf{A} \in \mathbb{R}^{N \times N}$: 状態遷移行列
  • $\mathbf{B} \in \mathbb{R}^{N \times 1}$: 入力射影行列
  • $\mathbf{C} \in \mathbb{R}^{1 \times N}$: 出力射影行列
  • $\mathbf{x}(t)$: 入力信号、$\mathbf{y}(t)$: 出力信号

離散時間に変換するために、タイムステップ $\Delta$ を用いてZero-Order Hold(ZOH)離散化を行う。

\[\bar{\mathbf{A}} = \exp(\Delta \mathbf{A})\] \[\bar{\mathbf{B}} = (\Delta \mathbf{A})^{-1}(\exp(\Delta \mathbf{A}) - \mathbf{I}) \cdot \Delta \mathbf{B}\]

離散化後のSSMは再帰的に計算可能となる。

\[\mathbf{h}_t = \bar{\mathbf{A}} \mathbf{h}_{t-1} + \bar{\mathbf{B}} x_t\] \[y_t = \mathbf{C} \mathbf{h}_t\]

選択機構(Selection Mechanism)

Mambaの核心は、$\mathbf{B}$, $\mathbf{C}$, $\Delta$ を固定パラメータではなく入力 $x_t$ の関数にすることである。

\[\mathbf{B}_t = \text{Linear}_B(\mathbf{x}_t) \in \mathbb{R}^{N}\] \[\mathbf{C}_t = \text{Linear}_C(\mathbf{x}_t) \in \mathbb{R}^{N}\] \[\Delta_t = \text{softplus}(\text{Linear}_\Delta(\mathbf{x}_t)) \in \mathbb{R}^{+}\]

ここで、$\text{softplus}(z) = \log(1 + e^z)$ は $\Delta_t$ が正の値をとることを保証する。

この選択機構により、各トークンは独立したパラメータセットで処理される。直感的には以下のように理解できる。

  • $\Delta_t$ が大きい → 新しい入力を重視(過去の情報を「忘却」)
  • $\Delta_t$ が小さい → 過去の状態を保持(新しい入力を「無視」)
  • $\mathbf{B}_t$ → 現在の入力のどの側面を状態に書き込むかを制御
  • $\mathbf{C}_t$ → 状態のどの側面を出力として読み出すかを制御

この機構により、モデルはシーケンス中の重要なトークン(例: 固有名詞、キーワード)を選択的に記憶し、不要なトークン(例: 冠詞、句読点)を忘却できる。

Hardware-Aware Parallel Scan

選択機構の導入により、SSMのパラメータがステップごとに異なるため、畳み込みモードでの高速計算が不可能になる(固定パラメータS4では畳み込みカーネルに変換して $O(n \log n)$ で計算できた)。

著者らは、この問題をGPUのメモリ階層を考慮したparallel scan(prefix sum)アルゴリズムで解決している。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch

def selective_scan(
    x: torch.Tensor,
    A: torch.Tensor,
    B: torch.Tensor,
    C: torch.Tensor,
    delta: torch.Tensor,
) -> torch.Tensor:
    """Selective SSMのスキャン操作(簡略化版)

    Args:
        x: 入力テンソル (batch, seq_len, d_model)
        A: 状態遷移行列 (d_inner, d_state)
        B: 入力射影 (batch, seq_len, d_state)
        C: 出力射影 (batch, seq_len, d_state)
        delta: タイムステップ (batch, seq_len, d_inner)

    Returns:
        出力テンソル (batch, seq_len, d_model)
    """
    batch, seq_len, d_inner = x.shape
    d_state = A.shape[1]

    delta_A = torch.exp(torch.einsum("b l d, d n -> b l d n", delta, A))
    delta_B_x = torch.einsum("b l d, b l n, b l d -> b l d n", delta, B, x)

    h = torch.zeros(batch, d_inner, d_state, device=x.device)
    outputs = []

    for t in range(seq_len):
        h = delta_A[:, t] * h + delta_B_x[:, t]
        y = torch.einsum("b d n, b n -> b d", h, C[:, t])
        outputs.append(y)

    return torch.stack(outputs, dim=1)

実際の実装では、このループはCUDAカーネルでparallel scanに変換され、work-efficient $O(n)$ で処理される。重要な設計判断として、中間状態をHBMではなくGPU SRAM上に保持し、カーネルの融合(kernel fusion)によりメモリアクセスを最小化している。

Mambaブロックの全体構成

1つのMambaブロックは以下の構成をとる。

  1. 入力を2つのパスに分岐(expand=2で次元を2倍に拡大)
  2. 一方のパスに1次元畳み込み(d_conv=4)を適用
  3. SiLU活性化関数を適用
  4. 選択的SSMを適用
  5. もう一方のパスとのゲーティング(要素積)
  6. 線形射影で元の次元に戻す
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class MambaBlock(torch.nn.Module):
    """Mambaブロックの簡略化実装"""

    def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4, expand: int = 2):
        super().__init__()
        d_inner = d_model * expand
        self.in_proj = torch.nn.Linear(d_model, d_inner * 2, bias=False)
        self.conv1d = torch.nn.Conv1d(
            d_inner, d_inner, d_conv, padding=d_conv - 1, groups=d_inner
        )
        self.x_proj = torch.nn.Linear(d_inner, d_state * 2 + 1, bias=False)
        self.A_log = torch.nn.Parameter(torch.randn(d_inner, d_state))
        self.D = torch.nn.Parameter(torch.ones(d_inner))
        self.out_proj = torch.nn.Linear(d_inner, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Mambaブロックの前方計算

        Args:
            x: (batch, seq_len, d_model)

        Returns:
            (batch, seq_len, d_model)
        """
        xz = self.in_proj(x)
        x_branch, z = xz.chunk(2, dim=-1)

        x_branch = x_branch.transpose(1, 2)
        x_branch = self.conv1d(x_branch)[:, :, :x.shape[1]]
        x_branch = x_branch.transpose(1, 2)
        x_branch = torch.nn.functional.silu(x_branch)

        proj = self.x_proj(x_branch)
        B = proj[:, :, :self.A_log.shape[1]]
        C = proj[:, :, self.A_log.shape[1]:2*self.A_log.shape[1]]
        delta = torch.nn.functional.softplus(proj[:, :, -1:].squeeze(-1))

        A = -torch.exp(self.A_log)
        y = selective_scan(x_branch, A, B, C, delta.unsqueeze(-1).expand_as(x_branch))
        y = y + x_branch * self.D

        y = y * torch.nn.functional.silu(z)
        return self.out_proj(y)

デフォルトハイパーパラメータ

著者らが報告しているデフォルト値は以下の通りである。

パラメータ意味
d_state16SSM隠れ状態の次元数
d_conv41次元畳み込みのカーネルサイズ
expand2入力次元に対する内部次元の拡大率

Nemotron 3 Nano OmniではMamba-2を採用しており、d_state=128 と大幅に拡大されている。これは、マルチモーダル入力の複雑なパターンを捕捉するために必要な状態容量の増加を反映している。

実験結果(Results)

言語モデリング

著者らは、The Pile上での言語モデリングでMambaとTransformerを比較している。

モデルパラメータPerplexity (↓)
Transformer (GPT-3 arch.)125M26.1
Mamba130M25.3
Transformer (GPT-3 arch.)1.3B13.0
Mamba1.4B12.5

(出典: 論文Table 3。著者らの報告に基づく)

すべてのスケールでMambaがTransformerを上回っている。著者らは、これは選択機構によりモデルが関連する文脈情報を効率的に保持できるためであると分析している。

マルチドメイン評価

DNA配列モデリング、音声分類等のドメインでも有効性が確認されている。特にDNA配列では、HyenaDNA等のSSMベースラインを大幅に上回ったと報告されている。

推論スループット

著者らは、シーケンス長に対するスループットの変化を測定している。長いシーケンス(1K〜16Kトークン)において、Transformerの5倍のスループットを達成したと報告されている。この差はシーケンス長が長いほど顕著であり、Attentionの $O(n^2)$ 計算量とKVキャッシュの線形メモリ増加がボトルネックになるTransformerとの差が拡大する。

実運用への応用(Practical Applications)

Mambaアーキテクチャは、Nemotron 3 Nano Omniにおいて以下の形で活用されている。

  • 動画処理の効率化: 動画入力は大量のビジョントークンを生成するが、SSMの線形時間計算により処理コストがシーケンス長に対して線形にスケールする。これがMediaPerfベンチマークでの9.91 h/hスループットの技術的基盤
  • 音声ストリーミング処理: SSMの再帰的な状態更新は、ストリーミング入力に自然に適合する。Parakeet-TDTエンコーダからの音声トークンをリアルタイムで処理可能
  • 長コンテキストドキュメント分析: KVキャッシュが不要なため、256Kトークンのコンテキストでもメモリ効率的に動作

ただし、SSMはAttentionメカニズムと異なり、任意の2トークン間の直接的な相互参照ができない。これがIn-Context Learningの制約の原因であり、Nemotron-Hではこの制約をAttention層との混合で緩和している。

関連研究(Related Work)

  • S4(Structured State Spaces for Sequence Modeling, Gu et al., 2021): Mambaの直接的な前身。固定パラメータのSSMをHiPPO初期化と対角構造で安定化したが、入力依存の選択機構を持たない
  • H3(Hungry Hungry Hippos, Fu et al., 2023): SSMとAttentionの混合を初期的に探索した研究。Mambaは純粋なSSMで同等以上の性能を達成
  • RWKV(Peng et al., 2023): 線形AttentionベースのRNNモデル。Mambaとは異なるアプローチで線形時間推論を実現しているが、著者らの実験ではMambaが上回る精度を報告

まとめと今後の展望

Mambaは、選択的状態空間モデルの導入によりSSMの表現力を大幅に向上させ、Transformerに対する実用的な代替手法としての地位を確立した研究である。著者らが示した「入力依存パラメータ」と「hardware-awareアルゴリズム」の組み合わせは、後続のMamba-2やNemotron-Hファミリーへと発展している。Nemotron 3 Nano Omniの効率的なマルチモーダル処理は、この基盤技術なしには実現しえなかったものである。

参考文献

  • arXiv: https://arxiv.org/abs/2312.00752
  • Code: https://github.com/state-spaces/mamba
  • Related Zenn article: https://zenn.dev/0h_n0/articles/fabaf781f4158d
  • S4原論文: https://arxiv.org/abs/2111.00396
  • Mamba-2: https://arxiv.org/abs/2405.21060
この投稿は CC BY 4.0 でライセンスされています。