date
slug
status
tags
summary
type
本文记录了我做 cs336 assignment 1 的过程中的一些要点和 key takeaway
assignment 1 主要的内容就是从零开始搭建一整个模型和训练的 pipeline:
- tokenizer
- model
- optimizer
- training utilility
Tokenizer
这里我们实现的是 byte level bpe tokenizer, 根据算法的基本描述我们可以实现最简单的版本:
- 统计 pair 频率
- merge 最大频率的 pair
- 重复直到达到目标 vocab 大小
这个实现在实际的使用中效率非常低,主要的一个原因是:每次统计 pair 频率的时候都要扫描整个语料。
为了提高效率,一个常见的做法是先做 pretokenization,也就是先做一个粗粒度的 tokenization(比如先根据空格分割就是一种方法),后续的统计 pair 频率和 merge pair 只在 pretoken 的内部做,消除了跨边界的情况。
应用了 pretokenization 之后只是消除了跨边界的情况,仍然大概要扫一遍整个语料来统计 pair 频率和 merge 最大频率的 pair,此外,pretokenization 本身也成为一个性能瓶颈。对于这两个点,可以进行如下优化:
- 以 special token (比如 <|endoftext|>)为分界线的不同 chunk 可以独立地做 pretokenization,因此可以通过 multiprocessing 来并行化处理
- 当 merge 一个 pair 的时候,实际上只有包含这个 pair 的 pretoken 里面的其他 pair 会被影响到,因此这里可以事先建立一个索引,每次 merge 的时候只需要 merge 这些并更新索引。
以上是训练 tokenizer 的部分,对于 tokenizer 的推理同样存在问题:如何处理超大文件。我在这里的处理是按照 line 流式读取结合批处理来控制内存使用峰值。
Model

模型架构使用类似 llama 的架构 (不过没有使用 group query attention):
- pre-norm
- SwiGLU
- RoPE
Optimizer
实现了 AdamW optimizer
Training Utility
- cross entropy
- gradient clipping
- learning rate scheduling
- data loader
- evaluation
Experiments and Key Findings
Experiment set #1: Ablation (on TinyStories)
- 去掉 rms norm:维持和去掉之前一样大小的 learning rate 会直接 loss 起飞,调小学习旅可以练但是 loss 下降慢
- 改成 post norm:loss 下降 behavior 类似,但是略高一点,稳定性上看起来差别不大(这个似乎和经验上的结论不一样,可能是因为模型和数据都太小了;下图青色是 post norm 蓝色是full model)

- no positional encoding: 和 post norm 类似,loss 下降 behavior 类似,但是略高一些,说明没有位置编码也能学习一些位置信息

- swiglu 换成 silu + mlp: 和上面两个类似,loss 下降 behavior 类似但是略高一些

下图是总的比较,可以看出影响的大小从高到低依次是:
- 去掉 rms norm
- 去掉位置编码
- swiglu 换 silu + mlp
- 使用 post norm

Experiment set #2: hyperparameter selection
这里主要考量了两个超参数的影响:learning rate 和 batch size
- 对于 learning rate:实验中我们调整最大 learning rate,最小 learning rate 设置为最大的 10%,实验的结论是当最大学习率小于某个阈值,学习率越大有利于 loss 快速下降,超过了之后就会发散。
- 对于 batch size:batch size 越大 loss 下降越稳定,但是每个 step 看的 token 多了,如果控制 token 数量不变,最后达到的 loss 不一定有小 batch 好(对于这个结论的一个思考是可能对于小模型使用小 batch 有可能更值得,而对于大模型和大数据可能本来就很难练所以用大 batch 尽量稳定训练比较重要;此外从系统的角度考虑,由于到目前为止我们的系统没有实现数据并行,大 batch 受单卡显存大小限制,只能用 gradient accumulation 来增大有效 batch 大小,但是这个会有点拖慢训练)
Experiment on OWT
最后在 openwebtext 上进行了训练,选择超参数主要参考了 llama 等模型的技术报告:
- max learning rate 一开始用了 3e-4,后来调大到 3e-3
- batch size 用了 1M token
- beta1 0.9 beta2 0.95
- weight decay 0.1
- gradient clipping 1.0
- model 参数量 100m 数量级,数据量根据 chinchilla 的 compute optimal 比例选择
Future Work
在最后的 openwebtext 的训练中我在我自己的这个 pipeline 上感觉到的一些痛点和可改进的地方:
- 用的是 naive 的 attention,应该可以用 flash attention,以及考虑推理的话可以用 group query attention
- 为了增大有效 batch size 现在用的是 gradient accumulation,但是有点拖慢速度,如果实现数据并行会更好一点
- Author:Lifan Sun
- URL:stevensun.site/article/cs336-as1
- Copyright:All articles in this blog, except for special statements, adopt BY-NC-SA agreement. Please indicate the source!
Relate Posts
Reading Notes: “DistServe: Disaggregating Prefill and Decoding for Goodput-optimized Large Language Model Serving”

Reading Notes: “Preble: Efficient Distributed Prompt Scheduling for LLM Serving”

Reading Note: “ORCA: A Distributed Serving System for Transformer-Based Generative Models”

Reading Notes: “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”

Reading Notes: “Efficient Memory Management for Large Language Model Serving with PagedAttention”

Reading Notes: “Alpa: Automating Inter- and Intra-Operator Parallelism for Distributed Deep Learning”

