date
Feb 12, 2026
slug
cs336-as2
status
Published
tags
MLSys
summary
type
Post
本文记录了我做 cs336 assignment 2 的要点和 key takeaway
assignment 2 主要是在 assignment 1 的 pipeline 中做一些 system 上的优化:
- 做 benchmark 和 profile 来理解系统的性能瓶颈
- 用 triton 实现 flash attention 2 kernel
- 实现 ddp + optimizer state sharding (也就是 ZeRO1)
Benchmarking & Profiling
Experiment #1: end-to-end benchmarking
分别 benchmark 单独的 forward 和 forward + backward, 基本符合 backward 计算量是 forward 的两倍左右的结论。
Experiment #2: profiling with nsight system

上图是 medium size model 的 nsys 的一个 timeline view,如果我们 zoom in 到一个 block 的 view:

可以看到 multihead-attention 是一个 block 内耗时最多的,其次是 ffn,最后是两个 rms norm,加起来也和 ffn 类似

再 zoom in 到 attention,我们可以看到三个子步骤的耗时,其中最大的是计算注意力分数。此外,我们可以看到在 attention 的过程中 launch 了好几个 kernel (最后一行)
Experiment #3: mix precision
为了更好地利用 tensor core,需要我们使用低精度数据类型,但是全低精度可能会使训练不稳定,或者让模型准确率下降。对此的一个解决办法是使用混合精度,对于对精度敏感的算子(比如 reduce 过程中的 accumulator 最好保持高精度),保持高精度,对于精度不敏感的算子(matmul),使用低精度。这可以使用 pytorch 的 autocast 来简单地实现。
可以看到混合精度获得了普遍的提升,并且模型越大,带来的提升越大。
Experiment #4: profiling memory

上图是 2.7B 模型在 fp32 下的 forward + backward 的 active memory timeline,随着 forward 进行 activation 使得显存占用增加,随着 backward 开始 activation 的显存下降但是 backward 占用的显存增加,因此后续显存保持基本稳定,第三阶段 backward 占用的显存开始释放,但是 optimizer 占用的显存增加,使得显存占用继续稳定。
Takeaway
从上述的 benchmarking 和 profiling 我们可以得到以下结论:
- attention 是耗时最多的算子,rms norm 也占有不可忽视的比例
- 混合精度可以显著加速

由此我们可以得到下一步可能的优化是优化 attention 算子(以及 rms norm)。通过 timeline view 我们可以看到目前它们都包含好几个 kernel invocation,因此最简单的思路是做 kernel fusion。
Optimizing Attention with FlashAttention-2
Experiment #1: benchmarking attention implementation
为了理解目前 attention 的性能瓶颈,我们对 attention 算子做单独的 benchmark

可以看到在 assignment 1 中实现的 attention 在显存占用上随着序列长度是平方增长的,而在耗时上也是超线性增长。
Experiment #2: benchmarking torch.compiled attention
一个很简单的优化尝试是直接使用 torch.compile 来进行优化,结果如下表。

在端到端速度上总体上是有一定的提升的,但是显存占用上还是随着序列长度平方增长。
Experiment #3: benchmarking flash attention 2
可以看到在 fp32,长 context 加速更明显;并且在长 context 避免了 oom




Key takeaway
flash attention 的核心思想:kernel fusion 和 tiling
Distributed Data Parallel
DDP 策略 | Attention | Context | E2E (s) | Grad Sync (s) | Sync 开销 |
Naive DDP | Naive | 512 | 1.9773 | 0.2646 | 13.38% |
Naive DDP | Triton FA | 512 | 1.8100 | 0.2654 | 14.66% |
Naive DDP | Naive | 1024 | OOM | — | — |
Naive DDP | Triton FA | 1024 | 3.4338 | 0.4839 | 14.09% |
Flat DDP | Naive | 512 | 1.9819 | 0.2686 | 13.55% |
Flat DDP | Triton FA | 512 | 1.8009 | 0.2703 | 15.01% |
Flat DDP | Naive | 1024 | OOM | — | — |
Flat DDP | Triton FA | 1024 | 3.4149 | 0.4832 | 14.15% |
Overlap Individual | Naive | 512 | 1.9467 | 0.2107 | 10.82% |
Overlap Individual | Triton FA | 512 | 1.7674 | 0.2138 | 12.10% |
Overlap Individual | Naive | 1024 | OOM | — | — |
Overlap Individual | Triton FA | 1024 | 3.4086 | 0.4293 | 12.59% |
Overlap Bucketed (256MB) | Naive | 512 | 1.9643 | 0.2123 | 10.81% |
Overlap Bucketed (256MB) | Triton FA | 512 | 1.7836 | 0.2105 | 11.80% |
Overlap Bucketed (256MB) | Naive | 1024 | OOM | — | — |
Overlap Bucketed (256MB) | Triton FA | 1024 | 3.4008 | 0.3966 | 11.66% |
Triton FA 的加速效果显著。 相同 DDP 策略下,Triton FA 相比朴素 Attention 在 context=512 时 E2E 提速约 8–9%。更关键的是,朴素 Attention 在 context=1024 下全部 OOM,而 Triton FA 均可正常运行,说明显存效率是其最核心的优势。
Overlap 策略能有效压缩梯度同步开销。 Naive/Flat DDP 的同步开销约 13–15%,Overlap 策略通过将 AllReduce 与反向传播并行执行,将开销降至 11–13%。
Bucketed 在长序列下略优。 context=1024 时,Bucketed DDP 的同步时间(0.3966s)低于 Individual Overlap(0.4293s),合并通信减少了 AllReduce 的启动次数。context=512 时两者差异不明显。
Flat DDP 无明显收益。 与 Naive DDP 性能几乎相同,在 2-GPU 规模下参数展平的优化效果可忽略不计。
Sharded Optimizer
实现了一个简化的 ZeRO1, 但是由于和 ddp 分开实现,只是在 optimizer 的 step 之后广播权重。
- Author:Lifan Sun
- URL:stevensun.site/article/cs336-as2
- Copyright:All articles in this blog, except for special statements, adopt BY-NC-SA agreement. Please indicate the source!
Relate Posts
CS336 Assignment 1 Key Takeaway

Reading Notes: “DistServe: Disaggregating Prefill and Decoding for Goodput-optimized Large Language Model Serving”

Reading Notes: “Preble: Efficient Distributed Prompt Scheduling for LLM Serving”

Reading Note: “ORCA: A Distributed Serving System for Transformer-Based Generative Models”

Reading Notes: “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”

Reading Notes: “Efficient Memory Management for Large Language Model Serving with PagedAttention”






