date
Feb 12, 2026
slug
cs336-as2
status
Published
tags
MLSys
summary
type
Post
本文记录了我做 cs336 assignment 2 的要点和 key takeaway
assignment 2 主要是在 assignment 1 的 pipeline 中做一些 system 上的优化:
  • 做 benchmark 和 profile 来理解系统的性能瓶颈
  • 用 triton 实现 flash attention 2 kernel
  • 实现 ddp + optimizer state sharding (也就是 ZeRO1)

Benchmarking & Profiling

Experiment #1: end-to-end benchmarking

分别 benchmark 单独的 forward 和 forward + backward, 基本符合 backward 计算量是 forward 的两倍左右的结论。

Experiment #2: profiling with nsight system

nsys profiling timeline view 的一个例子
nsys profiling timeline view 的一个例子
上图是 medium size model 的 nsys 的一个 timeline view,如果我们 zoom in 到一个 block 的 view:
notion image
可以看到 multihead-attention 是一个 block 内耗时最多的,其次是 ffn,最后是两个 rms norm,加起来也和 ffn 类似
notion image
再 zoom in 到 attention,我们可以看到三个子步骤的耗时,其中最大的是计算注意力分数。此外,我们可以看到在 attention 的过程中 launch 了好几个 kernel (最后一行)

Experiment #3: mix precision

为了更好地利用 tensor core,需要我们使用低精度数据类型,但是全低精度可能会使训练不稳定,或者让模型准确率下降。对此的一个解决办法是使用混合精度,对于对精度敏感的算子(比如 reduce 过程中的 accumulator 最好保持高精度),保持高精度,对于精度不敏感的算子(matmul),使用低精度。这可以使用 pytorch 的 autocast 来简单地实现。
可以看到混合精度获得了普遍的提升,并且模型越大,带来的提升越大。

Experiment #4: profiling memory

notion image
上图是 2.7B 模型在 fp32 下的 forward + backward 的 active memory timeline,随着 forward 进行 activation 使得显存占用增加,随着 backward 开始 activation 的显存下降但是 backward 占用的显存增加,因此后续显存保持基本稳定,第三阶段 backward 占用的显存开始释放,但是 optimizer 占用的显存增加,使得显存占用继续稳定。

Takeaway

从上述的 benchmarking 和 profiling 我们可以得到以下结论:
  • attention 是耗时最多的算子,rms norm 也占有不可忽视的比例
  • 混合精度可以显著加速
    • notion image
由此我们可以得到下一步可能的优化是优化 attention 算子(以及 rms norm)。通过 timeline view 我们可以看到目前它们都包含好几个 kernel invocation,因此最简单的思路是做 kernel fusion。

Optimizing Attention with FlashAttention-2

Experiment #1: benchmarking attention implementation

为了理解目前 attention 的性能瓶颈,我们对 attention 算子做单独的 benchmark
notion image
可以看到在 assignment 1 中实现的 attention 在显存占用上随着序列长度是平方增长的,而在耗时上也是超线性增长。

Experiment #2: benchmarking torch.compiled attention

一个很简单的优化尝试是直接使用 torch.compile 来进行优化,结果如下表。
notion image
在端到端速度上总体上是有一定的提升的,但是显存占用上还是随着序列长度平方增长。

Experiment #3: benchmarking flash attention 2

可以看到在 fp32,长 context 加速更明显;并且在长 context 避免了 oom
notion image
notion image
notion image
notion image

Key takeaway

flash attention 的核心思想:kernel fusion 和 tiling
 

Distributed Data Parallel

DDP 策略
Attention
Context
E2E (s)
Grad Sync (s)
Sync 开销
Naive DDP
Naive
512
1.9773
0.2646
13.38%
Naive DDP
Triton FA
512
1.8100
0.2654
14.66%
Naive DDP
Naive
1024
OOM
Naive DDP
Triton FA
1024
3.4338
0.4839
14.09%
Flat DDP
Naive
512
1.9819
0.2686
13.55%
Flat DDP
Triton FA
512
1.8009
0.2703
15.01%
Flat DDP
Naive
1024
OOM
Flat DDP
Triton FA
1024
3.4149
0.4832
14.15%
Overlap Individual
Naive
512
1.9467
0.2107
10.82%
Overlap Individual
Triton FA
512
1.7674
0.2138
12.10%
Overlap Individual
Naive
1024
OOM
Overlap Individual
Triton FA
1024
3.4086
0.4293
12.59%
Overlap Bucketed (256MB)
Naive
512
1.9643
0.2123
10.81%
Overlap Bucketed (256MB)
Triton FA
512
1.7836
0.2105
11.80%
Overlap Bucketed (256MB)
Naive
1024
OOM
Overlap Bucketed (256MB)
Triton FA
1024
3.4008
0.3966
11.66%
Triton FA 的加速效果显著。 相同 DDP 策略下,Triton FA 相比朴素 Attention 在 context=512 时 E2E 提速约 8–9%。更关键的是,朴素 Attention 在 context=1024 下全部 OOM,而 Triton FA 均可正常运行,说明显存效率是其最核心的优势。
Overlap 策略能有效压缩梯度同步开销。 Naive/Flat DDP 的同步开销约 13–15%,Overlap 策略通过将 AllReduce 与反向传播并行执行,将开销降至 11–13%。
Bucketed 在长序列下略优。 context=1024 时,Bucketed DDP 的同步时间(0.3966s)低于 Individual Overlap(0.4293s),合并通信减少了 AllReduce 的启动次数。context=512 时两者差异不明显。
Flat DDP 无明显收益。 与 Naive DDP 性能几乎相同,在 2-GPU 规模下参数展平的优化效果可忽略不计。

Sharded Optimizer

实现了一个简化的 ZeRO1, 但是由于和 ddp 分开实现,只是在 optimizer 的 step 之后广播权重。
Class Loading in JavaCS336 Assignment 1 Key Takeaway
Loading...