CUDA 编写 softmax 算子

2026-01-17

LeetGPU 题目链接

Shared Memory + Tree Reduction 算子实现思路

我的实现大致分 $3$ 步:

  1. 对原数组进行 block tiling,在每一个 block 上,用 Tree-Reduction 计算其负责的 $max_i, sum_i$,这里的 ${sum}_i$ 指的是这个 block 内的 safe softmax 之和,即

    $$sum_i=\sum_{k=l_i}^{r_i}e^{x_k-max_i}$$

    保存在 shared memory 里

  2. 再用 $1$ 个 block 对保存在 shared memory 上的 max_cache, sum_cache 进行归约,求出全局的最大值和 sum of safe softmax

  3. 利用计算好的 global max 和 sum,对 input 进行 transform.

safe_merge() 辅助函数

由于数组的长度并不一定是 $2$ 的整数次幂,所以在用 shared memory 进行 tree reduction 的时候,可能会发生无法完美填进 $2^k$。所以这个时候,我们需要对填不满的 shared memory array 里填 -INFINITY,由于需要编写特殊的逻辑来处理 -INFINITY 的加法和 safe softmax.

这里的设计是,cmax, nmax 表示两个 $max$,而 csum, nsum 表示两个 sum of safe softmax. 这里,safe softmax 的更新思路是 online softmax.

__device__ void safe_merge(float cmax,
                           float nmax,
                           float csum,
                           float nsum,
                           float &cm,
                           float &cs) {
  if (nmax == NINF)
    cm = cmax, cs = csum;
  else if (cmax == NINF)
    cm = nmax, cs = nsum;
  else {
    if (cmax > nmax)
      cs = cs + nsum * expf(nmax - cmax);
    else {
      cs = nsum + cs * expf(cmax - nmax);
      cm = nmax;
    }
  }
}

整体代码

#include <cuda_runtime.h>
#define NINF -INFINITY

__device__ void safe_merge(float cmax,
                           float nmax,
                           float csum,
                           float nsum,
                           float &cm,
                           float &cs) {
  if (nmax == NINF)
    cm = cmax, cs = csum;
  else if (cmax == NINF)
    cm = nmax, cs = nsum;
  else {
    if (cmax > nmax)
      cs = cs + nsum * expf(nmax - cmax);
    else {
      cs = nsum + cs * expf(cmax - nmax);
      cm = nmax;
    }
  }
}

__global__ void reduce_max(const float *__restrict__ input,
                           float *__restrict__ tmax,
                           float *__restrict__ tsum,
                           int n) {
  int tid = threadIdx.x;
  int idx = tid + blockIdx.x * blockDim.x;
  int stride = blockDim.x * gridDim.x;

  extern __shared__ float cache[];
  float *sumcache = cache + blockDim.x;

  float lmax = idx < n ? input[idx] : -INFINITY;
  float sum = idx < n ? 1.0f : 0.0f;

  for (idx += stride; idx < n; idx += stride) {
    float nextval = input[idx], nextsum = 1.0f;
    if (lmax > nextval)
      sum = sum + expf(nextval - lmax);
    else {
      sum = nextsum + sum * expf(lmax - nextval);
      lmax = nextval;
    }
  }
  cache[tid] = lmax;
  sumcache[tid] = sum;
  __syncthreads();

  for (int i = blockDim.x / 2; i > 0; i >>= 1) {
    float cur = cache[tid], next = cache[tid + i];
    float cursum = sumcache[tid], nextsum = sumcache[tid + i];
    if (tid < i)
      safe_merge(cur, next, cursum, nextsum, cache[tid], sumcache[tid]);

    __syncthreads();
  }

  if (tid == 0)
    tmax[blockIdx.x] = cache[0], tsum[blockIdx.x] = sumcache[0];
}

__global__ void reduce_blocks(const float *__restrict__ max,
                              const float *__restrict__ sum,
                              float *maxval,
                              float *sumval,
                              int num) {
  // use one block.
  extern __shared__ float tmax[];
  float *tsum = tmax + blockDim.x;

  int tid = threadIdx.x;

  float mx = tid < num ? max[tid] : -INFINITY;
  float sm = tid < num ? sum[tid] : 0.0f;

  int stride = blockDim.x * gridDim.x;
  for (int i = tid + stride; i < num; i += stride) {
    float nextmax = max[i], nextsum = sum[i];
    safe_merge(mx, nextmax, sm, nextsum, mx, sm);
  }
  tmax[tid] = mx, tsum[tid] = sm;
  __syncthreads();

  for (int i = blockDim.x / 2; i > 0; i >>= 1) {
    if (tid < i) {
      float cur = tmax[tid], next = tmax[tid + i];
      float cursum = tsum[tid], nextsum = tsum[tid + i];

      safe_merge(cur, next, cursum, nextsum, tmax[tid], tsum[tid]);
    }

    __syncthreads();
  }

  *maxval = tmax[0];
  *sumval = tsum[0];
}

__global__ void normalize(const float *input,
                          float *output,
                          float *max,
                          float *sum,
                          int n) {
  int idx = threadIdx.x + blockDim.x * blockIdx.x;
  int stride = blockDim.x * gridDim.x;

  float mx = *max, sm = *sum;
  for (; idx < n; idx += stride)
    output[idx] = expf(input[idx] - mx) / sm;
}

void distribute(const float *input, float *output, int n) {
  int threadsPerBlock = 256;
  int blocksPerGrid = (n + threadsPerBlock - 1) / threadsPerBlock;

  size_t bytes_size = blocksPerGrid * sizeof(float);
  float *tmax, *tsum;

  cudaMalloc(&tmax, bytes_size);
  cudaMalloc(&tsum, bytes_size);

  size_t allocPerBlock = threadsPerBlock * sizeof(float) * 2;
  reduce_max<<<blocksPerGrid, threadsPerBlock, allocPerBlock>>>(input, tmax, tsum, n);

  float *maxval, *sumval;
  cudaMalloc(&maxval, sizeof(float));
  cudaMalloc(&sumval, sizeof(float));
  size_t required = threadsPerBlock * sizeof(float) * 2;
  reduce_blocks<<<1, threadsPerBlock, required>>>(tmax, tsum, maxval, sumval, blocksPerGrid);

  normalize<<<blocksPerGrid, threadsPerBlock>>>(input, output, maxval, sumval, n);
}