WaferLLM:分布式 AI 系统的循环与突破

化转环属,各有形势,反覆相求,因事为制。

——《鬼谷子》

说起来,WaferLLM 这个工作确实已经告一段落了。OSDI 录用,但是比起开心,我更多的想法是复杂。但作为怀疑主义者,我始终不相信 LLM 和 Transformer 架构是终极答案,更不认为 scaling law 的曲线可以永远躺平。站在技术拐点上,我们比任何时候都需要保持清醒——躺平就意味着停滞。

正是这样的认知,让我在开篇选择了《鬼谷子》的箴言:

化转环属,各有形势,反覆相求,因事为制。

这句古训揭示的技术演进规律,恰与我们的研究形成对照:技术发展如同环环相扣的齿轮,面对不同阶段的态势变化,需要不断探寻本质规律,根据实际情况调整策略。WaferLLM的设计理念,正暗合了这个「」的概念。

那么 Wafer 是什么呢?

1 Wafer 是什么?

简单来说,晶圆级芯片(Wafer-Scale Chip)是一种将整个晶圆作为单个芯片使用的技术,而不是将晶圆切割成多个小芯片。

纵观芯片面积的演进史,从单核 CPU 到多核 CPU,再到 GPU、TPU 等专用加速器,表面上是计算单元的堆叠,实则是算力需求与物理限制的持续博弈。当摩尔定律逐渐失效,晶体管密度提升放缓时,单个 die 面积受制于光刻掩模版尺寸限制,通常被锁在 400-800 平方毫米的区间。光刻技术复杂度和良品率双重制约下,这一矛盾变得愈发尖锐。

面对单个芯片面积和性能的物理天花板,我们只能转向 off-chip 扩展路径:即通过访问 DRAM 获取更大存储容量,利用 NVLink 等高速互连实现多芯片协同,或借助 InfiniBand 等跨节点网络构建分布式计算集群。然而,这些 off-chip 方案引入了更高的延迟和带宽瓶颈,相比on-chip通信的纳秒级延迟,off-chip 访问往往需要数百纳秒甚至微秒级的时间代价,严重限制了系统的可扩展性。

晶圆级集成技术打破了这种僵局,将芯片面积提升了两个数量级。以 Cerebras WSE-3 为例:有 46,225 平方毫米的超大尺寸,相比传统 GPU 芯片实现 57 倍面积突破,集成 4 万亿晶体管与 90 万 AI 计算核心,对「面积瓶颈」的进行了突破。

1.1 晶圆级技术优势

从我们论文的表 1 可以看出,晶圆级芯片相比传统系统级封装有几个显著优势:

  1. 性能优势:晶圆级芯片可以集成数万亿晶体管,比常见的 GPU 多 100 倍,支持数百万计算核心。同时,更大的芯片面积也提供了 数十 GB 片上内存和 数十 PB/s 内存带宽,是标准 GPU 的 1,000 倍以上。
  2. 集成效率:基于晶圆的芯晶互连提供单位面积 10 倍带宽,比起传统 PCB 连接(如NVIDIA NVLink),每比特能效接近有 100 倍的提升
  3. 成本降低:晶圆级集成可降低制造成本,因为芯片制造成本中 30-50% 与测试和封装单个芯片相关。此外,TSMC等公司正在开发将测试过的芯片集成到单个晶圆的技术,进一步提高良率。

许多人认为晶圆级芯片(wafer-scale chip)的制造成本极高,但实际上,随着芯片制造工艺的进步,这一观点已经不再准确。现代芯片制造技术能够生产面积更小、更加标准化的计算核心,这使得每个核心的制造良率显著提升,从而摊薄了整体成本。

此外,晶圆级集成省去了单颗芯片的分拆、测试与封装环节,而这部分通常占据了芯片制造总成本的30-50%。像TSMC等先进代工厂,也在开发将经过测试的高良率芯片直接集成到单一晶圆上的技术,进一步提升了产出效率和良率。

小核心设计不仅降低了制造成本,还为晶圆级芯片带来了更高的可用性

除此之外,最近有一些新的流片技术,通过在设计中引入低成本的冗余互连(redundant wire design),硬件可以灵活地绕过制造中出现的缺陷区域,实现动态的硬件重映射(hardware remapping)。这意味着即使部分核心或互连存在缺陷,整个晶圆级芯片依然能够高效工作,进一步提升了芯片的整体可用性和可靠性。


简而言之,我们试图复用传统的多核思路,进一步提升了芯片面积。上述的片汤话其实也都来自 Wikipedia 和一些相关很容易查找到的资料。看起来,我们只是用上了更好的芯片,吃到了硬件红利。

但是真的是这样吗?

1.2 Wafer 存在的挑战

但如果说,我们设计出了更大面积的芯片,释放带来了更多的算力,那么代价是什么呢?毕竟没有免费的午餐。突破的背后必然隐藏着新的约束

而代价就是:我们曾经的大部分算法设计将不再适用。

首先,我们先看看我们「曾经」是如何编程和设计模型的。

传统的芯片结构中,不论是 CPU 还是 GPU,对单一加速器,我们的设计的思路更接近 Uniform。设计的时候,我们将计算核心集中放在一起,随着计算核心的距离,我们设置了多级 Cache 用于更高效的访存。

那么,谈完硬件,转向模型。我们引用最简单的 DNN 也就是,MNIST 数字识别为例[1]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from keras.models import Model
from keras.layers import Input, Dense, Dropout
from keras import regularizers
from keras.optimizers import Adam

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("mnist/", one_hot=True)
x_train = mnist.train.images # 训练数据 (55000, 784)
y_train = mnist.train.labels # 训练标签
x_test = mnist.test.images
y_test = mnist.test.images

# DNN网络结构
inputs = Input(shape=(784,))
h1 = Dense(64, activation='relu', kernel_regularizer=regularizers.l2(0.01))(inputs) # 权重矩阵l2正则化
h1 = Dropout(0.2)(h1)
h2 = Dense(64, activation='relu', kernel_regularizer=regularizers.l2(0.01))(h1) # 权重矩阵l2正则化
h2 = Dropout(0.2)(h2)
h3 = Dense(64, activation='relu', kernel_regularizer=regularizers.l2(0.01))(h2) # 权重矩阵l2正则化
h3 = Dropout(0.2)(h3)
outputs = Dense(10, activation='softmax', kernel_regularizer=regularizers.l2(0.01))(h3) # 权重矩阵l2正则化
model = Model(input=inputs, output=outputs)

# 编译模型
opt = Adam(lr=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-08) # epsilon模糊因子
model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy']) # 交叉熵损失函数

# 开始训练
model.fit(x=x_train, y=y_train, validation_split=0.1, batch_size=128, epochs=4)
model.save('k_DNN.h5')

我们可能也无数次看过类似这样的模型结构图[2]

我们编程、设计模型,思考的对象是算子模块,实际上几乎不考虑数据放置的问题。不会考虑每一个矩阵,例如刚刚代码中的 h1h2 放到 GPU 的什么位置。我们思考的时候大部分时候,只关心 GPU SRAM 中有什么数据。所以,当出现计算和内存访问比例特殊的模型,例如 Llama 架构的 LLM 模型中的 decode 阶段,我们就不得不依赖 Flash Attention 这样的特殊设计的算子。

那么我们再次拿出 LLM decoding 的代码,这里用 numpy 版本简易的 decoding 模块用来举例,参考 llama3.np 项目[3]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
norm_x = RMSNormFlatten(x, input_layer_norm_weight, args.norm_eps)
bsz, seqlen, _ = norm_x.shape

xq = norm_x @ q_weight
xk = norm_x @ k_weight
xv = norm_x @ v_weight

xq = xq.reshape(bsz, seqlen, n_local_heads, head_dim)
xk = xk.reshape(bsz, seqlen, n_local_kv_heads, head_dim)
xv = xv.reshape(bsz, seqlen, n_local_kv_heads, head_dim)

xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)

cache_k[:bsz, start_pos : start_pos + seqlen] = xk
cache_v[:bsz, start_pos : start_pos + seqlen] = xv

ks = cache_k[:bsz, : start_pos + seqlen]
vs = cache_v[:bsz, : start_pos + seqlen]

xk = repeat_kv(ks, n_rep) # (bs, cache_len+seqlen, n_local_heads, head_dim)
xv = repeat_kv(vs, n_rep) # (bs, cache_len+seqlen, n_local_heads, head_dim)
xq = xq.transpose(0, 2, 1, 3) # (bs, n_local_heads, seqlen, head_dim)
xk = xk.transpose(0, 2, 1, 3) # (bs, n_local_heads, cache_len+seqlen, head_dim)
xv = xv.transpose(0, 2, 1, 3) # (bs, n_local_heads, cache_len+seqlen, head_dim)
scores = np.matmul(xq, xk.transpose(0, 1, 3, 2)) / math.sqrt(head_dim)
if mask is not None:
scores = scores + mask[None, None, :, :]

scores = softmax(scores)
output = np.matmul(scores, xv)
output = output.transpose(0, 2, 1, 3).reshape(bsz, seqlen, -1)

h1 = output @ o_weight
z = x + h1
norm_z = RMSNormFlatten(z, post_attention_layernorm_weight, args.norm_eps)
z1 = norm_z @ up_weight
z2 = norm_z @ gate_weight
z2 = silu(z2)
z3 = z1 * z2
h2 = z3 @ down_weight
out = z + h2

我们设计模型的过程中,也很少考虑每一个部分的数据摆放。

更具体来说,例如;

  1. xq, xk, xv 三步连续计算的过程中,我们的 norm_x 其实需要尽可能保留在计算单元附近
  2. z3 计算之后,我们直接放在 z1 原本的位置似乎可以带来很多提升

当然,还有更多优化点,这里不做赘述列举。目前的解决方案,想必大家已经猜到,就是 triton。诚然,我们确实可以定制非常高性能的 cuda 算子,手动管理刚刚提到的复杂多层级内存。但是 triton 提供有限但够用的编程抽象,解决了我们最关心,也是对性能影响最重要的一环,即 L1 Cache 的行为。对 LLM 来说,这一步无疑是对性能影响最关键的一环,而 flash attention 系列工作,也是发现了这里产生的性能瓶颈,从而提出了高性能的优化。

在 GPU 上,我们几乎无法应对复杂的内存管理了。那么到了 Wafer 上 LLM 推理面临的主要挑战是芯片内存进一步变得复杂,问题将更加难解。我们先看一下 Cerebras 芯片的结构[4]

目前是市场上最具有代表性的 Wafer Scale Chip: Cerebras 的 Wafer Scale Engine (WSE) 采用了独特的 2D mesh 架构设计,整个芯片由数以万计的处理单元(PE,Processing Element)构成,它们以网格形式紧密排布在单个硅晶圆上。

每个 PE 包含三个关键组件:

  • 一个计算核心(compute kernel)负责执行实际运算
  • 一块小型的本地内存(local memory)用于数据存储
  • 一个路由器(router)处理通信需求。

这些 PE 通过高性能的片上网络(Network-on-Chip)相互连接,形成一个高带宽、低延迟的通信网络,使数据能够在不同处理单元之间高效流动。与之对应的,我们也有了更多的挑战:

  1. 分布式内存结构:我们的内存被打碎成无数小块的片上 Local Memory(对应 GPU 的 L1)。单块空间极小,例如 CS-2 上只有 48 KB。
  2. Mesh NoC:这里我们看到的是 2D Mesh 结构的网络,我们对数据的访问需要进行复杂的遍历。
  3. 数据移动开销:在传统的 GPU 内存中,我们很少关注 transpose、slice、reshape 等操作,而这里我们任何数据的存储和管理都异常复杂。

所以我们也看到,在如此复杂的内存结构中,哪怕是模型权重怎么切,怎么摆,都将成为值得探讨的问题。更不用说更加复杂的 KV Cache 管理了。

不过,我们在 WaferLLM 工作中,对上述各种问题提出了一套完整可行的方案。

2 PLMR 模型

在 Wafer Chip 上,在这种复杂的多层次内存架构中,哪怕是模型权重的切分和放置都成为值得深入探讨的问题,更不用说更加复杂的 KV Cache 管理了。

现有分布式 AI Library,在 Wafer Chip 几乎全部失效。利用 PLMR 模型,我们可以分析为什么现有 AI 系统难以充分利用 Wafer Chip,我们主要将现有工作分成两个类别:

  1. 基于共享内存的系统:如 Ladder,基于对称和均匀的内存访问,无法容忍 Wafer Chip 在访问远程内存时的 1000 倍延迟差异(违背 L 的限制)。此外,这些编译器主要关注计算,而较少优化内存限制和通信限制,这些工作将数据大量复制并,这违背了内存约束要求(违背 M)。最后,这些将各种通信 pattern 简化的操作没有仔细设计路由的 pattern,如果要直接映射在 2D Mesh 上,routing 会异常复杂,严重违背了路由资源的约束。
  2. 分布式内存系统:如T10系统,这一类工作确实保证了对同一芯片上其他核心的内存访问具有恒定跳数。T10 解决了内存约束(M)和路由资源限制(R)。然而,在 Wafer Chip 上,它需要经常进行大规模的 Sync 操作,例如 all reduce 和 all-to-all (违背 L),且仅扩展到数千而非数百万核心(很难适应 P 的拓展要求)。
  3. 集群级分布式系统:如 Megatron-LM,主要针对 GPU/TPU 集群设计,采用粗粒度的数据并行和模型并行策略。虽然能够扩展到大规模集群,但其通信模式依赖于高带宽的 off-chip 互连(如 NVLink、InfiniBand),无法适应 wafer-scale 芯片内部的细粒度并行需求(违背 P)。同时,其假设的均匀内存访问模式与 wafer 上的非均匀延迟特性不符(违背 L)。

故而现有的 off-chip 扩展方法并不能直接适用于 wafer-scale 硬件。这些方法应用到具有大规模 mesh 拓扑片上网络的 wafer-scale 芯片时,往往无法有效发挥硬件潜力,在数百万核心规模下遭遇严重的通信瓶颈。

为 Wafer-scale Chip 设计软件时,我们具体会面临哪些挑战?内存、通信、计算这些问题同时出现的时候,为了更有条理地进行思考,我们提出了 PLMR 模型来回答这一问题。

这一模型体现了从传统共享内存架构向片上大规模分布式内存系统的关键技术转变,通过精确捕获 wafer-scale 芯片的四个核心特征,为用户在设计高性能软件时提供指导和评估依据。

2.1 什么是PLMR模型

PLMR是四个关键硬件属性的首字母缩写:

  1. 海量并行(P - massive Parallelism):Wafer Chip 上有数百万的并行核心,相比之下,大部分 GPU 只有几千,最多几万个 core。在 wafer Chip 的每个核心上,都具有微型的硬件流水,可以在指令级并行完成数据输入、输出、计算和内存访问。
  2. 非均匀的访存延迟(L - highly non-uniform memory access Latency):在 Mesh NoC 中访问其他核的内存,会表现出高度非均匀的延迟。例如,在具有 N * M 的网格中,到最远的 NoC 跳数是 N + M (最左上角到最右下角的 PE 进行通信)。对百万核心片上网格,这可能达到几千跳,所以一个 core 访问 「local memory」和「同一芯片上,最远的 PE 的内存」之间,存在数千倍的延迟差异。因此,尽可能减少长距离通信至关重要。
  3. 有限的本地内存(M - constrained local Memory):过大容量的内存芯片会导致性能和能效下降,所以每个核心的内存都相对较小。因此,我们算法中,所有的 Matrix 和 Tensor,都必须用合适的方法进行切分,让每个 PE 分配能被存下的小数据块。
  4. 受限的路由资源(R - constrained Routing resources):由于芯片面积限制,每个PE核心的路由器只能支持有限数量的路由表项,通常仅为数十个条目。而现有方法往往需要数百万个路由条目,如图中红线所示。一旦路由表项耗尽,通信就必须回退到计算核心进行软件层路由,如图中黄线所示,这会引入显著的延迟开销。因此,面向wafer-scale芯片的软件系统必须采用路由策略,确保通信路径保持在片上网络(NoC)范围内,避免软件层的额外开销。

2.2 PLMR模型能指导我们做什么

基于 PLMR 模型,我们重新思考了整个 LLM 的架构,重新设计 Wafer Chip 上的各种系统。从算子,到内存管理,到最终的推理引擎,都重新进行了精细的设计。

最终,我们设计了适用于符合 PLMR 模型的各种硬件指标的 LLM 推理系统,WaferLLM。具体来说,有如下三个层次:

  1. 模型并行策略:PLMR 引导我们设计了针对 Wafer Chip 的有效 LLM 并行方法,包括 prefill 阶段、decode 阶段和 KV-cache 管理的并行策略。这使WaferLLM能有效扩展到数百万核心(满足P),同时最小化通信成本(满足L)、优化内存使用(满足M)并简化路由模式(满足R)。
  2. 片上分布式算子设计:PLMR 启发我们提出了 MeshGEMM 和 MeshGEMV,首次为 WaferChip 设计了高性能的分布式 GEMM 和 GEMV 算子。这些算子专门针对mesh拓扑网络优化,比传统的 SUMMA 和 CANNON 算法快2-3倍。
  3. 系统部署和推理引擎:基于 PLMR,我们将整个系统部署在 Cerebras WSE-2 上。通过在单个 Wafer Chip 上运行完整 LLM 推理,我们不仅最小化片外通信,还最大化片上带宽利用率。

凭借这些核心贡献,WaferLLM 打破了世界纪录,实现了每个请求超过 2700 tokens/s 的推理速度,相比 Megatron、Ladder 和 T10 等现有方案提升了 100-200 倍。

3 WaferLLM

对于 LLM 在 Wafer Chip 这样的加速器上工作的时候,我们需要充分发挥众多 PE 的并行能力。而在 Prefill 和 Decode 两个阶段,我们都有各种挑战。

3.1 算子优化 - GEMM

首先在 Prefill 阶段,我们大量使用 GEMM,即矩阵-矩阵乘法。传统的矩阵乘法不能满足在 PLMR 的约束条件。

为确定适合 PLMR 模型的可扩展分布式 GEMM,我们定义了以下指标:

  • 每核心路径数:每个核心的路由路径数,更少路径确保符合 R 属性;
  • 关键路径长度:每步传输子矩阵的最长通信路径(上图中的红线),更少跳数符合 L 属性;
  • 每核心内存:每个核心所需内存,更低使用确保 M 属性。

那么,我们来分析现有分布式 GEMM 方法:

  1. 基于 Allgather 的 GEMM:常见于 GPU 和 TPU 集群。每步最长通信路径是一个核心从最远核心收集数据,需要 N 步完成 allgather。每个核心创建 N 个与其行列中邻居的通信路径(违反 R)。每步的 gather 跨越 O(N) 跳的关键路径(违反 L),每个核心由于通信,需要用大量的 Buffer 使用 O(1/N) 内存,远超本地子矩阵的 O(1/N²)(违反 M)。
  2. SUMMA:这是 Cerebras 目前使用的算法。每步最长通信路径是一个核心沿列或行向最远核心广播数据。每个核心创建 N 个通信路径(违反 R),跨越最长路径中 O(N) 跳的关键路径(违反 L)。虽然 SUMMA 改进了内存使用,但仍然是本地分区子矩阵大小的两倍。
  3. Cannon:网格优化的分布式 GEMM 选择,在超级计算机中流行。每步最长通信路径是头核心向尾核心发送数据。每个核心在 2D 环中与两个邻居通信,只需 O(1) 通信路径和最佳 O(1/N²) 内存使用。但它产生 O(N) 跳的关键路径(违反 L)。
  4. MeshGEMM(我们的方法):符合 PLMR 模型的分布式 GEMM。每个核心与相距两跳的两个邻居通信。这种设计实现了每个核心需要 O(1) 通信路径和类似 Cannon 的最佳 O(1/N²) 内存使用。关键是,它将关键路径限制为 2 跳,复杂度为 O(1),使其巧妙地能够解决 L 属性。

我们的设计涉及两个关键步骤:

  1. 使用 GEMM 的循环移位过程确保算法正确性
  2. 证明这个循环上的两跳通信是满足 L 属性所需的最小距离

循环移位使 MeshGEMM 能够满足 M 和 R 属性,方法是将通信限制在两个邻居并最小化内存使用。它确保 GEMM 结果的正确性,遵循与 Cannon 类似的数据搬运方案。

对于通信,我们希望进一步最小化关键路径长度,从而满足 L 属性。我们的关键思想是引入 INTERLEAVE 操作来找到逻辑到物理的映射关系。

INTERLEAVE 算法根据核心的索引值确定其发送和接收的邻居索引:

而这个复杂的伪代码,如果用可视化的方法理解,会更加简单。

上面的动画中,我们首先看到,在 Cannon 算法中,有一步超长的通信距离。

为了规避超长的通信链路,我们想到了环形的空间结构。

在环上,任意两个邻居的距离都是 1。

如果我们将环压平到一维空间上,就能实现了最长链路为 2 的移动方案。

时间复杂度从 O(n) 降低到了 O(1)。

我们基于一维数组的讨论自然扩展到二维网格,我们分别沿着 X 轴和 Y 轴都进行 interleave 操作,在 MeshGEMM 中,任意一次「move」运算都将通信开销限制在 2 hop 的时间。

之后的操作,其实和 Cannon 算法并没有太大区别,我们采用同样的移动和计算交错的过程。MeshGEMM 算法主要步骤如下:

  1. 初始化:考虑 C=A×B。MeshGEMM 将 A 和 B 沿两个维度分割成子块 A_subB_sub,形成 N×N 块,分布在核心上。每个核心接收 A_subB_sub 的一个块。MeshGEMM 然后使用 INTERLEAVE 初始化每个核心的邻居位置。
  2. 对齐:每个核心与邻居对齐,确保分布式系统中的每个核心都以适当的操作数开始矩阵乘法过程。
  3. 计算-移位循环:每个核心执行 N 步通信和计算循环。在每一步中:
    • 计算部分和 C_sub = A_sub × B_sub + C_sub
    • 同时,沿 X 轴移位 A_sub 和沿 Y 轴移位 B_sub,获取下一步计算的新 A'_subB'_sub(如图 7 的③所示)
    • N 步后,返回累积的 C_sub

3.2 算子优化 - GEMV

分布式GEMV的完成时间主要取决于一个 Allgather 操作,该操作从所有选定核心聚合部分结果并将聚合结果广播回所有核心。和 GEMM 类似,我们也对同样的指标进行分析。

符号补充:

  • αper-hop transmission latency):指的是消息在 mesh 网络中,每经过一个核心(core)直接转发时产生的延迟。这个延迟随着跳数(hops)增加而线性增长,是硬件路由器根据预设规则直接转发数据时的基本延迟。
  • βper-routing latency):指的是消息在转发过程中,每遇到一次需要软件解析和重写头部(header parsing and rewriting)时产生的额外延迟。通常 β 大于 α,因为涉及更多的软件处理。

在 mesh 架构中,消息从一个核心传递到另一个远端核心时,总延迟由 α 和 β 共同决定。具体来说,最大内存访问延迟为: $\alpha(N_w + N_h) + \beta r$ 。其中,Nw 和 Nh 分别是 mesh 的宽和高,r 是路径上的路由阶段数(r < Nw + Nh)。

MeshGEMV 是唯一完全符合 PLMR 模型的方法:

  1. Pipeline Allreduce:常用于 TPU 集群系统和 Cerebras。它将路由资源使用限制为每个核心O(1)(满足R)。然而,其最长聚合路径是从尾到头核心,如红线所示,跨越O(N)的关键路径(违反L)。

  2. Ring Allreduce:常用于GPU集群系统,是默认配置。它将路由资源使用限制为O(1)(满足R)。但是,它在关键路径上跨越O(N)跳,(违反L)。

  3. K-Tree Allreduce:我们构建一个平衡的 K 层树(这里 K 是树的层数,而不是树的岔数)从两个方向归约;其最长聚合路径是从头或尾核心到树根核心。关键路径主要项为$O(N^{1/k} \times K)$,可以解决 L。每个根核心的最大通信路径数是 O(K),可以通过调整 K 满足 R 限制。

MeshGEMV 算法主要步骤:

  1. 初始化:考虑 C=A×B,其中 A 是向量。MeshGEMV 将 B 沿两个维度分割成子块 B_sub,形成 N×N 块分布在各计算核心上。对于向量 A,MeshGEMV 沿向量长度进行分割,形成在一个轴上分布的 N 块,并在另一轴上复制 A。每个核心接收 A_subB_sub 的一个块。然后根据 K 树结构确定哪些核心在每个阶段形成一组,以便高效获取聚合结果。
  2. 并行计算:在此阶段,每个核心独立执行本地 GEMV 操作 A_sub × B_sub,计算得到各自的部分和 C_sub
  3. 聚合:聚合步骤主要利用我们设计的 2-Way K-Tree Allreduce 机制,具体包括:
    • 第 1 阶段:每组内部进行组内归约,将结果汇集到各组的根核心,获得 C_sub 的部分和
    • 第 k 阶段:将第(k-1)阶段的结果进一步归约到第 k 阶段每组的根核心
    • 重复 K 次后,完整的 C 可以通过连接所有 K 树根核心的 C_sub 获得
    • (可选)根据是否需要连续的 GEMV 操作,可能会执行从 K 树根核心向下的广播操作

3.3 模型排布

处理完了算子,我们接下来开始排布整个模型,而其中 Prefill 和 Decode 有一些不同的挑战

  • Prefill 期间的多个大型矩阵,需要有效的维度划分,充分利用所有 PE,满足 (P);
  • Decode 使用比 Prefill 更小的矩阵,需要谨慎并行;
  • 该阶段主要依赖GEMV操作,计算不如GEMM密集,导致计算阶段短,与通信之间互相掩藏延迟的能力有限;
  • Prefill 和 Decode 连续进行 GEMM/GEMV 的时候,需要处理矩阵转置。

首先,我们提出了两种不同的切分方案。

Prefill 划分方案:

Decode 划分方案:

Prefill 划分 中:我们沿 PE 阵列的 X 和 Y 轴划分矩阵的两个维度,实现比现有方法更细粒度、百万级并行。上图展示了 Self Attention 和 Feed Forward 的在 Prefill 阶段的划分方式。

Decode 划分 中:当张量维度不足以实现 Decode 所需的高并行度时,我们在 LLM 中,将向量沿数据排列的正交方向进行复制。这种方法提高并行度并确保所有核心间的负载平衡,同时避免额外的通信操作,用冗余存储代替通信。

而为了消除矩阵转置,我们在 Prefill 中,设计了 免转置分布式GEMM。提出免转置的算子,更改通信的方向,使用转置分布式 GEMM(dist-GEMM-T) 计算 Prefill 期间的 Q@K^T,避免矩阵转置这一在 NoC 上代价高昂的操作。

而在 Decode 过程中,由于算子的瓶颈在 Allgather,重新设计算法并不能带来收益,所以我们预先优化模型权重排布,避免矩阵转置。为 Decode 预先优化好模型权重排布,将转置好的矩阵直接读到 PE 阵列上,可以在 MeshGEMV 阶段消除矩阵转置。虽然这在 Prefill 和 Decode 阶段之间,引入了重新排布模型权重的开销。

但凭借 NoC 网络上超强的通信能力,这部分开销比起生成一个 token 的开销,几乎可以忽略不计。

3.4 基于移位的 KV 缓存管理

KV 缓存管理在 PLMR 设备上也没那么简单,需要在分布式核心上存储大量数据,同时遵守本地内存约束(M)并分配KV缓存计算以实现高并行度(P)。

简单来说:我们在 2D Mesh 上实现了自适应的 KV Cache 存储方案。通过动态均衡让片上内存利用更充分。

现有方法的缺陷:传统的KV缓存管理采用「直接拼接」的方法 - 把新数据直接添加到缓存末尾。这种方法在PLMR 设备上会导致严重的负载不均:

  • 只有最后一行的核心在工作
  • 其他核心大部分时间处于空闲状态
  • 造成计算和内存资源的严重浪费

我们的解决方案其实很简单,但是默认的 Cerebras inference 系统之前并没有实现。

基于移位的新方法:不再简单地把新数据加到末尾,而是让所有核心都参与工作:

  • 每个核心都存储一部分KV缓存
  • 新数据来时,所有核心协同工作,把旧数据向上移动
  • 每个核心的负载保持均衡
  • 避免了某些核心过载而其他核心空闲的问题

4 WaferLLM 效果

通过 WaferLLM,我们展示了 Wafer Chip 在 LLM 推理方面的巨大潜力。我们在 Cerebras WSE-2 上进行了全面评估,与多个最先进的系统进行了对比。实验结果表明,WaferLLM 在系统性能、算子优化和能效方面都实现了显著突破。

4.1 端到端 LLM 推理性能

我们首先评估了 WaferLLM 与代表性系统的端到端性能对比,包括分布式内存架构的 T10 系统和共享内存架构的 Ladder 系统。

相比 T10 系统的性能提升

  • 短序列生成任务(输入4096/2048 tokens,输出128 tokens):WaferLLM 平均快 160 倍,最高达到 180 倍
  • 长序列生成任务(输入输出均为4096/2048 tokens):WaferLLM 平均快 36 倍,最高达到 48 倍

虽然 T10 考虑了 PLMR 设备的内存约束(M)和路由资源限制(R),但它无法处理网格 NoC 互连的核心架构,因此无法解决不同跳数距离的问题(L),也无法扩展到百万级核心(P)。

相比 Ladder 系统的性能提升

  • 短序列生成任务:WaferLLM 平均快 625 倍,最高达到 677 倍
  • 长序列生成任务:WaferLLM 平均快 312 倍,最高达到 342 倍

Ladder 系统为共享内存架构设计,无法适应 PLMR 设备的特性,导致无法在百万核心上分割 LLM(P)、产生昂贵的长程 NoC 通信(L)、无法处理本地内存约束(M)和有限的路由资源(R)。

4.2 算子级性能优化

MeshGEMM 性能

  • 相比 Cerebras WSE 默认的 SUMMA 算法快 2-3 倍
  • 相比超级计算机常用的 Cannon 算法也有显著提升
  • 在核心数量扩展时保持 70% 以上的计算效率,而 SUMMA 和 Cannon 效率降至 50% 以下

MeshGEMV 性能

  • 相比 Cerebras 优化的默认 GEMV 实现快 4-8 倍
  • 通过高效的双向 K-tree AllReduce 显著减少通信开销
  • 随着核心数量增加,通信成本仅略微增长,而基线方法出现严重性能退化

可扩展性分析: 我们的算子在不同核心配置下展现出优异的扩展性。对于 GEMM 8K 等大规模矩阵运算,计算变为带宽约束而非延迟约束,增加核心数量能够提升聚合网络带宽,解决性能瓶颈。

4.3 KV 缓存管理效果

基于移位的 KV 缓存管理相比传统的基于连接的方法(如 PagedAttention)实现了巨大改进:

模型 连接方法(PagedAttention) 移位方法(WaferLLM) 提升倍数
LLaMA3-8B 382 tokens 137,548 tokens 360x
LLaMA2-13B 16 tokens 6,168 tokens 385x

这一显著改进源于移位方法实现的平衡核心利用率,解决了连接方法导致的数据分部不均匀的问题。

4.4 与 GPU 的性能对比

我们将 WaferLLM(基于 Cerebras WSE-2)与运行 vLLM 的 NVIDIA A100 进行了公平对比。两者均采用 TSMC 7nm 工艺制造。

GEMV 操作对比

  • 性能提升:MeshGEMV 比 cuBLAS GEMV 快 606 倍
  • 能效提升:能效比 A100 高 22 倍

完整 LLM 推理对比

模型 WaferLLM(WSE-2) vLLM(A100) 性能提升 能效提升
LLaMA3-8B 2,480 tokens/s 78.36 tokens/s 31.6x 1.4x
LLaMA2-13B 1,848 tokens/s 47.86 tokens/s 38.6x 1.7x

推理速度突破

  • LLaMA3-8B:达到 2,700 tokens/sec/req 的 decode 速度
  • QWen2-72B:达到 840 tokens/sec/req 的 decode 速度

需要特别强调的是,无论增加多少块 GPU,传统 GPU 架构都难以实现单请求(per request)更高的推理吞吐。这是由于 GPU 集群受限于离片(off-chip)扩展,模型参数和激活值需要频繁跨芯片通信,导致带宽瓶颈和高延迟,单个请求的推理速度很难提升。

相比之下,WaferLLM 的晶圆级集成架构将全部计算核心和大容量片上内存紧密集成于同一硅片上,彻底消除了多芯片通信瓶颈。这样,即使面对大模型的单请求推理,也能实现远超 GPU 集群的极高吞吐率。例如在 LLaMA3-8B 上,1 batch decode 速度高达 2,700 tokens/sec,是单块 A100 的 30 多倍,即便用多块 GPU 也难以达到这一水平。

4.5 性能分析与局限

虽然 WaferLLM 实现了显著的性能提升,但我们也观察到一些当前的局限性:

  1. 从 GEMV 到完整 LLM 的性能衰减:GEMV 的 22 倍能效优势在完整 LLM 推理中降至 1.7 倍,主要原因包括:
    • WSE-2 核心的本地 SRAM(48KB)有限,阻碍了高效的张量并行
    • 当前 LLM 模型针对 GPU 架构优化,窄层设计限制了在 WSE-2 核心上的层放置
  2. 硬件成熟度影响:WSE-2 作为第二代产品,核心还无法完全重叠内存访问和计算,边缘核心利用率不足,长程 NoC 通信开销仍然存在。
  3. 软件栈限制:Cerebras 当前的软件栈相比 NVIDIA CUDA 的优化程度有限,影响了整体性能发挥。(我们当时发现 rmsnorm 速度非常慢,profile 之后发现是指数运算除法占据大量的时间,甚至远远超过了通信时间,我们目前还是使用手动实现的软件算子,用牛顿迭代法和高阶展开的数值算法进行替换,这里也能看到 NVIDIA 的强大护城河)

尽管存在这些限制,WaferLLM 依然实现了数量级的性能和能效提升。我们预期随着晶圆级 AI 计算的不断成熟和这些限制的逐步解决,性能表现将进一步增强。

5 Wafer 是什么?

说完了论文中枯燥的理论部分。我们再次环形回到本文开篇的问题:什么是 Wafer。

从技术上看,它是晶圆级集成,是更大硬件的一种实现方案;而实际上从系统上看,它是计算、通信和内存访问比例的再平衡。我们过去分布式系统中的设计,将再次活跃在 Wafer 的舞台上。

Wafer 的出现,挑战了我们对计算范式的既有认知。其独特的 PLMR 特性,要求我们在软件层面重新设计算法和系统。MeshGEMM 和 MeshGEMV 只是开始,未来针对稀疏矩阵、卷积和高阶张量等运算的优化,仍有广阔的研究空间。

Wafer 是 AI 加速的工具,更可能是未来计算架构的缩影,是将「分布式」与「片上」融合的潜在新范式。其实可以大胆猜测,AI Chip 竞争走到白热化的当下,我们之前思考过的 Tensor Parallel, Data Parallel, Pipeline Parallel 以至于流行的 Expert Parallel,在新的架构面前将全部平等地回归本质,我们的思路重新回到最本质的「计算、通信、访存的调度」这一最基本也是最本质的系统问题上。

技术的每一步前进,都是「反覆相求」的过程。面对新的「形势」,我们不仅要「因事为制」,也不能忘记「本质」。不断探索计算与通信、硬件与软件、理论与实践之间的平衡点。

所以说,回顾最近几个月自己的 blog,还是感触颇多。在 DeepSeek 出现之后的一段时间我麻木也开摆,很多想做的事情被解决了,而个人力量确实又很弱。一度有一种做啥都没价值的错觉。

从系统本质思考,这个世界上大部分问题的解法都非常牵强。不论是 vllm 还是 sglang,都是在特殊的环境和假设中提出解决问题,他们都是神级的工程奇迹,但是我不认为这些会是我喜欢的系统。他们对系统层面的「计算、通信、访存」中的阻塞和资源浪费,提出了一些解决方案,但似乎完全没有解决。我身上是有一些完美主义和强迫症的,我很清楚地知道,这个世界上能让我满意的系统还不存在。当然,我也从来没有对自己的工作满意过,未来还有很长的路要走。

在完成了 Wafer 这个项目之后,我可能才真正碰到了 MLSystem 的门槛,我的研究生涯才刚刚起步。

所以,System 对我,对我们来说是什么?

是对复杂的硬件结构的合理抽象。

是让一切空闲资源被我们调度有方。

是技术发展之后还能对前人的一切贡献念念不忘。

我们用最扎实的工程支持各种应用的百花齐放,

与诸位一起看 NoC、Wafer Scale、Mesh Topology 在 2025 的当下重新回响。

回到开头,

化转环属,各有形势,反覆相求,因事为制。

这才是我最喜欢的 System。

(BTW, 本项目全部开源,包括模拟器代码、Cerebras WSE-2 上的代码以及详细文档全部公开)

References


WaferLLM:分布式 AI 系统的循环与突破
http://blog.chivier.site/2025-07-09/2025/WaferLLM:分布式-AI-系统的循环与突破/
Author
Chivier Humber
Posted on
July 9, 2025
Licensed under