最近这段时间 Cerebras 相关的编程接口发生了比较大的变动,在 1.4 版本发布之后,我认为是一个比较恰当的时间点去重新整理一下现有的编程接口。在这个版本中,最重要的功能更新是更新了软件 color 补全硬件 color。
1. Cerebras 硬件架构:从晶圆到计算
1.1 晶圆级计算的技术背景
传统 GPU 集群在训练大规模 AI 模型时面临严重的通信瓶颈。当模型参数达到万亿级别,需要数百甚至数千块 GPU 协同工作,芯片间的数据传输延迟往往超过实际计算时间。
Cerebras 的技术方案是将所有计算单元集成在单一晶圆上,从根本上消除芯片间通信开销。
1.2 WSE:晶圆级计算引擎
Cerebras Wafer-Scale Engine (WSE) 是一块完整的硅晶圆,面积达到 46,225 平方毫米。WSE-3 的关键技术指标:
- 900,000 个 AI 核心:相当于约 50 块 NVIDIA H100 的计算核心总和
- 44 GB 片上 SRAM:分布式内存架构,每个核心附近都有本地内存
- 21 PB/s 内存带宽:是 H100 的 7,000 倍
- 功耗 15kW:考虑到计算密度,能效比极高
1.3 PE 架构详解
WSE 上分布着数十万个处理单元(Processing Elements, PE)。每个 PE 都是一个完整的计算单元:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| ┌─────────────────────────┐ │ PE 结构 │ ├─────────────────────────┤ │ ┌─────────────────┐ │ │ │ 计算引擎 (CE) │ │ <- 执行指令 │ └────────┬────────┘ │ │ │ │ │ ┌────────▼────────┐ │ │ │ 路由器 │ │ <- 与邻居通信 │ └────────┬────────┘ │ │ │ │ │ ┌────────▼────────┐ │ │ │ 本地内存 (48KB) │ │ <- 存储数据和代码 │ └─────────────────┘ │ └─────────────────────────┘
|
关键特性:
- 独立执行:每个 PE 有独立的程序计数器,可执行不同代码
- 本地内存:48KB SRAM,访问延迟 1-2 个时钟周期
- 无共享内存:PE 之间不共享内存,必须通过显式通信交换数据
1.4 通信机制:Wavelets 和 Colors
PE 之间的通信是 Cerebras 架构的核心特性。
Wavelets:数据传输单元
Wavelet 是一个 32 位的数据包,可以在单个时钟周期内从一个 PE 发送到相邻的 PE。这是 PE 间通信的基本单位:
1 2 3 4 5 6
| PE[0,0] ──wavelet──> PE[0,1] ──wavelet──> PE[0,2] │ │ │ wavelet wavelet wavelet │ │ │ ▼ ▼ ▼ PE[1,0] <─wavelet─── PE[1,1] <─wavelet─── PE[1,2]
|
Colors:虚拟通信通道
为了避免网络拥塞并实现高效的数据流控制,Cerebras 引入了 Colors 概念——虚拟的通信通道:
1 2 3 4 5 6 7 8 9 10 11 12
| // 定义不同用途的 colors const DATA_COLOR: color = @get_color(0); // 数据传输通道 const CTRL_COLOR: color = @get_color(1); // 控制信号通道 const SYNC_COLOR: color = @get_color(2); // 同步通道
// 配置路由:从西边接收 DATA_COLOR,向东边转发 @set_color_config(x, y, DATA_COLOR, .{ .routes = .{ .rx = .{WEST}, // 接收方向 .tx = .{EAST} // 发送方向 } });
|
Colors 的技术优势:
- 无阻塞通信:不同 color 的数据流相互独立,避免相互干扰
- 可编程路由:每个 PE 可以独立配置不同 color 的转发规则
- 硬件支持:WSE3 支持最多 24 个硬件 colors
实际例子:数据流水线
假设我们要实现一个简单的数据处理流水线:
1 2
| 输入数据 → PE[0] → PE[1] → PE[2] → 输出结果 (预处理) (计算) (后处理)
|
代码实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| // 在 layout.csl 中配置路由 const PIPELINE_COLOR = @get_color(0);
// PE[0]: 接收输入,向东发送 @set_color_config(0, 0, PIPELINE_COLOR, .{ .routes = .{ .rx = .{RAMP}, .tx = .{EAST} } });
// PE[1]: 从西接收,向东发送 @set_color_config(1, 0, PIPELINE_COLOR, .{ .routes = .{ .rx = .{WEST}, .tx = .{EAST} } });
// PE[2]: 从西接收,输出到 RAMP @set_color_config(2, 0, PIPELINE_COLOR, .{ .routes = .{ .rx = .{WEST}, .tx = .{RAMP} } });
|
1.5 I/O 架构:数据如何进出 WSE
WSE 这么大,数据是如何进出的?Cerebras 使用了特殊的 I/O 设计:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| Host CPU/Memory │ ▼ ┌─────────────┐ │ I/O 通道 │ (100 Gbps 以太网 × 16) └─────┬───────┘ │ ▼ ┌─────────────────────────────────┐ │ WSE Fabric │ │ ┌───┬───┬───┬───┬─ ─ ─ │ │ │I/O│I/O│I/O│I/O│ │ <- 前4列预留给I/O │ ├───┼───┼───┼───┼───┬─ ─ ─ │ │ │ │ │ │ │PE │ │ <- 你的计算PE从这里开始 │ └───┴───┴───┴───┴───┴─ ─ ─ │ └─────────────────────────────────┘
|
这就是为什么使用 memcpy 时必须设置 --fabric-offsets=4,1——前 4 列被 I/O 系统占用。
1.6 三代 WSE 演进
Cerebras 已经发布了三代 WSE,每一代都有显著提升:
| 特性对比 |
WSE-1 (2019) |
WSE-2 (2021) |
WSE-3 (2024) |
| 制程 |
16nm |
7nm |
5nm |
| 核心数 |
40万 |
85万 |
90万 |
| 性能提升 |
基准 |
2.1× |
2.25× |
| 内存 |
18GB |
40GB |
44GB |
| 新特性 |
- |
Wafer-Scale Cluster |
MemoryX 外部内存 |
2. 开发环境设置
2.1 安装 Cerebras SDK
1 2 3 4 5 6
| wget https://cerebras.ai/sdk/cerebras-sdk-latest.tar.gz
tar -xzf cerebras-sdk-latest.tar.gz source cerebras-sdk/setup_env.sh
|
2.2 基本工具链
- CSL (Cerebras Software Language): 类C语言,用于编写设备代码
- cslc: CSL 编译器
- SdkRuntime: Python 主机运行时
- cs_python: Cerebras 定制的 Python 解释器
3. 编程模型基础
3.1 CSL 编程模型的三个组成部分
CSL 编程模型由三个核心文件组成,每个文件负责不同的功能层面:
1 2 3 4 5 6 7 8 9
| single-pe-example/ ├── layout.csl # 1. 布局配置:定义PE网格和参数 ├── pe_program.csl # 2. PE程序:实际的计算逻辑 ├── run.py # 3. Host程序:控制执行流程 ├── commands_wse3.sh # 编译脚本 └── out/ # 编译输出目录 ├── out.elf # 可执行文件 ├── out.json # 编译元数据 └── ...
|
三个文件的职责分工:
- layout.csl - 硬件配置层:定义 PE 网格拓扑、分配计算资源、传递编译时参数
- pe_program.csl - 计算内核层:实现具体的计算逻辑,在每个 PE 上独立执行
- run.py - 主机控制层:负责数据准备、设备控制、结果收集和验证
3.2 理解编译参数
让我们详细解析 commands_wse3.sh 中的每个参数:
1 2 3 4 5 6 7 8 9 10
| set -e
cslc ./layout.csl \ --arch wse3 \ --fabric-dims=12,7 \ --fabric-offsets=4,1 \ --memcpy \ --channels=1 \ -o out \ --color-out="out.color"
|
参数详解:
--fabric-dims=12,7:定义可用的 fabric 区域大小
- 仿真环境:可以使用较小的尺寸(如 12×7)以加快编译
- 硬件环境:必须匹配实际系统(CS-3: 762×1176)
- 最小要求:
width + 4, height + 1(因为 memcpy 占用)
--fabric-offsets=4,1:PE 网格的起始位置
- X≥4:前 4 列被 I/O 系统占用
- Y≥1:第 0 行被系统保留
- 这是使用 memcpy 的硬性要求
--memcpy:启用数据传输功能
- 没有这个参数,无法在 Host 和 Device 间传输数据
- 会自动配置必要的 I/O 基础设施
--channels=1:I/O 通道数量
- 更多通道 = 更高带宽
- 最多 16 个通道(对应 16 个 100Gbps 连接)
不同场景的编译配置:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| cslc ./layout.csl \ --arch wse3 \ --fabric-dims=15,8 \ --fabric-offsets=4,1 \ --params=width:4,height:4 \ --memcpy --channels=1 -o out
cslc ./layout.csl \ --arch wse3 \ --fabric-dims=762,1176 \ --fabric-offsets=4,1 \ --params=width:16,height:16 \ --memcpy --channels=16 -o out
|
4. 单 PE 编程示例
4.1 概述
这是一个最简单的 CSL 程序示例,演示了:
- 基本的 CSL 语法和结构
- Host-Device 数据传输
- PE 上的简单计算
- 符号导出机制
4.2 代码详解
layout.csl - Fabric 布局配置
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
| // WSE3 Layout Template - Fabric Layout Configuration // 导入 memcpy 模块用于数据传输 // width=1, height=1 表示单个 PE const memcpy = @import_module("<memcpy/get_params>", .{ .width = 1, .height = 1 });
// layout 块定义 PE 布局 layout { // 设置 1x1 的 PE 网格 @set_rectangle(1, 1); // 为坐标 (0,0) 的 PE 设置代码和参数 @set_tile_code(0, 0, "pe_program.csl", .{ // memcpy 配置参数 .memcpy_params = memcpy.get_params(0), // 自定义参数传递给 PE .PARAM1 = 100, // 乘法系数 .PARAM2 = 16 // 加法偏移 }); // 导出符号供 host 访问 // 常见错误:忘记导出符号会导致 "undefined symbol" 错误 @export_name("data", [*]f32, false); // 输入数据数组 @export_name("result", [*]f32, false); // 输出结果数组 @export_name("compute", fn()void); // 计算函数 }
|
pe_program.csl - PE 计算程序
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
| // 接收来自 layout 的参数 param memcpy_params: comptime_struct; param PARAM1: i16; param PARAM2: i16;
// 导入 memcpy 模块 const sys_mod = @import_module("<memcpy/memcpy>", memcpy_params);
// 常量定义 const DATA_SIZE = 1024;
// 内存数组定义 var data: [DATA_SIZE]f32 = @zeros([DATA_SIZE]f32); var result: [DATA_SIZE]f32 = @zeros([DATA_SIZE]f32);
// 创建指针(导出数组必须使用指针) // 常见错误:直接导出数组会报 "cannot export symbol of type '[N]T'" const data_ptr: [*]f32 = &data; const result_ptr: [*]f32 = &result;
// 计算函数 fn compute() void { // 执行计算:result = data * PARAM1 + PARAM2 var i: u32 = 0; while (i < DATA_SIZE) : (i += 1) { result[i] = data[i] * @as(f32, PARAM1) + @as(f32, PARAM2); } // 解除命令流阻塞(必须调用) // 常见错误:忘记调用会导致程序挂起 sys_mod.unblock_cmd_stream(); }
// 编译时导出符号 comptime { @export_symbol(data_ptr, "data"); @export_symbol(result_ptr, "result"); @export_symbol(compute); }
|
run.py - Host 运行程序
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
| import argparse import numpy as np from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType, MemcpyOrder
DATA_SIZE = 1024 PARAM1 = 100 PARAM2 = 16
def main(): parser = argparse.ArgumentParser() parser.add_argument('--name', default='out', help='编译输出目录') parser.add_argument('--cmaddr', help='CS 系统地址') args = parser.parse_args() print(f"初始化 {DATA_SIZE} 个数据元素...") input_data = np.random.rand(DATA_SIZE).astype(np.float32) runner = SdkRuntime(args.name, cmaddr=args.cmaddr) data_symbol = runner.get_id('data') result_symbol = runner.get_id('result') print("加载程序到设备...") runner.load() runner.run() print("传输数据到设备...") runner.memcpy_h2d(data_symbol, input_data, 0, 0, 1, 1, DATA_SIZE, streaming=False, order=MemcpyOrder.ROW_MAJOR, data_type=MemcpyDataType.MEMCPY_32BIT, nonblock=False) print("执行计算...") runner.launch('compute', nonblock=False) print("获取结果...") result = np.zeros(DATA_SIZE, dtype=np.float32) runner.memcpy_d2h(result, result_symbol, 0, 0, 1, 1, DATA_SIZE, streaming=False, order=MemcpyOrder.ROW_MAJOR, data_type=MemcpyDataType.MEMCPY_32BIT, nonblock=False) runner.stop() print("验证结果...") expected = input_data * PARAM1 + PARAM2 if np.allclose(result, expected, rtol=1e-5, atol=1e-5): print("成功!结果正确。") else: print("失败!结果不匹配。") print(f"样本输入: {input_data[:5]}") print(f"实际结果: {result[:5]}") print(f"期望结果: {expected[:5]}") return 0
if __name__ == '__main__': exit(main())
|
4.3 运行示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| ./commands_wse3.sh
cs_python ./run.py --name out
初始化 1024 个数据元素... 加载程序到设备... 传输数据到设备... 执行计算... 获取结果... 验证结果... 成功!结果正确。
|
5. 多 PE 编程示例
5.1 概述
基于 gemv-05-multiple-pes 模式,这个示例演示了:
- 创建一维 PE 数组
- 数据复制到多个 PE
- PE 特定参数配置
- 从多个 PE 收集结果
5.2 代码详解
layout.csl - 多 PE 布局
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
| // PE 网格宽度参数 param width: i16;
// 导入 memcpy,配置为 width x 1 的网格 const memcpy = @import_module("<memcpy/get_params>", .{ .width = width, .height = 1 });
layout { // 设置 width x 1 的 PE 网格 @set_rectangle(width, 1); // 为每个 PE 配置代码和参数 for (@range(i16, width)) |x| { @set_tile_code(x, 0, "pe_program.csl", .{ // PE 特定的 memcpy 参数 .memcpy_params = memcpy.get_params(x), // PE 坐标 .px = x, // PE 特定的计算参数 .PARAM1 = 100 + x * 10, .PARAM2 = 16 + x * 2, // 每个 PE 的数据大小 .DATA_SIZE = 256 }); } // 导出分布式符号(第三个参数为 true) // 常见错误:多PE时忘记设置为 true 会导致符号冲突 @export_name("data", [*]f32, true); @export_name("result", [*]f32, true); @export_name("compute", fn()void); }
|
pe_program.csl - PE 计算程序
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
| // 参数接收 param memcpy_params: comptime_struct; param px: i16; // PE x 坐标 param DATA_SIZE: i16; // 数据大小 param PARAM1: i16; // 计算参数 1 param PARAM2: i16; // 计算参数 2
// 导入模块 const sys_mod = @import_module("<memcpy/memcpy>", memcpy_params);
// 内存数组 var data: [DATA_SIZE]f32 = @zeros([DATA_SIZE]f32); var result: [DATA_SIZE]f32 = @zeros([DATA_SIZE]f32);
// 指针 var data_ptr: [*]f32 = &data; var result_ptr: [*]f32 = &result;
// 计算函数 fn compute() void { // 每个 PE 使用自己的参数进行计算 for (@range(i16, DATA_SIZE)) |i| { result[i] = data[i] * @as(f32, PARAM1) + @as(f32, PARAM2); } sys_mod.unblock_cmd_stream(); }
// 导出符号 comptime { @export_symbol(data_ptr, "data"); @export_symbol(result_ptr, "result"); @export_symbol(compute); }
|
run_multi_pe.py - 多 PE Host 程序
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
| import argparse import json import numpy as np from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType, MemcpyOrder
DATA_SIZE_PER_PE = 256 BASE_PARAM1 = 100 BASE_PARAM2 = 16
def main(): parser = argparse.ArgumentParser() parser.add_argument('--name', default='out') parser.add_argument('--cmaddr') args = parser.parse_args() with open(f"{args.name}/out.json", "r") as f: compile_data = json.load(f) width = int(compile_data['params']['width']) print(f"运行 {width} 个 PE") input_data = np.random.rand(DATA_SIZE_PER_PE).astype(np.float32) runner = SdkRuntime(args.name, cmaddr=args.cmaddr) data_symbol = runner.get_id('data') result_symbol = runner.get_id('result') print("加载程序...") runner.load() runner.run() print(f"传输数据到 {width} 个 PE...") tiled_data = np.tile(input_data, width) runner.memcpy_h2d( data_symbol, tiled_data, 0, 0, width, 1, DATA_SIZE_PER_PE, streaming=False, order=MemcpyOrder.ROW_MAJOR, data_type=MemcpyDataType.MEMCPY_32BIT, nonblock=False ) print("执行计算...") runner.launch('compute', nonblock=False) print(f"从 {width} 个 PE 获取结果...") total_size = DATA_SIZE_PER_PE * width result = np.zeros(total_size, dtype=np.float32) runner.memcpy_d2h( result, result_symbol, 0, 0, width, 1, DATA_SIZE_PER_PE, streaming=False, order=MemcpyOrder.ROW_MAJOR, data_type=MemcpyDataType.MEMCPY_32BIT, nonblock=False ) runner.stop() print("\n验证结果...") all_passed = True for pe_idx in range(width): pe_param1 = BASE_PARAM1 + pe_idx * 10 pe_param2 = BASE_PARAM2 + pe_idx * 2 expected = input_data * pe_param1 + pe_param2 start_idx = pe_idx * DATA_SIZE_PER_PE end_idx = start_idx + DATA_SIZE_PER_PE pe_result = result[start_idx:end_idx] if np.allclose(pe_result, expected, rtol=1e-5, atol=1e-5): print(f" PE[{pe_idx}]: 通过 (PARAM1={pe_param1}, PARAM2={pe_param2})") else: print(f" PE[{pe_idx}]: 失败") all_passed = False if all_passed: print(f"\n✓ 成功!所有 {width} 个 PE 计算正确。") else: print("\n✗ 失败!部分 PE 计算错误。") return 0 if all_passed else 1
if __name__ == '__main__': exit(main())
|
commands_wse3.sh - 编译脚本
1 2 3 4 5 6 7 8 9 10 11 12
| set -e
WIDTH=4
cslc ./layout.csl \ --arch wse3 \ --fabric-dims=15,7 \ --fabric-offsets=4,1 \ --params=width:${WIDTH} \ --memcpy \ --channels=1 \ -o out
|
5.3 运行示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| ./commands_wse3.sh
cs_python ./run_multi_pe.py --name out
运行 4 个 PE 加载程序... 传输数据到 4 个 PE... 执行计算... 从 4 个 PE 获取结果...
验证结果... PE[0]: 通过 (PARAM1=100, PARAM2=16) PE[1]: 通过 (PARAM1=110, PARAM2=18) PE[2]: 通过 (PARAM1=120, PARAM2=20) PE[3]: 通过 (PARAM1=130, PARAM2=22)
✓ 成功!所有 4 个 PE 计算正确。
|
6. API 参考手册
本章详细介绍第 4 章和第 5 章示例中使用的所有 API。
6.1 CSL 内置函数
6.1.1 模块导入
@import_module
1 2
| @import_module(filename) @import_module(filename, param_binding)
|
导入 CSL 模块。
参数:
filename:模块文件名(comptime string)
param_binding:参数绑定(comptime 匿名结构体)
示例:
1 2 3 4 5
| // 导入标准库 const memcpy = @import_module("<memcpy/get_params>", .{ .width = 1, .height = 1 });
// 导入自定义模块 const math = @import_module("math.csl", .{ .precision = 32 });
|
6.1.2 布局配置
@set_rectangle
1
| @set_rectangle(width, height)
|
设置 PE 网格的矩形尺寸。必须在 layout 块中调用。
参数:
width:网格宽度(comptime 整数)
height:网格高度(comptime 整数)
@set_tile_code
1 2 3
| @set_tile_code(x_coord, y_coord) @set_tile_code(x_coord, y_coord, filename) @set_tile_code(x_coord, y_coord, filename, param_binding)
|
为指定坐标的 PE 设置执行代码。
参数:
x_coord, y_coord:PE 坐标(comptime 整数)
filename:代码文件名(comptime string)
param_binding:参数绑定(comptime 匿名结构体)
6.1.3 符号导出
@export_name
1 2
| @export_name(name, type) @export_name(name, type, is_distributed)
|
在 layout 中声明要导出的符号。
参数:
name:符号名称(string)
type:符号类型
is_distributed:是否为分布式符号(布尔值,默认 false)
示例:
1 2 3 4
| // 单 PE 导出 @export_name("data", [*]f32, false); // 多 PE 分布式导出 @export_name("data", [*]f32, true);
|
@export_symbol
1 2
| @export_symbol(symbol) @export_symbol(symbol, name)
|
在 PE 代码中导出符号供 host 访问。
参数:
symbol:要导出的变量或函数
name:导出名称(可选)
示例:
1 2 3 4
| comptime { @export_symbol(data_ptr, "data"); @export_symbol(compute); // 使用函数名作为导出名 }
|
6.1.4 数据初始化
@zeros
创建零初始化的数组。
参数:
示例:
1
| var data: [1024]f32 = @zeros([1024]f32);
|
@constants
1
| @constants(array_type, value)
|
创建常量初始化的数组。
参数:
array_type:数组类型
value:初始化值
示例:
1
| var x = @constants([N]f32, 1.0); // 所有元素初始化为 1.0
|
6.1.5 类型转换
@as
显式类型转换。
参数:
示例:
1
| result[i] = data[i] * @as(f32, PARAM1); // 将 i16 转换为 f32
|
@bitcast
位级别类型转换(不改变位模式)。
6.1.6 循环和范围
@range
1 2 3
| @range(type, count) @range(type, start, end) @range(type, start, end, step)
|
创建循环范围。
参数:
type:索引类型
count:元素数量(从 0 开始)
start:起始值
end:结束值(不包含)
step:步长
示例:
1 2 3 4 5 6 7 8 9
| // 0 到 width-1 for (@range(i16, width)) |x| { // ... }
// 0 到 9 for (@range(u32, 10)) |i| { // ... }
|
6.1.7 颜色和通信
@get_color
获取通信颜色。
参数:
@set_color_config
1
| @set_color_config(x, y, color, config)
|
配置 PE 的颜色路由。
参数:
x, y:PE 坐标
color:颜色对象
config:路由配置
示例:
1 2 3 4 5 6 7
| const EAST_COLOR = @get_color(1); @set_color_config(x, y, EAST_COLOR, .{ .routes = .{ .rx = .{WEST}, // 从西边接收 .tx = .{EAST} // 向东边发送 } });
|
6.2 编译时功能
6.2.1 参数声明
param
1 2
| param name: type; param name: type = default_value;
|
声明编译时参数,可从 layout 或命令行传入。
示例:
1 2
| param width: i16; param PARAM1: i16 = 100; // 带默认值
|
6.2.2 编译时块
comptime
1 2 3
| comptime { // 编译时执行的代码 }
|
标记在编译时执行的代码块。
6.3 Python Host API
6.3.1 运行时管理
SdkRuntime
1 2 3
| from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime
runner = SdkRuntime(compile_dir, cmaddr=None)
|
创建 SDK 运行时实例。
参数:
compile_dir:编译输出目录
cmaddr:CS 系统地址(可选,格式 “IP:port”)
基本控制方法:
1 2 3
| runner.load() runner.run() runner.stop()
|
6.3.2 符号访问
get_id
1
| symbol_id = runner.get_id(symbol_name)
|
获取导出符号的 ID。
参数:
返回:
6.3.3 数据传输
memcpy_h2d - Host 到 Device
1 2 3 4 5
| runner.memcpy_h2d(symbol, data, px, py, width, height, elem_per_pe, streaming=False, order=MemcpyOrder.ROW_MAJOR, data_type=MemcpyDataType.MEMCPY_32BIT, nonblock=False)
|
参数:
symbol:目标符号 ID
data:源数据(numpy 数组)
px, py:起始 PE 坐标
width, height:PE 网格尺寸
elem_per_pe:每个 PE 的元素数量
streaming:是否使用流式传输
order:数据布局顺序
data_type:数据类型(16 位或 32 位)
nonblock:是否非阻塞
memcpy_d2h - Device 到 Host
1 2 3 4 5
| runner.memcpy_d2h(result, symbol, px, py, width, height, elem_per_pe, streaming=False, order=MemcpyOrder.ROW_MAJOR, data_type=MemcpyDataType.MEMCPY_32BIT, nonblock=False)
|
参数:
result:目标数组(numpy 数组)
- 其他参数同
memcpy_h2d
6.3.4 函数调用
launch
1
| runner.launch(function_name, nonblock=False)
|
在设备上执行函数。
参数:
function_name:函数名称(字符串)
nonblock:是否非阻塞执行
6.3.5 枚举类型
MemcpyDataType
1 2 3 4
| from cerebras.sdk.runtime.sdkruntimepybind import MemcpyDataType
MemcpyDataType.MEMCPY_16BIT MemcpyDataType.MEMCPY_32BIT
|
MemcpyOrder
1 2 3 4
| from cerebras.sdk.runtime.sdkruntimepybind import MemcpyOrder
MemcpyOrder.ROW_MAJOR MemcpyOrder.COL_MAJOR
|
6.4 Memcpy 模块
6.4.1 模块导入
1 2 3 4 5 6
| const memcpy = @import_module("<memcpy/get_params>", .{ .width = width, .height = height, .MEMCPYH2D_1 = color, // 可选:自定义颜色 .MEMCPYD2H_1 = color // 可选:自定义颜色 });
|
6.4.2 获取 PE 参数
1
| memcpy.get_params(pe_id)
|
获取特定 PE 的 memcpy 参数。
参数:
pe_id:PE 索引(对于 2D 网格,使用 y * width + x)
6.4.3 在 PE 中使用
1 2 3 4 5
| param memcpy_params: comptime_struct; const sys_mod = @import_module("<memcpy/memcpy>", memcpy_params);
// 在计算函数末尾必须调用 sys_mod.unblock_cmd_stream();
|
6.5 常见模式总结
单 PE 数据传输
1 2 3 4 5
| runner.memcpy_h2d(symbol, data, 0, 0, 1, 1, size, ...)
runner.memcpy_d2h(result, symbol, 0, 0, 1, 1, size, ...)
|
多 PE 数据传输(复制模式)
1 2 3
| tiled_data = np.tile(data, num_pes) runner.memcpy_h2d(symbol, tiled_data, 0, 0, num_pes, 1, size_per_pe, ...)
|
多 PE 数据传输(分区模式)
1 2 3
| partitioned_data = np.concatenate([pe0_data, pe1_data, ...]) runner.memcpy_h2d(symbol, partitioned_data, 0, 0, num_pes, 1, size_per_pe, ...)
|
7. 关键概念总结
6.1 记住这些要点
三个文件,三个角色
- layout.csl:配置 PE 布局和参数
- pe_program.csl:实现计算逻辑
- run.py:控制执行流程
Fabric 偏移是硬性要求
- 使用 memcpy 必须
--fabric-offsets=4,1
- 前 4 列和第 1 行被系统占用
数组导出必须用指针
- 错误:
@export_symbol(data)
- 正确:
@export_symbol(&data)
别忘了 unblock_cmd_stream()
多 PE 符号要分布式导出
- 单 PE:
@export_name("data", [*]f32, false)
- 多 PE:
@export_name("data", [*]f32, true)
6.2 从这里开始
- 先跑通单 PE 示例:理解基本概念
- 尝试修改计算逻辑:在 compute() 中实现自己的算法
- 扩展到多 PE:体验并行计算的威力
- 探索高级特性:PE 间通信、流式传输等
8. 代码获取
完整的示例代码即将发布到 GitHub:
[TODO: 添加 GitHub 仓库链接]
仓库将包含:
- 本教程的所有示例代码
- 更多高级示例
- 常见应用模板
- 性能优化技巧
参考文献