Home 論文解説: LLaDA — マスク拡散で実現する大規模言語モデルの新パラダイム
投稿
キャンセル

📄 論文解説: LLaDA — マスク拡散で実現する大規模言語モデルの新パラダイム

本記事は arXiv:2502.09992 “Large Language Diffusion Models” の解説記事です。

論文概要(Abstract)

LLaDA(Large Language Diffusion with mAsking)は、マスク拡散フレームワークに基づく大規模言語モデルである。著者らは、前方過程でテキストトークンをランダムにマスクし、逆過程でTransformerがマスクトークンを予測する構成により、8Bパラメータのモデルをスクラッチから学習した。LLaDA 8Bは、in-context learningにおいてLLaMA3 8Bと同等の性能を示し、教師ありファインチューニング後には多ターン対話を含む指示追従能力を獲得している。さらに、逆順詩の補完タスクではGPT-4oを上回る性能を報告している。

この記事は Zenn記事: 拡散言語モデル2026年動向:Mercury・LLaDA・MoE統合の実装と展望 の深掘りです。

情報源

  • arXiv ID: 2502.09992
  • URL: https://arxiv.org/abs/2502.09992
  • 著者: Shen Nie, Fengqi Zhu, Zebin You, Xiaolu Zhang, Jingyang Ou, Jun Hu, Jun Zhou, Yankai Lin, Ji-Rong Wen, Chongxuan Li
  • 所属: Renmin University of China / Ant Group
  • 発表年: 2025年2月(NeurIPS 2025 Oral採択)
  • 分野: cs.CL, cs.LG

背景と動機(Background & Motivation)

大規模言語モデル(LLM)の発展は、自己回帰(Autoregressive: AR)アーキテクチャに依存してきた。GPT系列、LLaMA、Claudeなど主要モデルはすべてARフレームワークを採用しており、トークンを左から右へ逐次的に生成する。しかし、この一方向性には構造的制約がある。

第一に、ARモデルはシーケンス長に対して線形にレイテンシが増大する。100トークンの生成には100回の前方パスが必要となり、GPU並列性を十分に活用できない。第二に、因果マスク(causal mask)によりモデルは将来のトークンを参照できず、双方向の文脈理解が制限される。著者らはこの「反転の呪い(Reversal Curse)」と呼ばれる現象がARモデルの本質的な限界であると指摘している。

画像生成分野ではStable Diffusionに代表される拡散モデルが大きな成功を収めていた。しかし、テキストの離散性(discrete tokens)が画像の連続表現と根本的に異なるため、言語生成への拡散モデルの適用は困難とされてきた。LLaDAは、マスクベースの離散拡散(Masked Diffusion)を採用することでこの課題を解決し、ARモデルなしで大規模言語モデルを構築できることを実証した研究である。

主要な貢献(Key Contributions)

  • 貢献1: マスク拡散フレームワークで8Bパラメータの言語モデルをスクラッチから事前学習し、2.3兆トークンで学習を完了した。これは拡散言語モデルとしては当時最大規模である
  • 貢献2: in-context learningにおいてLLaMA3 8Bと同等の性能を達成し、拡散モデルでもARモデルの中核能力が獲得可能であることを実証した
  • 貢献3: 教師ありファインチューニング(SFT)により指示追従能力を獲得し、逆順タスク(Reversal Curse)ではGPT-4oを上回る性能を報告した

技術的詳細(Technical Details)

マスク拡散の数学的定式化

LLaDAの核心は、離散トークン列に対するマスクベースの拡散プロセスである。

前方過程(Forward Process)

入力テキスト $x_0 = (x_0^1, x_0^2, \ldots, x_0^L)$ に対し、各トークンを独立にマスクトークン $[M]$ に置換する。時刻 $t \in [0, 1]$ における遷移確率は以下で定義される:

\[q(x_t^i \mid x_0^i) = \begin{cases} [M] & \text{確率 } t \\ x_0^i & \text{確率 } 1 - t \end{cases}\]

ここで、

  • $x_0^i$: 元のテキストの$i$番目のトークン
  • $x_t^i$: 時刻$t$における$i$番目のトークン(マスクまたは元トークン)
  • $t$: マスク比率($t=0$で元テキスト、$t=1$で全マスク)

逆過程(Reverse Process)

マスクされたテキスト $x_t$ から元のテキストを復元する。モデル $p_\theta$ は以下を学習する:

\[p_\theta(x_0^i \mid x_t) = \text{Transformer}_\theta(x_t, t)[i]\]

各マスク位置$i$で、Transformerがシーケンス全体の文脈を参照して元トークンを予測する。ARモデルとの決定的な違いは、因果マスクを使用しない双方向Transformerを用いる点である。

学習目標(Training Objective)

以下の重み付きクロスエントロピー損失を最小化する:

\[\mathcal{L}(\theta) = -\mathbb{E}_{t \sim U(0,1), x_0, x_t \sim q(x_t \mid x_0)} \left[ \frac{1}{t} \sum_{i: x_t^i = [M]} \log p_\theta(x_0^i \mid x_t) \right]\]

ここで、

  • $U(0,1)$: $t$の一様分布
  • $\frac{1}{t}$: マスク比率$t$による正規化重み。マスクが少ない($t$が小さい)場合、各マスク位置の予測がより重要になるため重みが大きくなる
  • 和はマスクされた位置のみを対象とする

この損失関数は変分下界(ELBO)から導出されており、理論的にはデータの対数尤度の下界を最大化していることに相当する。

アーキテクチャの設計選択

LLaDAは標準的なTransformerアーキテクチャをベースとし、以下の変更を加えている:

  1. 双方向注意機構: 因果マスクを除去し、全トークン間で注意を計算。各位置が将来のトークンも参照できる
  2. RoPE位置埋め込み: LLaMA3と同様のRotary Position Embeddingを採用。相対位置情報の効率的なエンコーディングを実現
  3. マスクトークン埋め込み: 語彙に$[M]$トークンを追加し、専用の埋め込みベクトルを学習
  4. 時刻埋め込み: 拡散の時刻$t$をモデルに入力するための追加埋め込み層

推論アルゴリズム

推論時は以下の反復的アンマスキングを行う:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from dataclasses import dataclass

@dataclass
class LLaDAInferenceConfig:
    """LLaDA推論の設定パラメータ"""
    num_steps: int = 64  # 拡散ステップ数
    temperature: float = 1.0  # サンプリング温度
    mask_token_id: int = 32000  # [M]トークンのID

def llada_generate(
    model: torch.nn.Module,
    prompt_ids: torch.Tensor,
    gen_length: int,
    config: LLaDAInferenceConfig,
) -> torch.Tensor:
    """LLaDAのマスク拡散ベース生成

    Args:
        model: 学習済みLLaDAモデル
        prompt_ids: プロンプトのトークンID列 (1, prompt_len)
        gen_length: 生成するトークン数
        config: 推論設定

    Returns:
        生成されたトークンID列 (1, prompt_len + gen_length)
    """
    device = prompt_ids.device
    T = config.num_steps

    # 生成部分を全マスクで初期化
    masked = torch.full(
        (1, gen_length), config.mask_token_id,
        dtype=torch.long, device=device
    )
    x_t = torch.cat([prompt_ids, masked], dim=1)
    prompt_len = prompt_ids.shape[1]

    for step in range(T, 0, -1):
        t = step / T

        # 双方向Transformerで全位置の予測確率を取得
        logits = model(x_t, timestep=t)  # (1, seq_len, vocab_size)
        probs = torch.softmax(logits / config.temperature, dim=-1)

        # 現在マスクされている位置を特定
        mask_positions = (x_t == config.mask_token_id)
        num_masked = mask_positions.sum().item()

        if num_masked == 0:
            break

        # このステップでアンマスクするトークン数を決定
        num_to_unmask = max(1, int(num_masked / step))

        # 各マスク位置の予測信頼度を計算
        confidence = probs.max(dim=-1).values  # (1, seq_len)
        confidence[~mask_positions] = -float("inf")

        # 信頼度の高い位置からアンマスク
        _, top_indices = confidence[0].topk(min(num_to_unmask, num_masked))
        predicted_tokens = probs[0].argmax(dim=-1)

        for idx in top_indices:
            if idx >= prompt_len:  # プロンプト部分は変更しない
                x_t[0, idx] = predicted_tokens[idx]

    return x_t

推論の計算量: ARモデルが$L$トークンの生成に$L$回の前方パスを必要とするのに対し、LLaDAは$T$回の前方パス($T$は拡散ステップ数、通常10〜64)で$L$トークンを生成する。$T < L$の場合、理論上はARモデルより少ない前方パス数で生成が完了する。ただし、各前方パスでは全トークン位置を処理するため、1回あたりの計算量はARモデルの推論時より大きい。

実装のポイント(Implementation)

双方向Transformerの実装: 既存のLLaMAコードベースからcausal maskを除去するだけでは不十分である。Flash Attention 2を使用する場合、is_causal=Falseに設定する必要がある。また、KVキャッシュはARモデル専用の最適化であり、LLaDAでは使用できない。

マスクスケジュールの重要性: 前方過程のマスク率$t$のサンプリング分布が生成品質に大きく影響する。著者らは一様分布$U(0,1)$を推奨しているが、余弦スケジュール(cosine schedule)も有効であると報告している。

SFTの注意点: 教師ありファインチューニング時には、プロンプト部分をマスクせず、応答部分のみをマスクして学習する。これにより、プロンプトの情報を保持しつつ応答生成を最適化する。

ハイパーパラメータ: 推論ステップ数は品質と速度のトレードオフである。著者らの実験では$T=64$で十分な品質が得られるが、$T=128$ではさらなる改善は微小であったと報告している。

Production Deployment Guide

AWS実装パターン(コスト最適化重視)

LLaDAのような拡散言語モデルをAWS上で推論サービスとしてデプロイする場合、ARモデルとは異なるリソース特性を考慮する必要がある。KVキャッシュが不要なためメモリ効率は良いが、複数回の前方パスが必要なためGPU演算時間が増加する。

トラフィック量別の推奨構成:

規模月間リクエスト推奨構成月額コスト概算主要サービス
Small~3,000 (100/日)Serverless$50-150Lambda + Bedrock + DynamoDB
Medium~30,000 (1,000/日)Hybrid$400-1,000ECS Fargate + S3 + ElastiCache
Large300,000+ (10,000/日)Container$3,000-8,000EKS + GPU Instances + Karpenter

Small構成の詳細(月額$50-150):

  • Lambda: 推論リクエストのルーティング ($20/月)
  • Bedrock: マスク拡散モデルのカスタムモデルインポート、または代替としてHaiku使用 ($80/月)
  • DynamoDB: プロンプトキャッシュ、On-Demand ($10/月)
  • CloudWatch: 基本監視 ($5/月)

Medium構成の詳細(月額$400-1,000):

  • ECS Fargate: 0.5 vCPU, 4GB RAM × 2タスク、GPU不使用のCPU推論 ($150/月)
  • S3: モデル重みストレージ ($20/月)
  • ElastiCache Redis: cache.t3.micro, プロンプトキャッシュ ($15/月)
  • ALB: ロードバランシング ($20/月)

Large構成の詳細(月額$3,000-8,000):

  • EKS: コントロールプレーン ($72/月)
  • EC2 Spot: g5.xlarge × 2-4台、拡散モデル推論 ($800-1,600/月)
  • Karpenter: Spot自動スケーリング(追加コストなし)
  • S3 + CloudFront: モデル配信 ($50/月)

コスト試算の注意事項: 上記は2026年3月時点のAWS ap-northeast-1(東京)リージョン料金に基づく概算値です。拡散モデルはARモデルと異なりKVキャッシュ不要だがステップ数分の演算が必要なため、GPU利用時間が増加する傾向があります。最新料金は AWS料金計算ツール で確認してください。

Terraformインフラコード

Small構成 (Serverless): Lambda + DynamoDB

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
module "vpc" {
  source  = "terraform-aws-modules/vpc/aws"
  version = "~> 5.0"

  name = "llada-inference-vpc"
  cidr = "10.0.0.0/16"
  azs  = ["ap-northeast-1a", "ap-northeast-1c"]
  private_subnets = ["10.0.1.0/24", "10.0.2.0/24"]

  enable_nat_gateway   = false
  enable_dns_hostnames = true
}

resource "aws_iam_role" "lambda_llada" {
  name = "lambda-llada-inference-role"

  assume_role_policy = jsonencode({
    Version = "2012-10-17"
    Statement = [{
      Action = "sts:AssumeRole"
      Effect = "Allow"
      Principal = { Service = "lambda.amazonaws.com" }
    }]
  })
}

resource "aws_iam_role_policy" "bedrock_invoke" {
  role = aws_iam_role.lambda_llada.id
  policy = jsonencode({
    Version = "2012-10-17"
    Statement = [{
      Effect   = "Allow"
      Action   = ["bedrock:InvokeModel", "bedrock:InvokeModelWithResponseStream"]
      Resource = "arn:aws:bedrock:ap-northeast-1::foundation-model/*"
    }]
  })
}

resource "aws_lambda_function" "llada_handler" {
  filename      = "lambda.zip"
  function_name = "llada-inference-handler"
  role          = aws_iam_role.lambda_llada.arn
  handler       = "index.handler"
  runtime       = "python3.12"
  timeout       = 120
  memory_size   = 2048

  environment {
    variables = {
      DYNAMODB_TABLE = aws_dynamodb_table.prompt_cache.name
    }
  }
}

resource "aws_dynamodb_table" "prompt_cache" {
  name         = "llada-prompt-cache"
  billing_mode = "PAY_PER_REQUEST"
  hash_key     = "prompt_hash"

  attribute {
    name = "prompt_hash"
    type = "S"
  }

  ttl {
    attribute_name = "expire_at"
    enabled        = true
  }
}

resource "aws_cloudwatch_metric_alarm" "lambda_duration" {
  alarm_name          = "llada-lambda-duration-spike"
  comparison_operator = "GreaterThanThreshold"
  evaluation_periods  = 1
  metric_name         = "Duration"
  namespace           = "AWS/Lambda"
  period              = 3600
  statistic           = "Sum"
  threshold           = 100000

  dimensions = {
    FunctionName = aws_lambda_function.llada_handler.function_name
  }
}

Large構成 (Container): EKS + Karpenter + Spot

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
module "eks" {
  source  = "terraform-aws-modules/eks/aws"
  version = "~> 20.0"

  cluster_name    = "llada-inference-cluster"
  cluster_version = "1.31"
  vpc_id          = module.vpc.vpc_id
  subnet_ids      = module.vpc.private_subnets

  cluster_endpoint_public_access = true
  enable_cluster_creator_admin_permissions = true
}

resource "kubectl_manifest" "karpenter_provisioner" {
  yaml_body = <<-YAML
    apiVersion: karpenter.sh/v1
    kind: NodePool
    metadata:
      name: gpu-spot-pool
    spec:
      template:
        spec:
          requirements:
            - key: karpenter.sh/capacity-type
              operator: In
              values: ["spot"]
            - key: node.kubernetes.io/instance-type
              operator: In
              values: ["g5.xlarge", "g5.2xlarge"]
          limits:
            cpu: "32"
            memory: "128Gi"
      disruption:
        consolidateAfter: 30s
  YAML
}

resource "aws_budgets_budget" "llada_monthly" {
  name         = "llada-monthly-budget"
  budget_type  = "COST"
  limit_amount = "8000"
  limit_unit   = "USD"
  time_unit    = "MONTHLY"

  notification {
    comparison_operator       = "GREATER_THAN"
    threshold                 = 80
    threshold_type            = "PERCENTAGE"
    notification_type         = "ACTUAL"
    subscriber_email_addresses = ["ops@example.com"]
  }
}

運用・監視設定

CloudWatch Logs Insights クエリ:

1
2
3
4
5
6
7
-- 拡散ステップ数と推論レイテンシの相関分析
fields @timestamp, diffusion_steps, duration_ms, token_count
| stats avg(duration_ms) as avg_latency,
        pct(duration_ms, 95) as p95,
        pct(duration_ms, 99) as p99
  by diffusion_steps
| sort diffusion_steps asc

CloudWatch アラーム(Python):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import boto3

cloudwatch = boto3.client('cloudwatch')

cloudwatch.put_metric_alarm(
    AlarmName='llada-inference-latency-p99',
    ComparisonOperator='GreaterThanThreshold',
    EvaluationPeriods=2,
    MetricName='Duration',
    Namespace='LLaDA/Inference',
    Period=300,
    Statistic='p99',
    Threshold=60000,  # 60秒超過でアラート
    ActionsEnabled=True,
    AlarmActions=['arn:aws:sns:ap-northeast-1:123456789:llada-alerts'],
    AlarmDescription='LLaDA推論P99レイテンシ異常'
)

X-Ray トレーシング:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from aws_xray_sdk.core import xray_recorder, patch_all

patch_all()

@xray_recorder.capture('llada_inference')
def run_llada_inference(prompt: str, num_steps: int = 64):
    xray_recorder.put_annotation('model', 'LLaDA-8B')
    xray_recorder.put_annotation('diffusion_steps', num_steps)
    xray_recorder.put_metadata('prompt_length', len(prompt))

    output = model.generate(prompt, num_steps=num_steps)

    xray_recorder.put_metadata('output_length', len(output))
    return output

コスト最適化チェックリスト

アーキテクチャ選択:

  • ~100 req/日 → Lambda + Bedrock (Serverless) - $50-150/月
  • ~1000 req/日 → ECS Fargate (Hybrid) - $400-1,000/月
  • 10000+ req/日 → EKS + Spot Instances (Container) - $3,000-8,000/月

リソース最適化:

  • EC2 Spot Instances優先(最大90%削減、Karpenter自動管理)
  • Reserved Instances: 1年コミットで最大72%削減
  • Lambda: メモリ2048MBで演算性能とコストのバランス最適化
  • ECS/EKS: 夜間のスケールダウン設定

拡散モデル固有の最適化:

  • 推論ステップ数の動的調整(簡易プロンプトは少ステップ)
  • バッチ推論: 複数リクエストをバッチ化しGPU利用率向上
  • KVキャッシュ不要のため、メモリ効率はARモデルより良い
  • プロンプトキャッシュ: DynamoDBで同一プロンプトの結果を再利用

監視・アラート:

  • AWS Budgets: 月額予算設定(80%で警告)
  • CloudWatch: 推論レイテンシ・ステップ数の相関監視
  • Cost Anomaly Detection: GPU利用量の自動異常検知
  • 日次コストレポート: SNS/Slackへ自動送信

リソース管理:

  • 未使用GPUインスタンスの自動停止
  • タグ戦略: 環境別・モデル別でコスト可視化
  • S3モデル重みのライフサイクル管理
  • 開発環境は夜間停止(Auto Start/Stop)

実験結果(Results)

LLaDA 8Bの主要ベンチマーク結果を以下に示す(論文Table 1, Table 2より)。

ベンチマークLLaDA 8B-BaseLLaMA3 8B-Base差分
MMLU (5-shot)65.066.7-1.7
ARC-Challenge79.379.7-0.4
HellaSwag79.682.0-2.4
GSM8K (8-shot)55.356.7-1.4
HumanEval26.232.9-6.7

分析: 著者らは、ベースモデル同士の比較ではLLaMA3 8Bと概ね同等の性能であると報告している。特にMMLU(知識理解)とARC-Challenge(推論)ではほぼ同等である。一方、HumanEvalでは6.7%の差があり、コード生成タスクではARモデルが依然として優位性を持つ。

SFT後の性能(論文Table 3より):

  • MT-Bench: 6.27(LLaMA3 8B-Instruct: 7.24)
  • 逆順詩の補完: LLaDAがGPT-4oを上回る性能を報告

著者らは、逆順タスクでのLLaDAの優位性は双方向注意機構に起因すると分析している。ARモデルでは「A→B」の関係を学習しても「B→A」の推論が困難だが、LLaDAは双方向文脈を利用できるため、逆方向の推論が自然に行える。

スケーリング実験(論文Figure 2より): 著者らは1B/3B/8Bの3つのスケールで実験を行い、パラメータ数の増加に伴い性能が一貫して向上することを示している。この結果は、マスク拡散言語モデルにもARモデルと同様のスケーリング則が成立する可能性を示唆している。

実運用への応用(Practical Applications)

LLaDAの実運用における主な応用シナリオは以下である。

逆方向推論が必要なタスク: コード補完(関数シグネチャと本体の同時生成)、テンプレートベースのテキスト生成、双方向の依存関係を持つ構造化データ生成などでARモデルより優位性を発揮する可能性がある。

バッチ推論: 複数リクエストを同時処理する場合、拡散モデルの並列性が活きる。各ステップで全トークンを同時に処理するため、GPU演算ユニットの利用率が高く、バッチサイズの大きい推論でスループットが向上する。

制約: 2026年3月時点では、LLaDAの推論を効率化するフレームワーク(vLLMやTGI相当)は限定的である。dLLMフレームワークが開発されているが、ARモデル向けの推論インフラと比較してエコシステムの成熟度に大きな差がある。プロダクション利用には、自前での推論パイプライン構築が必要になるケースが多い。

関連研究(Related Work)

  • MDLM(Sahoo et al., 2024): マスク拡散言語モデルの理論的基盤を構築。LLaDAはMDLMのフレームワークを大規模にスケールさせた位置づけにある
  • SEDD(Score Entropy Discrete Diffusion)(Lou et al., 2024): 離散拡散の損失関数にスコアエントロピーを導入。LLaDAとは異なるアプローチだが、離散拡散の有効性を示した先行研究
  • GPT-2/3/4、LLaMA系列: ARモデルの代表例。LLaDAはこれらとの性能比較を通じて拡散モデルの競合可能性を実証した
  • Stable Diffusion / DALL-E: 画像生成分野での拡散モデルの成功。LLaDAは画像→テキストへの拡散モデル適用を進めた研究

まとめと今後の展望

LLaDAは、マスク拡散フレームワークで8Bパラメータの言語モデルをスクラッチ学習し、ARモデルと同等のin-context learning能力を達成した。LLMの中核能力がARアーキテクチャに依存しないことを実証した点が本研究の最大の意義である。

今後の方向性として、著者らは以下を挙げている:より大規模なモデルへのスケーリング(LLaDA 2.0で100Bに到達済み)、推論ステップ数の適応的調整による高速化、MoEとの統合によるパラメータ効率の改善である。2026年3月時点で、LLaDAを起点とする拡散言語モデルのエコシステムは急速に拡大している。

参考文献

この投稿は CC BY 4.0 でライセンスされています。