https://github.com/THUDM/slime/pull/467

这篇文章介绍一下在 Slime 上支持Context Parallelism的历程

背景知识

Ring Attn

随着上下文长度越来越大, transformer 需要的和 seq 成平方关系K, V 存储与计算也越来越大, 以下这篇文章介绍了怎么用 RingAttn 实现上下文并行 aka Context Parallelism.

arxiv.org

简单来说就是 把按照GPU rank 切分 sequence, 每个GPU 保持各自的Q 矩阵 按照环形的拓扑结构交换 K, V 并且在交换的Q, K, V计算的同时, 只要计算比通信耗时长, 就相当于 zero overhead 地无损可以把以前的 max_sequence_length 提高到 #GPU * max_sequence_length

image.png

Ring Flash Attn

GitHub - zhuzilin/ring-flash-attention: Ring attention implementation with flash attention

Ring Flash Attn 是帮助实现 Context Parallelism的一个库, 通过把Ring Attn 和FlashAttn 结合, 高效地实现在环形交换 K, V 的情况下的 Attn 的计算

具体的原理如下:

递推公式

递推公式

在 Slime上支持CP

我们使用 Ring Flash Attn 在Slime 的FSDP training engine 上尝试支持 CP:

  1. 首先先初始化 CP 相关参数 替换 attn
def setup_context_parallelism(self) -> None:
	  world_size = dist.get_world_size()
	  all_ranks = list(range(world_size))
	  self.cp_group = dist.new_group(ranks=all_ranks, backend="nccl")
	  self.cp_size = world_size
	  self.cp_rank = dist.get_rank(group=self.cp_group)
	  substitute_hf_flash_attn(self.cp_group, heads_k_stride=1)
	  print(f"[Rank {self.cp_rank}] Context Parallelism initialized - CP group size: {self.cp_size}")