Comprehensive Cerebras Note 1 - Go From A Simple Example

最近这段时间 Cerebras 相关的编程接口发生了比较大的变动,在 1.4 版本发布之后,我认为是一个比较恰当的时间点去重新整理一下现有的编程接口。在这个版本中,最重要的功能更新是更新了软件 color 补全硬件 color。

1. Cerebras 硬件架构:从晶圆到计算

1.1 晶圆级计算的技术背景

传统 GPU 集群在训练大规模 AI 模型时面临严重的通信瓶颈。当模型参数达到万亿级别,需要数百甚至数千块 GPU 协同工作,芯片间的数据传输延迟往往超过实际计算时间。

Cerebras 的技术方案是将所有计算单元集成在单一晶圆上,从根本上消除芯片间通信开销[1]

1.2 WSE:晶圆级计算引擎

Cerebras Wafer-Scale Engine (WSE) 是一块完整的硅晶圆,面积达到 46,225 平方毫米[2]。WSE-3 的关键技术指标:

  • 900,000 个 AI 核心:相当于约 50 块 NVIDIA H100 的计算核心总和
  • 44 GB 片上 SRAM:分布式内存架构,每个核心附近都有本地内存
  • 21 PB/s 内存带宽:是 H100 的 7,000 倍
  • 功耗 15kW:考虑到计算密度,能效比极高

1.3 PE 架构详解

WSE 上分布着数十万个处理单元(Processing Elements, PE)。每个 PE 都是一个完整的计算单元[3]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
┌─────────────────────────┐
│ PE 结构 │
├─────────────────────────┤
│ ┌─────────────────┐ │
│ │ 计算引擎 (CE) │ │ <- 执行指令
│ └────────┬────────┘ │
│ │ │
│ ┌────────▼────────┐ │
│ │ 路由器 │ │ <- 与邻居通信
│ └────────┬────────┘ │
│ │ │
│ ┌────────▼────────┐ │
│ │ 本地内存 (48KB) │ │ <- 存储数据和代码
│ └─────────────────┘ │
└─────────────────────────┘

关键特性:

  • 独立执行:每个 PE 有独立的程序计数器,可执行不同代码
  • 本地内存:48KB SRAM,访问延迟 1-2 个时钟周期
  • 无共享内存:PE 之间不共享内存,必须通过显式通信交换数据

1.4 通信机制:Wavelets 和 Colors

PE 之间的通信是 Cerebras 架构的核心特性[4]

Wavelets:数据传输单元

Wavelet 是一个 32 位的数据包,可以在单个时钟周期内从一个 PE 发送到相邻的 PE。这是 PE 间通信的基本单位:

1
2
3
4
5
6
PE[0,0] ──wavelet──> PE[0,1] ──wavelet──> PE[0,2]
│ │ │
wavelet wavelet wavelet
│ │ │
▼ ▼ ▼
PE[1,0] <─wavelet─── PE[1,1] <─wavelet─── PE[1,2]

Colors:虚拟通信通道

为了避免网络拥塞并实现高效的数据流控制,Cerebras 引入了 Colors 概念——虚拟的通信通道[5]

1
2
3
4
5
6
7
8
9
10
11
12
// 定义不同用途的 colors
const DATA_COLOR: color = @get_color(0); // 数据传输通道
const CTRL_COLOR: color = @get_color(1); // 控制信号通道
const SYNC_COLOR: color = @get_color(2); // 同步通道

// 配置路由:从西边接收 DATA_COLOR,向东边转发
@set_color_config(x, y, DATA_COLOR, .{
.routes = .{
.rx = .{WEST}, // 接收方向
.tx = .{EAST} // 发送方向
}
});

Colors 的技术优势

  • 无阻塞通信:不同 color 的数据流相互独立,避免相互干扰
  • 可编程路由:每个 PE 可以独立配置不同 color 的转发规则
  • 硬件支持:WSE3 支持最多 24 个硬件 colors

实际例子:数据流水线

假设我们要实现一个简单的数据处理流水线:

1
2
输入数据 → PE[0] → PE[1] → PE[2] → 输出结果
(预处理) (计算) (后处理)

代码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// 在 layout.csl 中配置路由
const PIPELINE_COLOR = @get_color(0);

// PE[0]: 接收输入,向东发送
@set_color_config(0, 0, PIPELINE_COLOR, .{
.routes = .{ .rx = .{RAMP}, .tx = .{EAST} }
});

// PE[1]: 从西接收,向东发送
@set_color_config(1, 0, PIPELINE_COLOR, .{
.routes = .{ .rx = .{WEST}, .tx = .{EAST} }
});

// PE[2]: 从西接收,输出到 RAMP
@set_color_config(2, 0, PIPELINE_COLOR, .{
.routes = .{ .rx = .{WEST}, .tx = .{RAMP} }
});

1.5 I/O 架构:数据如何进出 WSE

WSE 这么大,数据是如何进出的?Cerebras 使用了特殊的 I/O 设计:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
Host CPU/Memory


┌─────────────┐
I/O 通道 │ (100 Gbps 以太网 × 16)
└─────┬───────┘


┌─────────────────────────────────┐
WSE Fabric
│ ┌───┬───┬───┬───┬─ ─ ─ │
│ │I/OI/OI/OI/O│ │ <-4列预留给I/O
│ ├───┼───┼───┼───┼───┬─ ─ ─ │
│ │ │ │ │ │PE │ │ <- 你的计算PE从这里开始
│ └───┴───┴───┴───┴───┴─ ─ ─ │
└─────────────────────────────────┘

这就是为什么使用 memcpy 时必须设置 --fabric-offsets=4,1——前 4 列被 I/O 系统占用[6]

1.6 三代 WSE 演进

Cerebras 已经发布了三代 WSE,每一代都有显著提升[7]

特性对比 WSE-1 (2019) WSE-2 (2021) WSE-3 (2024)
制程 16nm 7nm 5nm
核心数 40万 85万 90万
性能提升 基准 2.1× 2.25×
内存 18GB 40GB 44GB
新特性 - Wafer-Scale Cluster MemoryX 外部内存

2. 开发环境设置

2.1 安装 Cerebras SDK

1
2
3
4
5
6
# 下载 SDK
wget https://cerebras.ai/sdk/cerebras-sdk-latest.tar.gz

# 解压并设置环境
tar -xzf cerebras-sdk-latest.tar.gz
source cerebras-sdk/setup_env.sh

2.2 基本工具链

  • CSL (Cerebras Software Language): 类C语言,用于编写设备代码
  • cslc: CSL 编译器
  • SdkRuntime: Python 主机运行时
  • cs_python: Cerebras 定制的 Python 解释器

3. 编程模型基础

3.1 CSL 编程模型的三个组成部分

CSL 编程模型由三个核心文件组成,每个文件负责不同的功能层面:

1
2
3
4
5
6
7
8
9
single-pe-example/
├── layout.csl # 1. 布局配置:定义PE网格和参数
├── pe_program.csl # 2. PE程序:实际的计算逻辑
├── run.py # 3. Host程序:控制执行流程
├── commands_wse3.sh # 编译脚本
└── out/ # 编译输出目录
├── out.elf # 可执行文件
├── out.json # 编译元数据
└── ...

三个文件的职责分工:

  1. layout.csl - 硬件配置层:定义 PE 网格拓扑、分配计算资源、传递编译时参数
  2. pe_program.csl - 计算内核层:实现具体的计算逻辑,在每个 PE 上独立执行
  3. run.py - 主机控制层:负责数据准备、设备控制、结果收集和验证

3.2 理解编译参数

让我们详细解析 commands_wse3.sh 中的每个参数:

1
2
3
4
5
6
7
8
9
10
set -e

cslc ./layout.csl \ # 入口文件
--arch wse3 \ # 目标架构:wse1, wse2, wse3
--fabric-dims=12,7 \ # Fabric尺寸:宽×高
--fabric-offsets=4,1 \ # PE网格在Fabric中的偏移
--memcpy \ # 启用Host-Device数据传输
--channels=1 \ # I/O通道数(1-16)
-o out \ # 输出目录
--color-out="out.color" # Color分配信息输出

参数详解:

--fabric-dims=12,7:定义可用的 fabric 区域大小

  • 仿真环境:可以使用较小的尺寸(如 12×7)以加快编译
  • 硬件环境:必须匹配实际系统(CS-3: 762×1176)
  • 最小要求:width + 4, height + 1(因为 memcpy 占用)

--fabric-offsets=4,1:PE 网格的起始位置

  • X≥4:前 4 列被 I/O 系统占用
  • Y≥1:第 0 行被系统保留
  • 这是使用 memcpy 的硬性要求

--memcpy:启用数据传输功能

  • 没有这个参数,无法在 Host 和 Device 间传输数据
  • 会自动配置必要的 I/O 基础设施

--channels=1:I/O 通道数量

  • 更多通道 = 更高带宽
  • 最多 16 个通道(对应 16 个 100Gbps 连接)

不同场景的编译配置:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 场景1:4×4 PE网格的仿真
cslc ./layout.csl \
--arch wse3 \
--fabric-dims=15,8 \
--fabric-offsets=4,1 \
--params=width:4,height:4 \ # 传递参数给layout.csl
--memcpy --channels=1 -o out

# 场景2:生产环境部署
cslc ./layout.csl \
--arch wse3 \
--fabric-dims=762,1176 \ # CS-3实际尺寸
--fabric-offsets=4,1 \
--params=width:16,height:16 \
--memcpy --channels=16 -o out # 使用全部16个通道

4. 单 PE 编程示例

4.1 概述

这是一个最简单的 CSL 程序示例,演示了:

  • 基本的 CSL 语法和结构
  • Host-Device 数据传输
  • PE 上的简单计算
  • 符号导出机制

4.2 代码详解

layout.csl - Fabric 布局配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
// WSE3 Layout Template - Fabric Layout Configuration
// 导入 memcpy 模块用于数据传输
// width=1, height=1 表示单个 PE
const memcpy = @import_module("<memcpy/get_params>", .{ .width = 1, .height = 1 });

// layout 块定义 PE 布局
layout {
// 设置 1x1 的 PE 网格
@set_rectangle(1, 1);

// 为坐标 (0,0) 的 PE 设置代码和参数
@set_tile_code(0, 0, "pe_program.csl", .{
// memcpy 配置参数
.memcpy_params = memcpy.get_params(0),

// 自定义参数传递给 PE
.PARAM1 = 100, // 乘法系数
.PARAM2 = 16 // 加法偏移
});

// 导出符号供 host 访问
// 常见错误:忘记导出符号会导致 "undefined symbol" 错误
@export_name("data", [*]f32, false); // 输入数据数组
@export_name("result", [*]f32, false); // 输出结果数组
@export_name("compute", fn()void); // 计算函数
}

pe_program.csl - PE 计算程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// 接收来自 layout 的参数
param memcpy_params: comptime_struct;
param PARAM1: i16;
param PARAM2: i16;

// 导入 memcpy 模块
const sys_mod = @import_module("<memcpy/memcpy>", memcpy_params);

// 常量定义
const DATA_SIZE = 1024;

// 内存数组定义
var data: [DATA_SIZE]f32 = @zeros([DATA_SIZE]f32);
var result: [DATA_SIZE]f32 = @zeros([DATA_SIZE]f32);

// 创建指针(导出数组必须使用指针)
// 常见错误:直接导出数组会报 "cannot export symbol of type '[N]T'"
const data_ptr: [*]f32 = &data;
const result_ptr: [*]f32 = &result;

// 计算函数
fn compute() void {
// 执行计算:result = data * PARAM1 + PARAM2
var i: u32 = 0;
while (i < DATA_SIZE) : (i += 1) {
result[i] = data[i] * @as(f32, PARAM1) + @as(f32, PARAM2);
}

// 解除命令流阻塞(必须调用)
// 常见错误:忘记调用会导致程序挂起
sys_mod.unblock_cmd_stream();
}

// 编译时导出符号
comptime {
@export_symbol(data_ptr, "data");
@export_symbol(result_ptr, "result");
@export_symbol(compute);
}

run.py - Host 运行程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import argparse
import numpy as np
from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType, MemcpyOrder

# 数据大小(必须与 CSL 代码匹配)
# 常见错误:大小不匹配会导致数据传输失败或结果错误
DATA_SIZE = 1024
PARAM1 = 100
PARAM2 = 16

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--name', default='out', help='编译输出目录')
parser.add_argument('--cmaddr', help='CS 系统地址')
args = parser.parse_args()

# 初始化数据
print(f"初始化 {DATA_SIZE} 个数据元素...")
input_data = np.random.rand(DATA_SIZE).astype(np.float32)

# 创建运行时
runner = SdkRuntime(args.name, cmaddr=args.cmaddr)

# 获取符号
data_symbol = runner.get_id('data')
result_symbol = runner.get_id('result')

# 加载并启动
print("加载程序到设备...")
runner.load()
runner.run()

# 传输数据到设备
print("传输数据到设备...")
runner.memcpy_h2d(data_symbol, input_data, 0, 0, 1, 1, DATA_SIZE,
streaming=False,
order=MemcpyOrder.ROW_MAJOR,
data_type=MemcpyDataType.MEMCPY_32BIT,
nonblock=False)

# 执行计算
print("执行计算...")
runner.launch('compute', nonblock=False)

# 获取结果
print("获取结果...")
result = np.zeros(DATA_SIZE, dtype=np.float32)
runner.memcpy_d2h(result, result_symbol, 0, 0, 1, 1, DATA_SIZE,
streaming=False,
order=MemcpyOrder.ROW_MAJOR,
data_type=MemcpyDataType.MEMCPY_32BIT,
nonblock=False)

# 停止执行
runner.stop()

# 验证结果
print("验证结果...")
expected = input_data * PARAM1 + PARAM2
if np.allclose(result, expected, rtol=1e-5, atol=1e-5):
print("成功!结果正确。")
else:
print("失败!结果不匹配。")
# 调试信息
print(f"样本输入: {input_data[:5]}")
print(f"实际结果: {result[:5]}")
print(f"期望结果: {expected[:5]}")

return 0

if __name__ == '__main__':
exit(main())

4.3 运行示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 编译
./commands_wse3.sh

# 运行
cs_python ./run.py --name out

# 输出
初始化 1024 个数据元素...
加载程序到设备...
传输数据到设备...
执行计算...
获取结果...
验证结果...
成功!结果正确。

5. 多 PE 编程示例

5.1 概述

基于 gemv-05-multiple-pes 模式[8],这个示例演示了:

  • 创建一维 PE 数组
  • 数据复制到多个 PE
  • PE 特定参数配置
  • 从多个 PE 收集结果

5.2 代码详解

layout.csl - 多 PE 布局

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
// PE 网格宽度参数
param width: i16;

// 导入 memcpy,配置为 width x 1 的网格
const memcpy = @import_module("<memcpy/get_params>", .{
.width = width,
.height = 1
});

layout {
// 设置 width x 1 的 PE 网格
@set_rectangle(width, 1);

// 为每个 PE 配置代码和参数
for (@range(i16, width)) |x| {
@set_tile_code(x, 0, "pe_program.csl", .{
// PE 特定的 memcpy 参数
.memcpy_params = memcpy.get_params(x),
// PE 坐标
.px = x,
// PE 特定的计算参数
.PARAM1 = 100 + x * 10,
.PARAM2 = 16 + x * 2,
// 每个 PE 的数据大小
.DATA_SIZE = 256
});
}

// 导出分布式符号(第三个参数为 true)
// 常见错误:多PE时忘记设置为 true 会导致符号冲突
@export_name("data", [*]f32, true);
@export_name("result", [*]f32, true);
@export_name("compute", fn()void);
}

pe_program.csl - PE 计算程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
// 参数接收
param memcpy_params: comptime_struct;
param px: i16; // PE x 坐标
param DATA_SIZE: i16; // 数据大小
param PARAM1: i16; // 计算参数 1
param PARAM2: i16; // 计算参数 2

// 导入模块
const sys_mod = @import_module("<memcpy/memcpy>", memcpy_params);

// 内存数组
var data: [DATA_SIZE]f32 = @zeros([DATA_SIZE]f32);
var result: [DATA_SIZE]f32 = @zeros([DATA_SIZE]f32);

// 指针
var data_ptr: [*]f32 = &data;
var result_ptr: [*]f32 = &result;

// 计算函数
fn compute() void {
// 每个 PE 使用自己的参数进行计算
for (@range(i16, DATA_SIZE)) |i| {
result[i] = data[i] * @as(f32, PARAM1) + @as(f32, PARAM2);
}

sys_mod.unblock_cmd_stream();
}

// 导出符号
comptime {
@export_symbol(data_ptr, "data");
@export_symbol(result_ptr, "result");
@export_symbol(compute);
}

run_multi_pe.py - 多 PE Host 程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
import json
import numpy as np
from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType, MemcpyOrder

# 常量
DATA_SIZE_PER_PE = 256
BASE_PARAM1 = 100
BASE_PARAM2 = 16

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--name', default='out')
parser.add_argument('--cmaddr')
args = parser.parse_args()

# 读取编译元数据获取 PE 数量
with open(f"{args.name}/out.json", "r") as f:
compile_data = json.load(f)

width = int(compile_data['params']['width'])
print(f"运行 {width} 个 PE")

# 创建测试数据
input_data = np.random.rand(DATA_SIZE_PER_PE).astype(np.float32)

# 创建运行时
runner = SdkRuntime(args.name, cmaddr=args.cmaddr)

# 获取符号
data_symbol = runner.get_id('data')
result_symbol = runner.get_id('result')

# 加载并启动
print("加载程序...")
runner.load()
runner.run()

# 使用 np.tile 复制数据到所有 PE
print(f"传输数据到 {width} 个 PE...")
tiled_data = np.tile(input_data, width)

# 常见错误:PE网格维度与实际PE数量不匹配
runner.memcpy_h2d(
data_symbol,
tiled_data,
0, 0, # 起始 PE 坐标
width, 1, # PE 网格维度
DATA_SIZE_PER_PE, # 每个 PE 的元素数
streaming=False,
order=MemcpyOrder.ROW_MAJOR,
data_type=MemcpyDataType.MEMCPY_32BIT,
nonblock=False
)

# 执行计算
print("执行计算...")
runner.launch('compute', nonblock=False)

# 获取所有 PE 的结果
print(f"从 {width} 个 PE 获取结果...")
total_size = DATA_SIZE_PER_PE * width
result = np.zeros(total_size, dtype=np.float32)

runner.memcpy_d2h(
result,
result_symbol,
0, 0,
width, 1,
DATA_SIZE_PER_PE,
streaming=False,
order=MemcpyOrder.ROW_MAJOR,
data_type=MemcpyDataType.MEMCPY_32BIT,
nonblock=False
)

runner.stop()

# 验证每个 PE 的结果
print("\n验证结果...")
all_passed = True

for pe_idx in range(width):
# 获取 PE 特定参数
pe_param1 = BASE_PARAM1 + pe_idx * 10
pe_param2 = BASE_PARAM2 + pe_idx * 2

# 计算期望结果
expected = input_data * pe_param1 + pe_param2

# 提取该 PE 的结果
start_idx = pe_idx * DATA_SIZE_PER_PE
end_idx = start_idx + DATA_SIZE_PER_PE
pe_result = result[start_idx:end_idx]

# 验证
if np.allclose(pe_result, expected, rtol=1e-5, atol=1e-5):
print(f" PE[{pe_idx}]: 通过 (PARAM1={pe_param1}, PARAM2={pe_param2})")
else:
print(f" PE[{pe_idx}]: 失败")
all_passed = False

if all_passed:
print(f"\n✓ 成功!所有 {width} 个 PE 计算正确。")
else:
print("\n✗ 失败!部分 PE 计算错误。")

return 0 if all_passed else 1

if __name__ == '__main__':
exit(main())

commands_wse3.sh - 编译脚本

1
2
3
4
5
6
7
8
9
10
11
12
set -e

WIDTH=4 # PE 数量

cslc ./layout.csl \
--arch wse3 \
--fabric-dims=15,7 \
--fabric-offsets=4,1 \
--params=width:${WIDTH} \
--memcpy \
--channels=1 \
-o out

5.3 运行示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 编译
./commands_wse3.sh

# 运行
cs_python ./run_multi_pe.py --name out

# 输出
运行 4 个 PE
加载程序...
传输数据到 4 个 PE...
执行计算...
从 4 个 PE 获取结果...

验证结果...
PE[0]: 通过 (PARAM1=100, PARAM2=16)
PE[1]: 通过 (PARAM1=110, PARAM2=18)
PE[2]: 通过 (PARAM1=120, PARAM2=20)
PE[3]: 通过 (PARAM1=130, PARAM2=22)

✓ 成功!所有 4 个 PE 计算正确。

6. API 参考手册

本章详细介绍第 4 章和第 5 章示例中使用的所有 API。

6.1 CSL 内置函数

6.1.1 模块导入

@import_module

1
2
@import_module(filename)
@import_module(filename, param_binding)

导入 CSL 模块。

参数:

  • filename:模块文件名(comptime string)
  • param_binding:参数绑定(comptime 匿名结构体)

示例:

1
2
3
4
5
// 导入标准库
const memcpy = @import_module("<memcpy/get_params>", .{ .width = 1, .height = 1 });

// 导入自定义模块
const math = @import_module("math.csl", .{ .precision = 32 });

6.1.2 布局配置

@set_rectangle

1
@set_rectangle(width, height)

设置 PE 网格的矩形尺寸。必须在 layout 块中调用。

参数:

  • width:网格宽度(comptime 整数)
  • height:网格高度(comptime 整数)

@set_tile_code

1
2
3
@set_tile_code(x_coord, y_coord)
@set_tile_code(x_coord, y_coord, filename)
@set_tile_code(x_coord, y_coord, filename, param_binding)

为指定坐标的 PE 设置执行代码。

参数:

  • x_coord, y_coord:PE 坐标(comptime 整数)
  • filename:代码文件名(comptime string)
  • param_binding:参数绑定(comptime 匿名结构体)

6.1.3 符号导出

@export_name

1
2
@export_name(name, type)
@export_name(name, type, is_distributed)

在 layout 中声明要导出的符号。

参数:

  • name:符号名称(string)
  • type:符号类型
  • is_distributed:是否为分布式符号(布尔值,默认 false)

示例:

1
2
3
4
// 单 PE 导出
@export_name("data", [*]f32, false);
// 多 PE 分布式导出
@export_name("data", [*]f32, true);

@export_symbol

1
2
@export_symbol(symbol)
@export_symbol(symbol, name)

在 PE 代码中导出符号供 host 访问。

参数:

  • symbol:要导出的变量或函数
  • name:导出名称(可选)

示例:

1
2
3
4
comptime {
@export_symbol(data_ptr, "data");
@export_symbol(compute); // 使用函数名作为导出名
}

6.1.4 数据初始化

@zeros

1
@zeros(array_type)

创建零初始化的数组。

参数:

  • array_type:数组类型表达式

示例:

1
var data: [1024]f32 = @zeros([1024]f32);

@constants

1
@constants(array_type, value)

创建常量初始化的数组。

参数:

  • array_type:数组类型
  • value:初始化值

示例:

1
var x = @constants([N]f32, 1.0);  // 所有元素初始化为 1.0

6.1.5 类型转换

@as

1
@as(type, value)

显式类型转换。

参数:

  • type:目标类型
  • value:要转换的值

示例:

1
result[i] = data[i] * @as(f32, PARAM1);  // 将 i16 转换为 f32

@bitcast

1
@bitcast(type, value)

位级别类型转换(不改变位模式)。

6.1.6 循环和范围

@range

1
2
3
@range(type, count)
@range(type, start, end)
@range(type, start, end, step)

创建循环范围。

参数:

  • type:索引类型
  • count:元素数量(从 0 开始)
  • start:起始值
  • end:结束值(不包含)
  • step:步长

示例:

1
2
3
4
5
6
7
8
9
// 0 到 width-1
for (@range(i16, width)) |x| {
// ...
}

// 0 到 9
for (@range(u32, 10)) |i| {
// ...
}

6.1.7 颜色和通信

@get_color

1
@get_color(color_id)

获取通信颜色。

参数:

  • color_id:颜色 ID(0-23)

@set_color_config

1
@set_color_config(x, y, color, config)

配置 PE 的颜色路由。

参数:

  • x, y:PE 坐标
  • color:颜色对象
  • config:路由配置

示例:

1
2
3
4
5
6
7
const EAST_COLOR = @get_color(1);
@set_color_config(x, y, EAST_COLOR, .{
.routes = .{
.rx = .{WEST}, // 从西边接收
.tx = .{EAST} // 向东边发送
}
});

6.2 编译时功能

6.2.1 参数声明

param

1
2
param name: type;
param name: type = default_value;

声明编译时参数,可从 layout 或命令行传入。

示例:

1
2
param width: i16;
param PARAM1: i16 = 100; // 带默认值

6.2.2 编译时块

comptime

1
2
3
comptime {
// 编译时执行的代码
}

标记在编译时执行的代码块。

6.3 Python Host API

6.3.1 运行时管理

SdkRuntime

1
2
3
from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime

runner = SdkRuntime(compile_dir, cmaddr=None)

创建 SDK 运行时实例。

参数:

  • compile_dir:编译输出目录
  • cmaddr:CS 系统地址(可选,格式 “IP:port”)

基本控制方法

1
2
3
runner.load()      # 加载程序到设备
runner.run() # 启动设备执行
runner.stop() # 停止程序执行

6.3.2 符号访问

get_id

1
symbol_id = runner.get_id(symbol_name)

获取导出符号的 ID。

参数:

  • symbol_name:符号名称(字符串)

返回:

  • 符号 ID(整数)

6.3.3 数据传输

memcpy_h2d - Host 到 Device

1
2
3
4
5
runner.memcpy_h2d(symbol, data, px, py, width, height, elem_per_pe,
streaming=False,
order=MemcpyOrder.ROW_MAJOR,
data_type=MemcpyDataType.MEMCPY_32BIT,
nonblock=False)

参数:

  • symbol:目标符号 ID
  • data:源数据(numpy 数组)
  • px, py:起始 PE 坐标
  • width, height:PE 网格尺寸
  • elem_per_pe:每个 PE 的元素数量
  • streaming:是否使用流式传输
  • order:数据布局顺序
  • data_type:数据类型(16 位或 32 位)
  • nonblock:是否非阻塞

memcpy_d2h - Device 到 Host

1
2
3
4
5
runner.memcpy_d2h(result, symbol, px, py, width, height, elem_per_pe,
streaming=False,
order=MemcpyOrder.ROW_MAJOR,
data_type=MemcpyDataType.MEMCPY_32BIT,
nonblock=False)

参数:

  • result:目标数组(numpy 数组)
  • 其他参数同 memcpy_h2d

6.3.4 函数调用

launch

1
runner.launch(function_name, nonblock=False)

在设备上执行函数。

参数:

  • function_name:函数名称(字符串)
  • nonblock:是否非阻塞执行

6.3.5 枚举类型

MemcpyDataType

1
2
3
4
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyDataType

MemcpyDataType.MEMCPY_16BIT # 16 位数据
MemcpyDataType.MEMCPY_32BIT # 32 位数据

MemcpyOrder

1
2
3
4
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyOrder

MemcpyOrder.ROW_MAJOR # 行优先
MemcpyOrder.COL_MAJOR # 列优先

6.4 Memcpy 模块

6.4.1 模块导入

1
2
3
4
5
6
const memcpy = @import_module("<memcpy/get_params>", .{
.width = width,
.height = height,
.MEMCPYH2D_1 = color, // 可选:自定义颜色
.MEMCPYD2H_1 = color // 可选:自定义颜色
});

6.4.2 获取 PE 参数

1
memcpy.get_params(pe_id)

获取特定 PE 的 memcpy 参数。

参数:

  • pe_id:PE 索引(对于 2D 网格,使用 y * width + x

6.4.3 在 PE 中使用

1
2
3
4
5
param memcpy_params: comptime_struct;
const sys_mod = @import_module("<memcpy/memcpy>", memcpy_params);

// 在计算函数末尾必须调用
sys_mod.unblock_cmd_stream();

6.5 常见模式总结

单 PE 数据传输

1
2
3
4
5
# H2D
runner.memcpy_h2d(symbol, data, 0, 0, 1, 1, size, ...)

# D2H
runner.memcpy_d2h(result, symbol, 0, 0, 1, 1, size, ...)

多 PE 数据传输(复制模式)

1
2
3
# 使用 np.tile 复制数据到所有 PE
tiled_data = np.tile(data, num_pes)
runner.memcpy_h2d(symbol, tiled_data, 0, 0, num_pes, 1, size_per_pe, ...)

多 PE 数据传输(分区模式)

1
2
3
# 不同数据到不同 PE
partitioned_data = np.concatenate([pe0_data, pe1_data, ...])
runner.memcpy_h2d(symbol, partitioned_data, 0, 0, num_pes, 1, size_per_pe, ...)

7. 关键概念总结

6.1 记住这些要点

  1. 三个文件,三个角色

    • layout.csl:配置 PE 布局和参数
    • pe_program.csl:实现计算逻辑
    • run.py:控制执行流程
  2. Fabric 偏移是硬性要求

    • 使用 memcpy 必须 --fabric-offsets=4,1
    • 前 4 列和第 1 行被系统占用
  3. 数组导出必须用指针

    • 错误:@export_symbol(data)
    • 正确:@export_symbol(&data)
  4. 别忘了 unblock_cmd_stream()

    • 每个计算函数最后必须调用
    • 否则程序会挂起
  5. 多 PE 符号要分布式导出

    • 单 PE:@export_name("data", [*]f32, false)
    • 多 PE:@export_name("data", [*]f32, true)

6.2 从这里开始

  1. 先跑通单 PE 示例:理解基本概念
  2. 尝试修改计算逻辑:在 compute() 中实现自己的算法
  3. 扩展到多 PE:体验并行计算的威力
  4. 探索高级特性:PE 间通信、流式传输等

8. 代码获取

完整的示例代码即将发布到 GitHub:

[TODO: 添加 GitHub 仓库链接]

仓库将包含:

  • 本教程的所有示例代码
  • 更多高级示例
  • 常见应用模板
  • 性能优化技巧

参考文献


Comprehensive Cerebras Note 1 - Go From A Simple Example
http://blog.chivier.site/2025-07-22/2025/Comprehensive-Cerebras-Note-1---Go-From-A-Simple-Example/
Author
Chivier Humber
Posted on
July 22, 2025
Licensed under