算子融合

算子融合:让 AI 计算少跑”冤枉路”

在深度学习的世界里,一个看似简单的操作,背后可能是无数次的数据搬运。算子融合(Kernel Fusion)技术就像一位聪明的快递员,把原本需要多次往返的包裹合并成一趟送完,大幅提升了 AI 计算的效率。

什么是算子(Kernel)?

在 GPU 计算中,算子(也叫 Kernel)是运行在 GPU 上的一个计算函数。常见的算子包括:

  • 矩阵乘法(MatMul)
  • 激活函数(ReLU、GELU)
  • 归一化(LayerNorm、BatchNorm)
  • 逐元素运算(加法、乘法)

每个算子执行时,通常需要:

  1. 从显存(GPU 内存)读取数据
  2. 执行计算
  3. 把结果写回显存

问题:反复读写显存太慢了!

假设我们要计算:y = ReLU(x + bias)

不融合的做法:

1
2
3
4
5
6
7
8
9
算子1: Add
- 从显存读取 x 和 bias
- 计算 x + bias
- 把结果写回显存(临时存储 temp)

算子2: ReLU
- 从显存读取 temp
- 计算 ReLU(temp)
- 把结果写回显存(最终结果 y)

问题在哪?

  • 中间结果 temp 被写入显存,紧接着又被读出来
  • 显存访问是 GPU 计算的最大瓶颈
  • 这次”往返”完全是浪费时间!

算子融合的解决方案

融合后的做法:

1
2
3
4
5
融合算子: Add_ReLU
- 从显存读取 x 和 bias
- 计算 x + bias
- 直接在寄存器/共享内存中计算 ReLU
- 把最终结果 y 写回显存

效果:

  • 省掉了中间结果的读写
  • 显存访问次数减半
  • 速度大幅提升!

形象类比

不融合: 你要寄三个包裹去三个地址。

  • 去邮局寄第一个,回家
  • 再去邮局寄第二个,回家
  • 再去邮局寄第三个

融合: 你带着三个包裹一趟全寄完。

省的就是那些”回家再出门”的时间。在 GPU 世界,”去邮局”就是访问显存,这个时间成本非常高!

常见的融合模式

1. 激活函数融合

最常见的融合是把激活函数合并到前一个算子:

1
2
MatMul + ReLU → MatMul_ReLU
Conv + BatchNorm + ReLU → Conv_BN_ReLU

2. 归一化融合

1
2
LayerNorm = Mean + Variance + Normalize + Scale + Shift
→ 融合成一个 FusedLayerNorm 算子

3. 注意力机制融合

Flash Attention 就是经典的算子融合案例:

1
2
3
4
5
6
传统 Attention:
Q × K^T → Softmax → × V
(每步都读写显存)

Flash Attention:
融合成一个算子,中间结果保留在 SRAM 中

4. 逐元素操作链融合

1
2
x → Add(bias) → Multiply(scale) → Tanh → Dropout
→ 融合成一个算子

技术实现

手动融合

在 CUDA 中手动实现融合算子:

1
2
3
4
5
6
7
8
__global__ void fused_add_relu(float* x, float* bias, float* y, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
// 一次读取,一次计算,一次写入
float val = x[idx] + bias[idx];
y[idx] = val > 0 ? val : 0; // ReLU
}
}

自动融合

现代深度学习编译器可以自动进行算子融合:

工具/框架 融合能力
TensorRT 自动融合常见模式
XLA TensorFlow 的编译器
TorchScript/torch.compile PyTorch 的 JIT 编译
Triton 用户友好的自定义 Kernel
ONNX Runtime 图优化器自动融合

PyTorch 示例:

1
2
3
4
5
6
7
8
9
10
11
import torch

## 定义模型
class MyModel(torch.nn.Module):
def forward(self, x, bias):
return torch.relu(x + bias)

model = MyModel()

## 使用 torch.compile 自动优化(包括算子融合)
optimized_model = torch.compile(model)

融合的收益分析

以 Transformer 中的 FFN 层为例:

1
2
3
4
5
传统实现:
Linear1 → 写显存 → 读显存 → GELU → 写显存 → 读显存 → Linear2

融合实现:
Linear1 → GELU → Linear2 (中间结果留在快速存储中)

性能对比:

指标 未融合 融合后
显存访问次数 6次 2次
内存带宽占用
计算效率 30-50% 70-90%
延迟 基准 降低 30-50%

融合的限制

并非所有算子都能融合,需要满足一些条件:

1. 数据依赖

只有串行依赖的算子才能融合。如果两个算子是并行关系,融合反而可能降低效率。

2. 计算/访存比

融合的收益取决于节省的访存时间 vs 增加的复杂度

  • 如果计算量很大、访存很少,融合收益不明显
  • 如果访存多、计算少(如逐元素操作),融合收益很大

3. 共享内存限制

融合后的算子可能需要更多共享内存和寄存器,可能超过 GPU 限制。

实际应用案例

Flash Attention

1
2
3
4
5
6
7
8
## 传统注意力 - 多个独立算子
attn_weights = torch.matmul(Q, K.transpose(-2, -1))
attn_weights = torch.softmax(attn_weights / math.sqrt(d_k), dim=-1)
output = torch.matmul(attn_weights, V)

## Flash Attention - 融合算子
from flash_attn import flash_attn_func
output = flash_attn_func(Q, K, V) # 一个融合 Kernel

效果: 速度提升 2-4 倍,显存降低 5-20 倍。

TensorRT 自动融合

1
2
3
4
5
import tensorrt as trt

## TensorRT 会自动识别并融合:
## Conv + BatchNorm + ReLU → CBR融合算子
## MatMul + Add + GELU → 融合算子

如何判断是否需要融合?

使用性能分析工具:

1
2
3
4
5
6
7
8
## 查看 Kernel 调用情况
nsys profile python my_model.py

## 如果看到大量小 Kernel 连续调用,考虑融合
## Kernel: add_kernel (0.1ms)
## Kernel: relu_kernel (0.1ms)
## Kernel: mul_kernel (0.1ms)
## → 可以融合成一个 Kernel

总结

算子融合是深度学习性能优化的核心技术之一。它的本质是减少显存访问,通过把多个小操作合并成一个大操作,避免中间结果的反复读写。

关键要点:

  1. 显存访问是瓶颈: 计算快,读写慢
  2. 融合减少往返: 数据留在快速存储中处理
  3. 自动融合工具: TensorRT、torch.compile 等
  4. 经典案例: Flash Attention

理解算子融合,你就理解了 GPU 加速的核心奥秘之一。

Kernel Fusion: Making AI Computing Take Fewer “Unnecessary Trips”

In the world of deep learning, a seemingly simple operation might involve countless data transfers behind the scenes. Kernel Fusion technology is like a smart delivery driver who combines packages that would require multiple trips into one delivery, dramatically improving AI computing efficiency.

What is a Kernel?

In GPU computing, a Kernel is a computational function that runs on the GPU. Common kernels include:

  • Matrix multiplication (MatMul)
  • Activation functions (ReLU, GELU)
  • Normalization (LayerNorm, BatchNorm)
  • Element-wise operations (addition, multiplication)

When each kernel executes, it typically needs to:

  1. Read data from GPU memory (VRAM)
  2. Perform computation
  3. Write results back to GPU memory

The Problem: Repeated Memory Access is Too Slow!

Suppose we need to compute: y = ReLU(x + bias)

Without fusion:

1
2
3
4
5
6
7
8
9
Kernel 1: Add
- Read x and bias from GPU memory
- Compute x + bias
- Write result to GPU memory (temporary storage: temp)

Kernel 2: ReLU
- Read temp from GPU memory
- Compute ReLU(temp)
- Write result to GPU memory (final result: y)

Where’s the problem?

  • Intermediate result temp is written to GPU memory, then immediately read back
  • Memory access is the biggest bottleneck in GPU computing
  • This “round trip” is a complete waste of time!

Kernel Fusion’s Solution

After fusion:

1
2
3
4
5
Fused Kernel: Add_ReLU
- Read x and bias from GPU memory
- Compute x + bias
- Compute ReLU directly in registers/shared memory
- Write final result y to GPU memory

Effect:

  • Eliminated intermediate result read/write
  • Memory accesses cut in half
  • Significant speed improvement!

An Analogy

Without fusion: You need to mail three packages to three addresses.

  • Go to post office to send first one, go home
  • Go to post office again to send second one, go home
  • Go to post office again to send third one

With fusion: You bring all three packages and mail them in one trip.

What you save is all that “going home and going out again” time. In the GPU world, “going to the post office” is accessing GPU memory—this time cost is very high!

Common Fusion Patterns

1. Activation Function Fusion

The most common fusion is merging activation functions into the preceding operator:

1
2
MatMul + ReLU → MatMul_ReLU
Conv + BatchNorm + ReLU → Conv_BN_ReLU

2. Normalization Fusion

1
2
LayerNorm = Mean + Variance + Normalize + Scale + Shift
→ Fused into one FusedLayerNorm kernel

3. Attention Mechanism Fusion

Flash Attention is a classic kernel fusion case:

1
2
3
4
5
6
Traditional Attention:
Q × K^T → Softmax → × V
(Each step reads/writes GPU memory)

Flash Attention:
Fused into one kernel, intermediate results stay in SRAM

4. Element-wise Operation Chain Fusion

1
2
x → Add(bias) → Multiply(scale) → Tanh → Dropout
→ Fused into one kernel

Technical Implementation

Manual Fusion

Manually implementing a fused kernel in CUDA:

1
2
3
4
5
6
7
8
__global__ void fused_add_relu(float* x, float* bias, float* y, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
// One read, one compute, one write
float val = x[idx] + bias[idx];
y[idx] = val > 0 ? val : 0; // ReLU
}
}

Automatic Fusion

Modern deep learning compilers can automatically perform kernel fusion:

Tool/Framework Fusion Capability
TensorRT Automatic fusion of common patterns
XLA TensorFlow’s compiler
TorchScript/torch.compile PyTorch’s JIT compilation
Triton User-friendly custom kernels
ONNX Runtime Graph optimizer automatic fusion

PyTorch example:

1
2
3
4
5
6
7
8
9
10
11
import torch

## Define model
class MyModel(torch.nn.Module):
def forward(self, x, bias):
return torch.relu(x + bias)

model = MyModel()

## Use torch.compile for automatic optimization (including kernel fusion)
optimized_model = torch.compile(model)

Fusion Benefit Analysis

Using Transformer’s FFN layer as an example:

1
2
3
4
5
Traditional implementation:
Linear1 → write memory → read memory → GELU → write memory → read memory → Linear2

Fused implementation:
Linear1 → GELU → Linear2 (intermediate results stay in fast storage)

Performance comparison:

Metric Unfused Fused
Memory accesses 6 times 2 times
Memory bandwidth usage High Low
Compute efficiency 30-50% 70-90%
Latency Baseline 30-50% lower

Fusion Limitations

Not all operators can be fused; some conditions must be met:

1. Data Dependencies

Only serially dependent operators can be fused. If two operators are parallel, fusion might actually reduce efficiency.

2. Compute-to-Memory Ratio

Fusion benefits depend on saved memory access time vs. added complexity:

  • If computation is heavy and memory access is light, fusion benefits are minimal
  • If memory access is heavy and computation is light (like element-wise ops), fusion benefits are significant

3. Shared Memory Limits

Fused kernels might need more shared memory and registers, potentially exceeding GPU limits.

Practical Application Cases

Flash Attention

1
2
3
4
5
6
7
8
## Traditional attention - multiple separate kernels
attn_weights = torch.matmul(Q, K.transpose(-2, -1))
attn_weights = torch.softmax(attn_weights / math.sqrt(d_k), dim=-1)
output = torch.matmul(attn_weights, V)

## Flash Attention - fused kernel
from flash_attn import flash_attn_func
output = flash_attn_func(Q, K, V) # One fused kernel

Effect: 2-4x speed improvement, 5-20x memory reduction.

TensorRT Automatic Fusion

1
2
3
4
5
import tensorrt as trt

## TensorRT automatically identifies and fuses:
## Conv + BatchNorm + ReLU → CBR fused kernel
## MatMul + Add + GELU → Fused kernel

How to Determine if Fusion is Needed?

Use performance analysis tools:

1
2
3
4
5
6
7
8
## View kernel call patterns
nsys profile python my_model.py

## If you see many small kernels called consecutively, consider fusion
## Kernel: add_kernel (0.1ms)
## Kernel: relu_kernel (0.1ms)
## Kernel: mul_kernel (0.1ms)
## → Can be fused into one kernel

Summary

Kernel fusion is one of the core technologies in deep learning performance optimization. Its essence is reducing memory access by combining multiple small operations into one large operation, avoiding repeated read/write of intermediate results.

Key points:

  1. Memory access is the bottleneck: Computing is fast, reading/writing is slow
  2. Fusion reduces round trips: Data stays in fast storage for processing
  3. Automatic fusion tools: TensorRT, torch.compile, etc.
  4. Classic example: Flash Attention

Understand kernel fusion, and you understand one of the core secrets of GPU acceleration.