Home 論文解説: FlashAttention-2 — GPU並列性とワーク分割の改善によるAttention高速化
投稿
キャンセル

📄 論文解説: FlashAttention-2 — GPU並列性とワーク分割の改善によるAttention高速化

本記事は arXiv:2307.08691 “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning” の解説記事です。

論文概要(Abstract)

Tri Dao(Princeton/Together AI)は、FlashAttention v1のGPU利用率を25-40%から50-73%に改善するFlashAttention-2を提案した。3つのアルゴリズム改善——非matmul FLOPsの削減、シーケンス方向の並列化、warpレベルのワーク分割最適化——により、A100 GPU上でFlashAttention v1比約2倍の高速化を達成したと報告されている。

この記事は Zenn記事: rvLLM:Rust製vLLM代替で学ぶGPU推論エンジンの実装最適化 の深掘りです。rvLLMが15個のCUDAカーネルにFlashAttention-2を実装しており、そのアルゴリズムの原理を理解する上で必読の論文です。

情報源

背景と動機(Background & Motivation)

TransformerのAttention機構は $O(N^2)$ の時間・空間計算量を持ち、長シーケンスではメモリと速度の両面でボトルネックとなる。FlashAttention v1(Dao et al., 2022)は、GPUのメモリ階層(HBM↔SRAM)のI/O特性を活用したタイリング手法でメモリを $O(N)$ に削減し、2-4倍の実測高速化を達成した。

しかし、FlashAttention v1はGPUの理論性能(FLOP/s)の25-40%しか達成できていなかった。著者は2つの原因を特定している:

  1. warp間の非最適なワーク分割: 共有メモリへの不要な読み書きが発生
  2. シーケンス方向の並列化不足: バッチ×ヘッド数のみで並列化し、単一ヘッド内でのシーケンス方向並列化がない

主要な貢献(Key Contributions)

  • 非matmul FLOPsの削減: backward passのリスケーリング操作を効率化し、約15% FLOPs削減
  • ループ順序の変更: 外側ループをQ方向、内側ループをK/V方向に変更し、スレッドブロック間の独立性を確保
  • シーケンス方向並列化: (batch, head, seq_block) の3次元で並列化し、長シーケンスでのGPU occupancyを向上
  • warpレベルワーク分割の改善: warpがK/V列方向を分担し、Qはレジスタに保持することで共有メモリアクセスを削減

技術的詳細(Technical Details)

GPUメモリ階層

A100 GPUのメモリ構成:

  • SRAM(共有メモリ): 約19MB、帯域幅約19TB/s
  • HBM(高帯域幅メモリ): 40-80GB、帯域幅約2TB/s

SRAMはHBMの約10倍高速であり、FlashAttentionの核心はHBMアクセスを最小化するタイリング戦略にある。

標準AttentionのI/O問題

クエリ $Q$、キー $K$、バリュー $V \in \mathbb{R}^{N \times d}$ に対する標準Attention:

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d}}\right) V\]

標準実装では、$S = QK^\top$ をHBMに書き出し($O(N^2)$ メモリ)、softmaxを計算し、再びHBMから読み出して $PV$ を計算する。HBMアクセスは $O(N^2)$ となる。

FlashAttention v1のタイリング戦略

FlashAttention v1はQ, K, VをSRAMに収まるブロックに分割し、online softmax(タイル間で正規化定数を累積更新)でブロック単位にAttentionを計算する。

online softmaxの更新式(論文Section 3.1):

\[m_i^{\text{new}} = \max(m_i, \text{rowmax}(S_{ij}))\] \[\ell_i^{\text{new}} = e^{m_i - m_i^{\text{new}}} \ell_i + \text{rowsum}(e^{S_{ij} - m_i^{\text{new}}})\] \[O_i^{\text{new}} = \text{diag}(e^{m_i - m_i^{\text{new}}})^{-1} O_i + e^{S_{ij} - m_i^{\text{new}}} V_j\]

ここで、$m_i$ は行方向最大値(数値安定化用)、$\ell_i$ はsoftmax正規化定数の累積和、$O_i$ は出力の累積和である。

HBMアクセス量: $O(N^2 d / M)$($M$: SRAM容量)。$N^2 \gg M$ の場合、標準の $O(N^2)$ より大幅に削減。

FlashAttention-2の3つの改善

改善1: ループ順序の変更

FlashAttention v1: 外側K/Vブロック、内側Qブロック

FlashAttention-2: 外側Qブロック、内側K/Vブロック

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# FlashAttention-2 forward pass(概念的な擬似コード)
def flash_attention_2_forward(
    Q: torch.Tensor,  # (N, d)
    K: torch.Tensor,  # (N, d)
    V: torch.Tensor,  # (N, d)
    block_size_q: int,
    block_size_kv: int,
) -> torch.Tensor:
    """FlashAttention-2のフォワードパス

    Args:
        Q, K, V: クエリ・キー・バリュー行列
        block_size_q: Qブロックサイズ
        block_size_kv: K/Vブロックサイズ

    Returns:
        Attention出力
    """
    N, d = Q.shape
    O = torch.zeros_like(Q)

    # 外側ループ: Qブロック(並列化可能)
    for i in range(0, N, block_size_q):
        Q_i = Q[i:i+block_size_q]
        m_i = torch.full((block_size_q,), float('-inf'))
        l_i = torch.zeros(block_size_q)
        O_i = torch.zeros(block_size_q, d)

        # 内側ループ: K/Vブロック(逐次処理)
        for j in range(0, N, block_size_kv):
            K_j = K[j:j+block_size_kv]
            V_j = V[j:j+block_size_kv]

            S_ij = Q_i @ K_j.T / (d ** 0.5)

            # Online softmax更新
            m_i_new = torch.max(m_i, S_ij.max(dim=-1).values)
            P_ij = torch.exp(S_ij - m_i_new.unsqueeze(-1))
            l_i_new = torch.exp(m_i - m_i_new) * l_i + P_ij.sum(dim=-1)

            O_i = torch.diag(torch.exp(m_i - m_i_new)) @ O_i + P_ij @ V_j
            m_i, l_i = m_i_new, l_i_new

        O[i:i+block_size_q] = O_i / l_i.unsqueeze(-1)

    return O

利点: 各Qブロックが独立に計算可能なため、異なるスレッドブロック間でQブロックを並列処理できる。FlashAttention v1では同一Qブロックが複数のK/Vイテレーションで更新されるため、スレッドブロック間の同期が必要だった。

改善2: シーケンス方向並列化

FlashAttention v1: (batch, head) の2次元で並列化

FlashAttention-2: (batch, head, seq_block) の3次元で並列化

長シーケンス・小バッチの推論では、batch × head数だけではGPUのSM(Streaming Multiprocessor)を埋め切れない。シーケンス方向にも並列化することで、GPU occupancyが向上する。

改善3: warpレベルワーク分割

FlashAttention v1: 各warpがO(出力)の異なる行を担当。K, VはすべてのwarpがSRAMから読み取るため、共有メモリ帯域幅がボトルネック。

FlashAttention-2: warpがK/Vの列方向を分担。Qはレジスタに保持し、warp shuffle命令で交換。共有メモリへの書き込みを大幅に削減。

I/O計算量の理論分析

FlashAttentionのHBMアクセス量(論文Theorem 1より):

\[\Theta\left(\frac{N^2 d}{M}\right)\]

これは情報理論的な下界と一致し、I/O最適であることが証明されている。FlashAttention-2は漸近的な計算量は同一だが、定数係数を改善している。

MQA/GQAサポート

FlashAttention-2はMulti-Query Attention(MQA)とGrouped-Query Attention(GQA)をネイティブサポートする:

  • MQA: 複数のQヘッドが単一のK/Vヘッドを共有。K/Vを1回ロードし、複数Qヘッドで再利用
  • GQA: $h$ 個のQヘッドを $g$ グループに分割し、各グループが1つのK/Vヘッドを共有

KVキャッシュサイズが $h/g$ 倍に削減されるため、推論時のメモリ効率が向上する。

実装のポイント(Implementation)

CUDA vs Triton: FlashAttention-2はCUDA C++(高性能)とTriton(可読性・ポータビリティ)の両実装を提供。warp shuffle命令やレジスタ割り当ての直接制御が必要なため、CUDA版がTriton版を上回る。

rvLLMとの関連: rvLLMは15個のCUDAカーネルをPTXで手書きしており、そのうちFlashAttention-2カーネルが最も複雑な実装である。共有メモリ並列リダクション、warpレベル削減、ベクトル化float4ロードを組み合わせ、GQA対応のPagedAttention V2カーネルとして統合している。

精度要件: FP16またはBF16が必要。FP32は非対応。rvLLMがPhase 2でFP32→FP16に切り替えた際にFlashAttention-2の恩恵を受けた理由はここにある。

Production Deployment Guide

AWS実装パターン(コスト最適化重視)

規模月間リクエスト推奨構成月額コスト
Small~3,000Lambda + Bedrock$50-150
Medium~30,000ECS Fargate + Bedrock$300-800
Large300,000+EKS + Karpenter + Spot$2,000-5,000

コスト試算の注意事項: 上記は2026年3月時点のAWS ap-northeast-1料金に基づく概算値です。最新料金は AWS料金計算ツール で確認してください。

Terraformインフラコード

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
module "vpc" {
  source  = "terraform-aws-modules/vpc/aws"
  version = "~> 5.0"

  name = "flashattn-vpc"
  cidr = "10.0.0.0/16"
  azs  = ["ap-northeast-1a", "ap-northeast-1c"]
  private_subnets = ["10.0.1.0/24", "10.0.2.0/24"]

  enable_nat_gateway   = false
  enable_dns_hostnames = true
}

resource "aws_lambda_function" "inference" {
  filename      = "lambda.zip"
  function_name = "flashattn-inference"
  role          = aws_iam_role.lambda_bedrock.arn
  handler       = "index.handler"
  runtime       = "python3.12"
  timeout       = 60
  memory_size   = 1024
}

resource "aws_budgets_budget" "monthly" {
  name         = "flashattn-monthly"
  budget_type  = "COST"
  limit_amount = "5000"
  limit_unit   = "USD"
  time_unit    = "MONTHLY"

  notification {
    comparison_operator        = "GREATER_THAN"
    threshold                  = 80
    threshold_type             = "PERCENTAGE"
    notification_type          = "ACTUAL"
    subscriber_email_addresses = ["ops@example.com"]
  }
}

コスト最適化チェックリスト

  • Spot Instances優先(最大90%削減)
  • Reserved Instances: 1年コミットで72%削減
  • Bedrock Batch API使用で50%割引
  • Prompt Caching有効化で30-90%削減
  • AWS Budgets設定(80%で警告)
  • CloudWatch: GPU利用率・トークン使用量モニタリング
  • Cost Anomaly Detection有効化
  • タグ戦略: 環境別コスト可視化
  • S3ライフサイクル: キャッシュ自動削除
  • 開発環境: 夜間GPU停止

実験結果(Results)

フォワードパス速度

A100 80GB上でのAttention計算速度(論文Table 1, causal attention):

手法TFLOPs/s理論性能比
PyTorch naive~3~3%
xFormers memory-efficient~10-12~15%
FlashAttention v120-3025-40%
FlashAttention-240-6050-73%

シーケンス長別の高速化

FlashAttention v1に対するFlashAttention-2の高速化倍率(論文Figure 3, A100):

シーケンス長FA v1 (TFLOPs/s)FA-2 (TFLOPs/s)高速化
51225401.6×
102430501.7×
409630581.9×
819228602.1×
1638425582.3×

シーケンス長が長くなるほど高速化倍率が増大する。これはシーケンス方向並列化の効果であり、長コンテキスト推論でFlashAttention-2の恩恵が特に大きいことを示している。

エンドツーエンドの訓練高速化

GPT-3相当モデルの4×A100訓練(論文Table 2):

モデルPyTorch比 FA v1PyTorch比 FA-2
GPT3-1.3B1.8×2.0×
GPT3-2.7B2.0×2.2×
GPT3-6.7B2.2×2.4×

実運用への応用(Practical Applications)

FlashAttention-2は事実上すべての現代的LLM推論・訓練で採用されている標準技術である。PyTorch 2.0以降のF.scaled_dot_product_attentionはFlashAttention-2を内部的に呼び出す。

rvLLMへの影響: rvLLMがPhase 2でFP16に切り替えた際、cuBLASのhgemmとFlashAttention-2の組み合わせにより、Phase 1(FP32)の8,339 tok/sから2.6倍の高速化を達成している。FlashAttention-2の恩恵は、Tensor Coreの活用とメモリ帯域幅の削減の両面で現れる。

関連研究(Related Work)

  • FlashAttention v1 (Dao et al., 2022): I/O-awareなタイリング手法の原論文。FlashAttention-2はこの上に構築
  • FlashAttention-3 (Dao & Gu, 2024): H100のWGMMAとTMAを活用し、最大740 TFLOPs/sを達成。FlashInferがMLSys 2025最優秀論文賞を受賞
  • PagedAttention (Kwon et al., SOSP 2023): FlashAttention-2のタイリングはPagedAttentionの非連続KVレイアウトと互換性がある

まとめと今後の展望

FlashAttention-2は、GPU利用率を25-40%→50-73%に改善した実用的に重要な最適化である。3つの改善(ループ順序変更、シーケンス並列化、warp分割改善)はいずれもアルゴリズムの正確性を保ちながらハードウェア利用率のみを改善している点が特徴的である。rvLLMのようなRust製推論エンジンがPTXカーネルでFlashAttention-2を再実装していることは、この手法がハードウェア効率の観点で不可欠であることを示している。

参考文献

この投稿は CC BY 4.0 でライセンスされています。