Triton编译器在RISC-V上的移植与适配实践

引言

在上一篇系列文章《Triton for RISC-V:AI 算子开发新机遇》中,我们围绕 Triton 与 RISC-V 这一技术组合展开了较为系统的讨论,重点回答了两个核心问题:其一,为什么有必要在 RISC-V 平台上支持 Triton;其二,这一方向在现有技术条件下是否具备现实可行性。

随着生成式 AI 和大模型的快速发展,算子实现与优化在整体性能和开发效率中的重要性愈发凸显。Triton 作为一种面向算子开发的编程与编译体系,通过更贴近硬件执行模型的抽象方式,有效降低了高性能算子开发的门槛。与此同时,RISC-V 架构在开放性、可扩展性以及面向特定应用场景的定制能力方面展现出的优势,也为 AI 计算基础设施提供了更大的设计自由度。二者的结合,为应对当前异构计算体系日益复杂、算子需求高度定制化的趋势,提供了一条具有潜力的技术路径。

在可行性分析方面,从编程模型的角度来看,Triton 所采用的抽象方式具备向 RISC-V 平台延展和适配的基础条件;在编译框架层面,业界已经形成了较为完整的端到端 Triton 到 CPU 的实现路径,同时 LLVM 也提供了成熟的 RISC-V 后端支持。这使得在实际工程中,无需从零开始构建完整的编译体系,而是可以在较大程度上复用已有技术成果,为 Triton 在 RISC-V 平台上的落地提供切实可行的基础。

基于这样的背景,本文将以 Triton 编译器为核心,围绕其在 RISC-V 平台上的移植与适配过程,对相关技术路径和关键环节进行一次相对完整的梳理,帮助读者更清晰地理解 Triton 编译器在 RISC-V 平台落地过程中所面临的实际工程复杂度,并为后续更深入的优化与实践讨论奠定基础。

Triton 在 RISC-V 平台的运行方案

在尝试将 Triton 编译器引入 RISC-V 平台时,首要面对的问题是Triton 的整体编译与运行模式应当如何在这一新架构上落地。

Triton 最初主要面向以 x86 为核心的 host + device 异构计算生态进行设计,其运行模型默认宿主侧具备成熟的 Python 运行环境、完整的 LLVM 编译基础设施,以及相对稳定的系统软件栈。而在当前的 RISC-V 生态中,硬件形态呈现出高度多样化的特点:既包括面向嵌入式和 SoC 场景的低功耗处理器,也逐步出现定位服务器或加速计算的高性能实现,不同平台在算力规模、内存体系以及系统完整性等方面差异显著。在软件生态层面,RISC-V 的 Python 运行环境和编译工具链虽然近年来发展迅速,基础能力已逐步完善,但整体成熟度、生态完整性以及与上层框架的协同程度,仍与 x86 平台存在一定差距。

从工程实践的角度出发,在综合考虑实现复杂度、平台资源条件以及实际可落地性之后,目前可以将 “在 RISC-V 上运行 Triton” 的总体思路概括为两种方向:一种是将编译过程前移,在功能完备的平台上完成 Triton kernel 的生成,RISC-V 侧仅负责加载和执行;另一种则是尽量保留 Triton 原有的 JIT 使用范式,在 RISC-V 设备上直接完成 kernel 的动态编译与运行。

Offline 静态编译方案

Offline 静态编译方案的核心思想,是将 Triton 的编译阶段完全放在 x86 等成熟平台上完成。在这种模式下,开发者依然在熟悉的 x86 环境中编写、调试 Triton kernel,并借助交叉编译工具链,将其编译为面向 RISC-V 架构的目标文件或可执行二进制。最终,RISC-V 设备只承担运行时执行的职责,而不直接参与 Triton kernel 的生成过程。

这一方案最大的优势,在于对目标平台运行环境的要求非常低。RISC-V 侧无需部署完整的 Python 运行时,也不必引入 LLVM 等重量级编译基础设施,其系统形态更接近传统的“预编译 + 直接执行”模式。这使得 Offline 方案在嵌入式系统、专用加速器 SoC 或资源受限的平台上尤为具有吸引力,不仅可以显著降低系统复杂度,也有助于提升整体工程的稳定性和可控性。

但与此同时,这种设计也带来了新的工程挑战。Triton 的 kernel launcher 以及部分运行时逻辑,通常与宿主环境高度耦合,在 RISC-V 平台上复用或重构这些组件,往往需要额外的适配工作。更重要的是,由于编译环境与执行环境相互分离,问题定位、调试效率以及 Triton 原生支持的 autotuning 能力都会受到限制,整体开发体验不可避免地有所下降。

JIT 动态编译方案

与 Offline 静态编译方案相对,JIT 动态编译方案选择尽可能贴近 Triton 的原生使用方式。在这一路径下,Triton 的 Python 前端、编译器以及运行时被完整部署在 RISC-V 设备上,kernel 会在实际运行过程中,根据输入形状和配置参数动态生成、即时编译并执行。这种模式在使用逻辑上,与当前主流的 Triton 工作流保持高度一致。

JIT 方案最直观的优势体现在开发与调试体验上。由于编译和执行发生在同一平台,开发者可以像在 x86 环境中一样快速迭代 kernel 实现,进行参数调优和性能分析。这种一致性不仅降低了使用 Triton 的心智负担,也更有利于充分发挥其在自动调优和动态 specialization 方面的设计优势。

当然,这种灵活性是以更高的平台门槛为代价的。要在 RISC-V 上完整运行 Triton 的 JIT 体系,意味着需要具备较为完善的 Python 生态以及功能健全的 LLVM 编译工具链,这对当前多数 RISC-V 平台而言仍然是一项不小的挑战。同时,JIT 编译本身也会引入额外的运行时开销,对内存容量、存储性能以及计算资源都提出了更高要求。因此,该方案更适合资源相对充裕、定位偏向服务器或高性能计算场景的 RISC-V 系统。

两种方案的取舍与适用场景

从整体工程视角来看,Offline 静态编译与 JIT 动态编译,分别代表了将复杂性前移和将复杂性留在运行时的两种不同取舍。前者将主要工程成本集中在开发和构建阶段,以换取目标平台上的轻量化和稳定性,更适合嵌入式系统或高度定制化的 SoC 场景;后者则通过在运行时引入更高的系统复杂度,换取更强的灵活性和更友好的开发体验,更契合高端 RISC-V 服务器或研究型平台的需求。

因此,在实际工程中,这两种方案并不存在绝对意义上的优劣之分。关键在于目标平台的资源条件、项目周期,以及对灵活性和稳定性各自的侧重程度。理解这两种方案的取舍,也为后续具体实现提供了明确方向:

针对此次实践的RISC-V 目标平台,结合上述工程考量和平台限制,更可行的方案是采用 Offline 静态编译方案。具体流程如下:

Triton Kernel → Triton IR → LLVM IR → RISC-V 汇编 → ELF → 板端执行

在该流程中,Triton 前端生成 IR 后,通过 lowering 转化为 LLVM IR,再由 LLVM 后端编译为 RISC-V 汇编与可执行文件。最终生成的 ELF 可直接在 RISC-V 板端执行,无需依赖 Python 或 LLVM。此方案降低了系统复杂度,适用于嵌入式系统或专用加速平台,并有助于提高系统稳定性与调试效率。

在选择适合的方案后,下一步是配置支持 RISC-V 架构的交叉编译工具链,并确保所有依赖工具齐全,以顺利实现 Offline 静态编译流程,接下来将详细介绍如何使用 ZCC 编译器完成这一配置。

环境配置

在 RISC-V 平台进行 Triton 的静态编译,首先需要准备一套可用的 RISC-V 交叉编译环境。本文以 ZCC(Terapines 提供的基于 LLVM 的编译器)为例。在此基础上,需要确保 CMake、Python 等依赖工具可用,确保整个编译过程的顺利进行。

本文选用 ZCC 作为 RISC-V 交叉编译器,主要原因是:ZCC 基于 LLVM,天然兼容 Triton 生成的 LLVM IR,并且已具备较为成熟的 RVV 自动向量化优化能力。

以 ZCC 为例,需要确保版本支持 RISC-V 后端。交叉编译参数如下:

-mllvm --riscv-disable-rvv-fixedlen=false
-mrvv-vector-bits=128
--target=riscv64-unknown-linux-gnu
-O3
-march=rv64gcv
-mabi=lp64d

rv64gcv 启用 RVV 向量扩展

-mrvv-vector-bits=128 显式指定向量长度

关闭 fixed-length RVV,有利于后端进行更灵活的向量化决策。

后端插件的注册与构建

Python 构建入口(setup.py)

Triton 后端的选择始于 Python 层。当用户执行:

pip install .

时,setup.py 会显式指定需要参与构建的后端列表:

backends = [*BackendInstaller.copy(["nvidia", "amd", "riscv"]), *BackendInstaller.copy_externals()]

至此,Triton 在逻辑层面已经“知道”存在一个名为 riscv 的后端。 这一步的关键点在于:后端列表在构建阶段即被确定,后续所有 C++ 编译、宏定义、符号导出,都会围绕这个列表展开。

后端代码目录结构

接下来需要在 third_party 目录下补齐对应的后端实现:

(triton) ➜  third_party git:(feat/support-new-backend) ✗ tree -L 1
.
├── amd
├── nvidia
└── riscv

RISC-V 后端的基本结构如下:

(triton) ➜  riscv git:(feat/support-new-backend) ✗ tree -l
.
├── backend
│   ├── compiler.py
│   └── driver.py
├── CMakeLists.txt
└── python
    └── triton_riscv.cc

其中:

  • compiler.py:定义编译 pipeline(IR 转换、优化阶段)

  • driver.py:负责 Kernel Launcher 与运行时接口生成

  • triton_riscv.cc:Python 扩展入口(当前为最小实现)

到这里,一个 可被 Triton 识别的后端框架 就已经搭建完成。

triton_riscv.cc

当前阶段,RISC-V 后端不需要额外的 Python API,仅提供空初始化函数即可:

void init_triton_riscv(py::module &&m) {}

CMake编译命令生成

setup.py 中通过 setuptools 将扩展构建流程交给 CMakeBuild

"build_ext": CMakeBuild,

当构建开始时:

  1. setuptools 调用 build_ext

  2. CMakeBuild.build_extension() 被触发

  3. 后端列表被转化为 CMake 变量:

"-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]),
"-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external])
-DTRITON_CODEGEN_BACKENDS=nvidia;amd;riscv

最终生成的 cmake 命令中,TRITON_CODEGEN_BACKENDS 明确指定了需要参与编译的后端集合。这一变量后续会被主CMakeLists.txt文件反复使用。

后端目录的统一注册

  foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS})
    add_subdirectory(third_party/${CODEGEN_BACKEND})
  endforeach()

这会遍历每个后端(包括 RISC-V),并立即进入子目录并执行其中的 CMakeLists.txt,然后再返回主CMakeLists.txt文件。

RISC-V 后端插件的定义与注册

插件的编译和链接

链接的库是用于 Triton IR 的分析与转换的,可以自定义转换的 pass 以及逻辑。

message(STATUS "=== Supporting RISC-V Backend ===")

if(TRITON_BUILD_PYTHON_MODULE)
    add_triton_plugin(
      TritonRiscv
      ${CMAKE_CURRENT_SOURCE_DIR}/python/triton_riscv.cc
      LINK_LIBS
        TritonAnalysis
        TritonStructuredIR
        TritonToLinalg
        ...
        LLVMRISCVCodeGen
    )
  target_link_libraries(TritonRiscv PRIVATE Python3::Module pybind11::headers)
endif()

add_triton_plugin函数

具体实现如下:

set_property(GLOBAL PROPERTY TRITON_PLUGINS "")
function(add_triton_plugin name)
  set_property(GLOBAL APPEND PROPERTY TRITON_PLUGINS ${name})
  add_triton_object(${name} ${ARGN})
endfunction()

add_triton_plugin 做了两件事:

  1. 将插件名称 TritonRiscv 添加到全局属性 TRITON_PLUGINS,后续可通过 get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS) 获取所有插件

  2. 将所有参数(${ARGN})传递给 add_triton_object 执行实际构建

add_triton_object函数

具体实现如下:

# Utilities
function(add_triton_object name)
  cmake_parse_arguments(ARG "" "" "DEPENDS;LINK_LIBS" ${ARGN})
  add_library(${name} OBJECT)
  target_sources(${name}
    PRIVATE ${ARG_UNPARSED_ARGUMENTS}
    INTERFACE $<TARGET_OBJECTS:${name}>
  )

  # add_library(${name} OBJECT ${ARG_UNPARSED_ARGUMENTS})
  if(ARG_DEPENDS)
    add_dependencies(${name} ${ARG_DEPENDS})
  endif()
  if(ARG_LINK_LIBS)
    target_link_libraries(${name} PUBLIC ${ARG_LINK_LIBS})
  endif()
endfunction(add_triton_object)

add_library(TritonRiscv OBJECT) 用来创建 OBJECT 库目标

target_sources(TritonRiscv PRIVATE python/triton_riscv.cc …) 将源文件添加到 OBJECT 库

target_link_libraries(TritonRiscv PUBLIC tritonAnalysis …) 链接依赖库

这意味着每个后端插件本质上只是“一组会被最终合并的目标文件”,并不会生成单独的 .so

所有 add_subdirectory 执行完毕后,继续执行主 CMakeLists.txt 的后续代码。

创建主共享库并链接插件

在所有 add_subdirectory() 执行完成后,顶层 CMakeLists.txt 会统一收集所有插件(包括 TritonRiscv):

  get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)
  get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS)
  set(TRITON_LIBRARIES
    ${triton_libs}
    ${triton_plugins}

并将所有插件 OBJECT 库加入主库:

add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc
              ${PYTHON_SRC_PATH}/ir.cc
              ${PYTHON_SRC_PATH}/passes.cc
              ${PYTHON_SRC_PATH}/interpreter.cc
              ${PYTHON_SRC_PATH}/llvm.cc)
  target_link_libraries(triton PRIVATE ${TRITON_LIBRARIES})

最终结果是:TritonRiscv.o 文件与 NVIDIA、AMD 后端一起被静态链接进 libtriton.so

Python 运行时的后端加载机制

用户代码导入triton

import triton
import triton.language as tl

triton/init.py导入compiler模块

from .compiler import compile, CompilationError
from .compiler import CompiledKernel, ASTSource, IRSource, compile, make_backend, LazyDict

compiler.py文件中的compiler()函数(Python 导入时)

from .._C.libtriton import get_cache_invalidating_env_vars, ir

Python 会:

  1. 加载共享库 libtriton.so

  2. 对于模块 libtriton,Python 会查找并调用 PyInit_libtriton()

  3. PyInit_libtriton函数是由 PYBIND11_MODULE(libtriton, m)宏展开生成的函数接口

#define PYBIND11_MODULE(name, variable, ...)                                         
PYBIND11_MODULE_PYINIT(name, (pybind11::detail::get_num_interpreters_seen() += 1), ##__VA_ARGS__)
PYBIND11_MODULE_EXEC(name, variable)

PYBIND11_MODULE_EXEC会将宏展开成

#define PYBIND11_PLUGIN_IMPL(name)
    PYBIND11_PLUGIN_DECL(name)
    extern "C" PYBIND11_EXPORT PyObject *PyInit_##name()
PYBIND11_MODULE(libtriton, m) {
  m.doc() = "Python bindings to the C++ Triton API";
  init_triton_stacktrace_hook(m);
  init_triton_env_vars(m);
  init_triton_ir(m.def_submodule("ir"));
  init_triton_passes(m.def_submodule("passes"));
  init_triton_interpreter(m.def_submodule("interpreter"));
  init_triton_llvm(m.def_submodule("llvm"));
  FOR_EACH_P(INIT_BACKEND, TRITON_BACKENDS_TUPLE)
}

FOR_EACH_P被展开成:

init_triton_nvidia(m.def_submodule("nvidia"));
init_triton_amd(m.def_submodule("amd"));
init_triton_riscv(m.def_submodule("riscv"));

因此,后端需要添加triton_riscv.cc文件并实现 init_triton_riscv 函数:

void init_triton_riscv(py::module &&m) {}

总的来说就是:

  1. Python setup.py 决定构建哪些后端

  2. CMake 决定编译哪些目录

  3. OBJECT library 决定如何被合并

  4. pybind11 决定如何暴露给 Python

Offline 静态编译以及运行

在后端注册完成后,可在 x86 主机完成 RISC-V Offline 编译:

Triton Kernel → Triton IR → LLVM IR → 汇编 → ELF → 板端执行

compiler.py的修改

与 JIT 编译不同,Offline 模式下 不直接生成可执行二进制,而是止步于 LLVM IR。后续的汇编、链接与部署流程全部交由外部交叉编译工具链完成。

编译 stage 设置

stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
stages["coreir"] = lambda src, metadata: _optimize_coreir(_ttir_to_coreir(src))
stages["llir"] = lambda src, metadata: _optimize_llir(_coreir_to_llir(src, metadata))

IR 转换Pipeline

在支持e2e 的 Triton 项目中,目前有两个可以参考的 CPU 后端实现的开源项目Triton- CPUTriton-shared

Triton-CPU

Triton-CPU 以 vector dialect 作为核心中间表示,在 TTIR 向下转换过程中直接使用 vector dialect 承接算子,使向量化语义能够在较高层级的 IR 中被显式表达,从而有利于面向通用 CPU SIMD 架构生成高效代码。然而,相较于 linalg dialect,vector dialect 对张量整体结构和高层计算语义的保留能力有限,部分张量级信息在转换过程中会逐步丢失。这不仅使其在适配 RISC-V 的矩阵扩展或自定义扩展时面临困难,同时由于 vector dialect 的抽象模型与 RISC-V 向量指令集扩展(RVV)的执行语义并不完全匹配,也使得当前难以充分发挥 RVV 在可变向量长度和数据并行方面的架构优势。Lowering 路线图:

TTIR → scalarize → convert_memory_ops → 直接生成 vector.load/memref.load → LLVM IR
  1. 将 tt.load 等访存指令往 vector 方言翻译;

  2. 没有 bufferize 步骤,直接从 pointer 类型信息构建 memref,在 vector dialect 层面显示控制向量化;

相当于在 LLVM IR 层面直接生成向量指令,而不是依赖后端自动向量化。

Triton-shared

Triton-shared 采用以 linalg 与 memref 为核心的中间表示设计,在 TTIR 向下转换过程中分别承接计算与内存操作,从而较好地保留了张量级计算语义和数据访问结构。这种设计具有较强的通用性,也更符合 MLIR 典型的张量编译路径。然而,Triton-shared 最终生成的是基于 LLVM 的标量指令,向量化能力并未在 IR 层显式体现。对于 RISC-V 后端而言,这意味着需要依赖额外的向量化 pass 或 LLVM 的自动向量化机制,矩阵扩展和自定义扩展同样需要专门适配,使其在充分发挥硬件并行能力方面存在一定局限。Lowering路线图:

tt.load → tensor → bufferize → memref → LLVM IR
  • 使用 --one-shot-bufferize 进行 bufferize 分析;

  • 生成标准 LLVM IR,由后端编译器 LLVM 进行优化和 RVV/SIMD 指令生成;

总结:

目前,两种技术方案在整体完整度和性能优化方面仍处于持续迭代之中,相关实现和优化能力有待在实践中不断完善。基于这一现状,并考虑到 ZCC 在 RVV 场景下已经积累了较为成熟的自动向量化与线程级向量优化能力,本文在当前阶段选择以 Triton-shared 风格的 lowering 流程作为主要实现方向。在后续的 Triton 算子优化中,重点放在自动向量化能力的发挥上,其中具体的向量化转换主要由 ZCC 后端完成,而上层则侧重于为自动向量化分析提供必要的结构和语义信息支持。

这一技术路径的选择并不意味着对更高层次硬件语义表达的排斥。结合 RISC-V 平台的长期发展趋势,尤其是矩阵扩展(IME/VME/AME)的逐步引入,编译体系需要在高层语义表达与后端适配之间保持良好的平衡。Triton-shared 所采用的编译路径在强调通用性和可维护性的同时,也为后续能力演进预留了空间。在必要情况下,可以在 MLIR 层级引入更加显式的向量化或矩阵化表示,以更清晰地刻画计算结构,并为后端提供明确的优化边界,从而为未来面向矩阵类硬件特性的优化奠定基础。

driver.py的修改

driver.py 负责生成 Kernel Launcher,针对RISCV后端 offline 模式,主要是多线程启动编译好的 Triton kernel。

实现参考 triton-cpu 后端的 make_launcher,但针对 RISC-V 平台进行了如下调整:

  • 使用 OpenMP 实现多线程并行

  • 每个 program instance 映射到 (x, y, z) 网格

  • 生成 XXX_launcher.h / XXX_launcher.cpp

XXX_launcher.h

声明 Kernel Launcher 接口:

using kernel_ptr_t = void(*)({kernel_arg_decls}int, int, int, int, int, int);

void {kernel_name}_omp(int gridX, int gridY, int gridZ, {kernel_arg_decls}kernel_ptr_t kernel_ptr);

XXX_launcher.cpp

在 RISC-V CPU 上通过 OpenMP 并行执行 Triton 内核

#pragma omp parallel for schedule(static) num_threads(max_threads.value())
for (uint32_t x = 0; x < gridX; ++x) {
    (*kernel_ptr)(..., x, y, z);
}

这一实现完整复现了 Triton 的 program-level 并行语义。

算子示例与 Kernel 编译

以 softmax 算子为例

import torch
import triton
import triton.language as tl
import os
from triton.backends.riscv.driver import CrossDriver

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):
    row_idx = tl.program_id(0)
    row_start_ptr = input_ptr + row_idx * input_row_stride

    row_max = -float('inf')
    for off in range(0, n_cols, BLOCK_SIZE):
        col_offsets = off + tl.arange(0, BLOCK_SIZE)
        row = tl.load(row_start_ptr + col_offsets, mask=col_offsets < n_cols, other=-float('inf'))
        row_max = tl.maximum(row_max, tl.max(row, axis=0))

    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    denominator = 0.0
    for off in range(0, n_cols):
        row = tl.load(row_start_ptr + off)
        row_minus_max = row - row_max
        numerator = tl.exp(row_minus_max)
        denominator += numerator
        tl.store(output_row_start_ptr + off, numerator)

    for off in range(0, n_cols, BLOCK_SIZE):
        col_offsets = off + tl.arange(0, BLOCK_SIZE)
        row = tl.load(output_row_start_ptr + col_offsets, mask=col_offsets < n_cols, other=-float('inf'))
        softmax_output = row / denominator
        tl.store(output_row_start_ptr + col_offsets, softmax_output, mask=col_offsets < n_cols)

def softmax(x, y=None):
    n_rows, n_cols = x.shape
    if y is None:
        y = torch.empty_like(x)
    softmax_kernel[(n_rows, )](y, x, x.stride(0), y.stride(0), n_cols, BLOCK_SIZE=32)
    return y

DEVICE = triton.runtime.driver.active.get_active_torch_device()
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton_cpu = softmax(x)

通过 Triton 编译流程,最终得到:

  • softmax_kernel.llir

  • softmax_kernel_launcher.cpp

  • softmax_kernel_launcher.h

主函数的生成

由于 Torch 数据初始化发生在 x86 主机,Offline 场景下需要在 RISC-V 端 自行分配并初始化数据,因此额外编写 main.cpp

#include "softmax_kernel_launcher.h"
#include <cstdlib>
#include <random>

int main(int argc, char *argv[]) {
  int R = 1024;
  int C = 1024;
  int RUN_COUNT = 10;

  float *input = (float *)malloc(R * C * sizeof(float));
  float *real_out = (float *)malloc(R * C * sizeof(float));

  std::random_device rd;
  std::mt19937 gen(rd());
  std::normal_distribution<float> norm_dis(0, 1);
  for (int i = 0; i < R; ++i) {
    for (int j = 0; j < C; ++j) {
      input[i * C + j] = norm_dis(gen);
    }
  }

  for (int i = 0; i < RUN_COUNT; i++) {
    softmax_kernel_omp(R, 1, 1, 0, real_out, 0, input, C, C, C,
                       &softmax_kernel);
  }

  free(input);
  free(real_out);
  return 0;
}

主要支持:

  1. malloc 输入 / 输出 buffer

  2. 初始化随机数据

  3. 调用生成的 launcher 接口

汇编生成 object code

命令行如下:

生成汇编

./bin/zcc -fno-lto --target=riscv64-unknown-linux-gnu -march=rv64gcv -mabi=lp64d -O3 -S -x ir ./build/aux/src/softmax/softmax_kernel.llir -fopenmp=libomp -o ./build/aux/src/softmax//softmax_kernel.s -mllvm --riscv-disable-rvv-fixedlen=false -mrvv-vector-bits=128

链接生成elf

 ./bin/zcc -fno-lto --target=riscv64-unknown-linux-gnu -march=rv64gcv -mabi=lp64d -O3 ./softmax_kernel.s ./softmax_kernel_launcher.cpp ./softmax_kernel.cpp -I ./llvm-project/mlir/include/mlir -I ./riscv64-unknown-linux-gnu/include/c++/14.1.1 -I ./riscv64-unknown-linux-gnu/include/c++/14.1.1/riscv64-unknown-linux-gnu/lib64/lp64d -fopenmp=libgomp -L ./llvm-project/build/lib -lmlir_c_runner_utils -lmlir_float16_utils -lm -latomic -lstdc++ -std=c++17 -DCHECK_ACCURACY -fPIC -mllvm --riscv-disable-rvv-fixedlen=false -mrvv-vector-bits=128 -o ./softmax_kernel.elf

最终生成可在 RISC-V 板卡直接运行的 softmax_kernel.elf

Kernel Launcher 启动流程

将 ELF 拷贝至 RISC-V 平台运行后,对比 Kernel 输出与纯 C 实现结果:

out: Printing first 10 values
Index |  Value A |  Value B | Diff
------|----------|----------|-------
    0 | 0.000970 | 0.000970 | 0.000000
    1 | 0.000650 | 0.000650 | 0.000000
    2 | 0.007885 | 0.007885 | 0.000000
    3 | 0.001161 | 0.001161 | 0.000000
    4 | 0.000885 | 0.000885 | 0.000000
    5 | 0.001572 | 0.001572 | 0.000000
    6 | 0.003447 | 0.003447 | 0.000000
    7 | 0.001103 | 0.001103 | 0.000000
    8 | 0.000340 | 0.000340 | 0.000000
    9 | 0.000124 | 0.000124 | 0.000000
... (showing first 10 of 1400064 values)

两者结果完全一致,验证了Triton到RISC-V编译流程的正确性,此时RISC-V 后端基础适配成功完成。

总结

综上所述,围绕 Triton 编译器在 RISC-V 平台上的移植与适配实践,本文重点探讨并验证了以 Offline 静态编译为核心的整体实现路径。通过将 Triton kernel 的编译过程前移至功能完备的平台,并在 RISC-V 侧仅保留执行能力,这一方案为 Triton 在 RISC-V 上构建起了一条相对完整、可控的端到端算子编译与执行流程,也标志着 Triton 在 RISC-V 架构下具备了实际工程落地的基础条件。

从更宏观的生态视角来看,Offline 静态编译方案的成功实施展示了 RISC-V、Triton 以及 MLIR 等技术体系之间良好的协同潜力。随着后续性能优化工作的持续推进,RISC-V 后端将逐步形成一套既易于维护又具备良好扩展能力的算子库及其配套的算子库编译器,为上层应用和系统演进奠定坚实基础。在 RISC-V、Triton 与 MLIR 共同构成的分层协同模式下,不同技术栈能够更加清晰地发挥各自优势。硬件设计得以集中精力推动体系结构与算力能力的持续演进,Triton 则可以专注于高层算子表达能力的提升以及算子库生态的构建,而编译器与中间表示层则围绕 RISC-V 架构特性开展更具针对性的优化工作。这样的分层解耦协作机制,不仅有效降低了整体系统的复杂度,也为后续生态的拓展与演进提供了更加清晰和可持续的发展路径。

展望未来:在完成基本可用性验证之后,性能优化自然成为下一阶段更具挑战性、同时也更具研究价值的方向。在这一过程中,如何充分结合 RISC-V 的微架构特性,对 Triton 生成代码在指令选择、访存模式以及并行执行策略等方面开展系统化、深层次的优化,将成为影响其在实际计算负载中表现的关键因素。围绕这些核心问题,后续系列文章将继续聚焦 Offline 静态编译方案,系统性地展示 Triton 在 RISC-V 平台上的实际性能表现及其演进过程。