天下苦英伟达久矣!PyTorch免CUDA加速推理,Triton时代要来?
yuyutoo 2025-03-06 21:00 1 浏览 0 评论
机器之心报道
编辑:杜伟、小舟
近日,PyTorch 官方分享了如何实现无 CUDA 计算,对各个内核进行了微基准测试比较,并讨论了未来如何进一步改进 Triton 内核以缩小与 CUDA 的差距。
在做大语言模型(LLM)的训练、微调和推理时,使用英伟达的 GPU 和 CUDA 是常见的做法。在更大的机器学习编程与计算范畴,同样严重依赖 CUDA,使用它加速的机器学习模型可以实现更大的性能提升。
虽然 CUDA 在加速计算领域占据主导地位,并成为英伟达重要的护城河之一。但其他一些工作的出现正在向 CUDA 发起挑战,比如 OpenAI 推出的 Triton,它在可用性、内存开销、AI 编译器堆栈构建等方面具有一定的优势,并持续得到发展。
近日,PyTorch 官宣要做「无英伟达 CUDA 参与的大模型推理」。在谈到为什么要 100% 使用 Triton 进行探索时,PyTorch 表示:「Triton 提供了一条途径,使大模型 能够在不同类型的 GPU 上运行,包括英伟达、AMD、英特尔和其他基于 GPU 的加速器。
此外 Triton 还在 Python 中为 GPU 编程提供了更高的抽象层,使得使用 PyTorch 能够比使用供应商特定的 API 更快地编写高性能内核。」
在 PyTorch 博客中讨论了使用流行的 LLM 模型(例如 Meta 的 Llama3-8B 和 IBM 的 Granite-8B Code)实现 FP16 推理的方法,其中计算是 100% 使用 OpenAI 的 Triton 语言执行的。
对于使用基于 Triton 内核的模型生成单个 token 的时间,PyTorch 能够实现在英伟达 H100 GPU 上 Llama 和 Granite 的 CUDA 内核主导工作流程的 0.76-0.78 倍性能,以及在英伟达 A100 GPU 上的 0.62-0.82 倍。
图 1. 在英伟达 H100 和 A100 上,Llama3-8B 和 Granite-8B 的 Triton 和 CUDA 变体的推理吞吐量比较。设置:批大小 = 2,输入序列长度 = 512,输出序列长度 = 256
也许告别英伟达的时候真要来了。
Transformer 块的组成
PyTorch 团队首先对基于 Transformer 的模型中发生的计算进行细分。下图显示了典型 Transformer 块的「内核(kernel)」。
图 2
Llama3 架构的核心操作总结如下:
- 均方根归一化(RMSNorm)
- 矩阵乘法:Fused QKV
- RoPE
- 注意力
- 矩阵乘法:输出投影
- RMSNorm
- 矩阵乘法:Fused Gate + Up Projection
- 激活函数:SiLU
- 点乘(Element Wise Multiplication)
- 矩阵乘法:Down Projection
这些操作中的每一个都是通过在 GPU 上执行一个(或多个)内核来计算的。虽然每个内核的细节在不同的 Transformer 模型中可能有所不同,但核心操作保持不变。例如,IBM 的 Granite 8B Code 模型在 MLP 层中使用偏置,与 Llama3 不同。此类更改确实需要对内核进行修改。典型的模型是这些 Transformer 块的堆叠,这些 Transformer 块通过嵌入层连接在一起。
模型推理
典型的模型架构代码与 PyTorch 启动的 python model.py 文件共享。在默认的 PyTorch Eager Execution 模式下,这些内核都是使用 CUDA 执行的。为了实现 100% Triton 进行端到端 Llama3-8B 和 Granite-8B 推理,需要编写和集成手写 Triton 内核以及利用 torch.compile(生成 Triton 操作)。首先,PyTorch 用编译器生成的 Triton 内核替换较小的操作,其次,PyTorch 用手写的 Triton 内核替换更昂贵和复杂的计算(例如矩阵乘法和闪存注意力)。
Torch.compile 自动为 RMSNorm、RoPE、SiLU 和点乘生成 Triton 内核。使用 Nsight Systems 等工具,可以观察到这些生成的内核,它们在矩阵乘法和注意力之间表现为微小的深绿色内核。
图 3. 使用 torch.compile 跟踪 Llama3-8B,显示用于矩阵乘法和闪存注意力的 CUDA 内核。
对于上面的跟踪,PyTorch 团队注意到,在 Llama3-8B 样式模型中,占 E2E 延迟 80% 的两个主要操作是矩阵乘法和注意力内核,并且两者仍然是 CUDA 内核。因此,为了弥补剩余的差距,PyTorch 团队用手写的 Triton 内核替换了 matmul 和注意力内核。
Triton SplitK GEMM 内核
对于线性层中的矩阵乘法,PyTorch 团队编写了一个自定义 FP16 Triton GEMM(通用矩阵 - 矩阵乘法)内核,该内核利用了 SplitK 工作分解。
GEMM 内核调优
为了实现最佳性能,PyTorch 团队使用穷举搜索方法来调整 SplitK GEMM 内核。Granite-8B 和 Llama3-8B 具有如下形状的线性层:
图 4. Granite-8B 和 Llama3-8B 线性层权重矩阵形状。
每个线性层都有不同的权重矩阵形状。因此,为了获得最佳性能,必须针对每个形状轮廓调整 Triton 内核。在对每个线性层进行调整后,PyTorch 能够在 Llama3-8B 和 Granite-8B 上实现相对于未调整的 Triton 内核 1.20 倍的 E2E 加速。
Flash Attention 内核
PyTorch 团队使用不同的配置,对现有 Triton flash attention 内核进行了评估,包括
- AMD Flash
- OpenAI Flash
- Dao AI Lab Flash
- XFormers Flash
- PyTorch FlexAttention
PyTorch 团队分别在 eager 模式和编译模式下评估了每个内核的文本生成质量。下图 5 为不同 Flash Attention 内核的比较。
上图总结了 PyTorch 观察到的开箱即用情况,并预计内核 2 到 5 可以在修改后满足上述标准。不过这也表明,拥有一个可用于基准测试的内核通常只是将它用作端到端生产内核的开始。
PyTorch 团队选择在后续测试中使用 AMD flash attention 内核,它通过 torch.compile 进行编译,并在 eager 和编译模式下产生清晰的输出。
为了满足 torch.compile 与 AMD flash attention 内核的兼容性,PyTorch 团队必须将它定义为 torch 自定义算子。并且封装更复杂的 flash attention 内核遵循以下两个步骤:
一是将函数封装为一个 PyTorch 自定义算子。
二是向该算子添加一个 FakeTensor 内核,并在给定 flash 输入张量的形状(q、k 和 v)时,计算 flash 内核的输出形状。
在将 Triton flash 内核定义为一个自定义 op 后,PyTorch 团队可以成功地对它进行编译以实现端到端运行。
图 6:在交换 Triton matmul 和 Triton flash attention 内核后,使用 torch.compile 的 Llama3-8B 轨迹。
从图中可以看到,在集成 SplitK 矩阵乘法内核后,torch op 封装 flash attention 内核,然后运行 torch.compile,即可实现使用 100% Triton 计算内核的前向传递。
端到端基准测试
PyTorch 团队分别对运行 Granite-8B 和 Llama3-8B 模型的英伟达 H100 和 A100(单 GPU)进行了端到端测试,使用了两种不同的配置来执行基准测试。
其中 Triton 内核配置使用了:
- Triton SplitK GEMM
- AMD Triton Flash Attention
CUDA 内核配置使用了
- cuBLAS GEMM
- cuDNN Flash Attention - Scaled Dot-Product Attention (SDPA)
在典型推理设置下,两种 eager 和 torch 编译模式的吞吐量和 inter-token 延迟如下图所示。
图 7:H100 和 A100 上 Granite-8B 和 Llama3-8B 单 token 生成延迟(批大小 = 2,输入序列长度 = 512,输出序列长度 = 256)。
总的来说,在 H100 上,Triton 模型最高可以达到 CUDA 模型性能的 78%;在 A100 上可以达到 82%。这些性能差距是由 matmul 和 flash attention 的内核延迟造成的。
微基准测试
下图 8 为 Triton 和 CUDA 内核延迟比较(英伟达 H100 上运行 Llama3-8B)。输入为一个任意 prompt(批大小 = 1,prompt 序列长度 = 44),以解码延迟时间。
最后结果显示,Triton matmul 内核比 CUDA 慢了 1.2 至 1.4 倍,而 AMD Triton Flash Attention 比 CUDA SDPA 慢了 1.6 倍。
以上结果凸显了需要进一步提升 GEMM 和 Flash Attention 等核心原语内核的性能。最近的一些工作(如 FlashAttention-3、FlexAttention) 已经提出了更好地利用底层硬件和 Triton 的方法,PyTorch 希望在它们的基础上实现更大加速。为了阐明这一点,PyTorch 团队将 FlexAttention 与 SDPA、AMD’s Triton Flash 内核进行了比较。
PyTorch 团队 正努力验证 FlexAttention 的端到端性能。目前,FlexAttention 的初始微基准测试结果表明,在查询向量较小的情况下,有望实现更长的上下文以及解码问题形状。
图 9:英伟达 H100 SXM5 80GB 上 FlexAttention 内核基准测试(批大小 = 1,最大头数 = 32,头维数 = 128)。
未来工作
未来,PyTorch 团队计划探索进一步优化 matmuls 的方法,以便更好地利用硬件,并为基于 Triton 的方法实现更大的加速。
对于 flash attention,PyTorch 团队计划探索 FlexAttention 和 FlashAttention-3 等内核中使用到的技术,以帮助进一步缩小 Triton 与 CUDA 之间的差距。同时还将探索端到端 FP8 LLM 推理。
原文链接:
https://pytorch.org/blog/cuda-free-inference-for-llms/
相关推荐
- 网站建设:从新手到高手
-
现代化网站应用领域非常广泛,从个人形象网站展示、企业商业网站运作、到政府公益等服务网站,各行各业都需要网站建设。大体上可以归结四类:宣传型网站设计、产品型网站制作、电子商务型网站建设、定制型功能网站开...
- JetBrains 推出全新 AI 编程工具 Junie,助力高效开发
-
JetBrains宣布推出名为Junie的全新AI编程工具。这款工具不仅能执行简单的代码生成与检查任务,还能应对编写测试、验证结果等复杂项目,为开发者提供全方位支持。根据SWEBench...
- AI也能写代码!代码生成、代码补全、注释生成、代码翻译轻松搞定
-
清华GLM技术团队打造的多语言代码生成模型CodeGeeX近期更新了新的开源版本「CodeGeeX2-6B」。CodeGeeX2是多语言代码生成模型CodeGeeX的第二代模型,不同于一代CodeG...
- 一键生成前后端代码,一个36k星的企业级低代码平台
-
「企业级低代码平台」前后端分离架构SpringBoot2.x,SpringCloud,AntDesign&Vue,Mybatis,Shiro,JWT。强大的代码生成器让前后端代码一键生成,无需写任...
- Gitee 代码托管实战指南:5 步完成本地项目云端同步(附避坑要点)
-
核心流程拆解:远程仓库的搭建登录Gitee官网(注册账号比较简单,大家自行操作),点击“新建仓库”,建议勾选“初始化仓库”和“设置模板文件”(如.gitignore),避免上传临时文件。...
- jeecg-boot 源码项目-强烈推荐使用
-
JEECGBOOT低代码开发平台...
- JetBrains推出全新AI编程工具Junie,强调以开发者为中心
-
IT之家2月1日消息,JetBrains发文,宣布推出一款名为Junie的全新AI编程工具,官方声称这款AI工具既能执行简单的代码生成与检查等基础任务,也能应对“编写测试、验证结...
- JetBrains旗下WebStorm和Rider现已加入“非商用免费”阵营
-
IT之家10月25日消息,软件开发商JetBrains今日宣布,旗下WebStorm(JavaScript开发工具)和Rider(.NET开发工具)现已加入“非商用免费”阵营。如果...
- 谈谈websocket跨域
-
了解websocketwebsocket是HTML5的新特性,在客户端和服务端提供了一个基于TCP连接的双向通道。...
- websocket调试工具
-
...
- 利用webSocket实现消息的实时推送
-
1.什么是webSocketwebSocket实现实现推送消息WebSocket是HTML5开始提供的一种在单个TCP连接上进行全双工通讯的协议。以前的推送技术使用Ajax轮询,浏览器需...
- 为 Go 开发的 WebSocket 库
-
#记录我的2024#...
- 「Java基础」Springboot+Websocket的实现后端数据实时推送
-
这篇文章主要就是实现这个功能,只演示一个基本的案例。使用的是websocket技术。...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- mybatis plus (70)
- scheduledtask (71)
- css滚动条 (60)
- java学生成绩管理系统 (59)
- 结构体数组 (69)
- databasemetadata (64)
- javastatic (68)
- jsp实用教程 (53)
- fontawesome (57)
- widget开发 (57)
- vb net教程 (62)
- hibernate 教程 (63)
- case语句 (57)
- svn连接 (74)
- directoryindex (69)
- session timeout (58)
- textbox换行 (67)
- extension_dir (64)
- linearlayout (58)
- vba高级教程 (75)
- iframe用法 (58)
- sqlparameter (59)
- trim函数 (59)
- flex布局 (63)
- contextloaderlistener (56)