date
slug
status
tags
summary
type
本文记录了我做 cs336 assignment 1 的过程中的一些要点和 key takeaway
assignment 1 主要的内容就是从零开始搭建一整个模型和训练的 pipeline:
  • tokenizer
  • model
  • optimizer
  • training utilility

Tokenizer

这里我们实现的是 byte level bpe tokenizer, 根据算法的基本描述我们可以实现最简单的版本:
  • 统计 pair 频率
  • merge 最大频率的 pair
  • 重复直到达到目标 vocab 大小
这个实现在实际的使用中效率非常低,主要的一个原因是:每次统计 pair 频率的时候都要扫描整个语料。
为了提高效率,一个常见的做法是先做 pretokenization,也就是先做一个粗粒度的 tokenization(比如先根据空格分割就是一种方法),后续的统计 pair 频率和 merge pair 只在 pretoken 的内部做,消除了跨边界的情况。
应用了 pretokenization 之后只是消除了跨边界的情况,仍然大概要扫一遍整个语料来统计 pair 频率和 merge 最大频率的 pair,此外,pretokenization 本身也成为一个性能瓶颈。对于这两个点,可以进行如下优化:
  • 以 special token (比如 <|endoftext|>)为分界线的不同 chunk 可以独立地做 pretokenization,因此可以通过 multiprocessing 来并行化处理
  • 当 merge 一个 pair 的时候,实际上只有包含这个 pair 的 pretoken 里面的其他 pair 会被影响到,因此这里可以事先建立一个索引,每次 merge 的时候只需要 merge 这些并更新索引。
以上是训练 tokenizer 的部分,对于 tokenizer 的推理同样存在问题:如何处理超大文件。我在这里的处理是按照 line 流式读取结合批处理来控制内存使用峰值。

Model

notion image
模型架构使用类似 llama 的架构 (不过没有使用 group query attention):
  • pre-norm
  • SwiGLU
  • RoPE

Optimizer

实现了 AdamW optimizer

Training Utility

  • cross entropy
  • gradient clipping
  • learning rate scheduling
  • data loader
  • evaluation

Experiments and Key Findings

Experiment set #1: Ablation (on TinyStories)

  • 去掉 rms norm:维持和去掉之前一样大小的 learning rate 会直接 loss 起飞,调小学习旅可以练但是 loss 下降慢
  • 改成 post norm:loss 下降 behavior 类似,但是略高一点,稳定性上看起来差别不大(这个似乎和经验上的结论不一样,可能是因为模型和数据都太小了;下图青色是 post norm 蓝色是full model)
notion image
 
  • no positional encoding: 和 post norm 类似,loss 下降 behavior 类似,但是略高一些,说明没有位置编码也能学习一些位置信息
notion image
  • swiglu 换成 silu + mlp: 和上面两个类似,loss 下降 behavior 类似但是略高一些
notion image
下图是总的比较,可以看出影响的大小从高到低依次是:
  • 去掉 rms norm
  • 去掉位置编码
  • swiglu 换 silu + mlp
  • 使用 post norm
notion image

Experiment set #2: hyperparameter selection

这里主要考量了两个超参数的影响:learning rate 和 batch size
  • 对于 learning rate:实验中我们调整最大 learning rate,最小 learning rate 设置为最大的 10%,实验的结论是当最大学习率小于某个阈值,学习率越大有利于 loss 快速下降,超过了之后就会发散
  • 对于 batch sizebatch size 越大 loss 下降越稳定,但是每个 step 看的 token 多了,如果控制 token 数量不变,最后达到的 loss 不一定有小 batch 好(对于这个结论的一个思考是可能对于小模型使用小 batch 有可能更值得,而对于大模型和大数据可能本来就很难练所以用大 batch 尽量稳定训练比较重要;此外从系统的角度考虑,由于到目前为止我们的系统没有实现数据并行,大 batch 受单卡显存大小限制,只能用 gradient accumulation 来增大有效 batch 大小,但是这个会有点拖慢训练)

Experiment on OWT

最后在 openwebtext 上进行了训练,选择超参数主要参考了 llama 等模型的技术报告:
  • max learning rate 一开始用了 3e-4,后来调大到 3e-3
  • batch size 用了 1M token
  • beta1 0.9 beta2 0.95
  • weight decay 0.1
  • gradient clipping 1.0
  • model 参数量 100m 数量级,数据量根据 chinchilla 的 compute optimal 比例选择

Future Work

在最后的 openwebtext 的训练中我在我自己的这个 pipeline 上感觉到的一些痛点和可改进的地方:
  • 用的是 naive 的 attention,应该可以用 flash attention,以及考虑推理的话可以用 group query attention
  • 为了增大有效 batch size 现在用的是 gradient accumulation,但是有点拖慢速度,如果实现数据并行会更好一点
Class Loading in JavaReading Notes: “DistServe: Disaggregating Prefill and Decoding for Goodput-optimized Large Language Model Serving”
Loading...