Reading Notes: “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints”

date
Mar 1, 2025
slug
gqa
status
Published
tags
NLP
summary
type
Post

Motivation

在原版的注意力机制中(MHA, Multi-Head Attention),解码时由于对于每个 query 都需要加载所有的 key 和 value,这可能带来内存瓶颈;多查询注意力(MQA, Multi-Query Attention)通过多个 query 头共享 key 和 value 来缓解这个问题,但是常常带来模型性能的损失。
本文提出的组查询注意力(GQA, Group-Query Attention)在 MHA 和 MQA 之间做了一个权衡,让一组 query 头共用 key 和 value,在减少内存通信压力的同时保持了模型的性能。

Approach

notion image

Uptraining: MHA to MQA

将使用 MHA 的模型的 checkpoint 转化为使用 MQA 需要如下步骤:
  • 首先将所有键和值头的投影矩阵平均池化成单一投影矩阵
  • 然后使用原始预训练计算量的5%继续预训练,使模型适应新结构

Uptraining: MHA to GQA

和 MQA 类似:
  • 将查询头分成G组,每组共享一个键头和值头
  • 通过平均池化组内所有原始头来构建每个组的键和值头
可以看到,MQA 和 MHA 都可以看作 GQA 的特殊情况,当组大小为 H 就变成 MQA,组大小为 1 时就是 MHA.
notion image

Results

使用 T5 进行了实验,评估了摘要生成、翻译和问答任务的性能:
  • uptraining 后的 MQA-XXL 模型比原始 MHA-Large 模型提供了更好的质量和更快的推理速度
  • GQA-XXL 达到了接近 MHA-XXL 的质量,同时保持了接近 MQA-XXL 的速度
  • 仅使用 5% 的原始预训练计算量即可有效地将 MHA 模型转换为 MQA 或 GQA 模型
  • GQA 在转换后立即能达到合理性能,而 MQA 需要更多训练才能获得良好效果
 
 
 

© Lifan Sun 2023 - 2025