CUDA 编写 softmax 算子
Shared Memory + Tree Reduction 算子实现思路
我的实现大致分 $3$ 步:
对原数组进行 block tiling,在每一个 block 上,用 Tree-Reduction 计算其负责的 $max_i, sum_i$,这里的 ${sum}_i$ 指的是这个 block 内的 safe softmax 之和,即
$$sum_i=\sum_{k=l_i}^{r_i}e^{x_k-max_i}$$
保存在 shared memory 里
再用 $1$ 个 block 对保存在 shared memory 上的
max_cache,sum_cache进行归约,求出全局的最大值和 sum of safe softmax利用计算好的 global max 和 sum,对 input 进行 transform.
safe_merge() 辅助函数
由于数组的长度并不一定是 $2$ 的整数次幂,所以在用 shared memory 进行 tree reduction 的时候,可能会发生无法完美填进 $2^k$。所以这个时候,我们需要对填不满的 shared memory array 里填 -INFINITY,由于需要编写特殊的逻辑来处理 -INFINITY 的加法和 safe softmax.
这里的设计是,cmax, nmax 表示两个 $max$,而 csum, nsum 表示两个 sum of safe softmax. 这里,safe softmax 的更新思路是 online softmax.
__device__ void safe_merge(float cmax,
float nmax,
float csum,
float nsum,
float &cm,
float &cs) {
if (nmax == NINF)
cm = cmax, cs = csum;
else if (cmax == NINF)
cm = nmax, cs = nsum;
else {
if (cmax > nmax)
cs = cs + nsum * expf(nmax - cmax);
else {
cs = nsum + cs * expf(cmax - nmax);
cm = nmax;
}
}
}
整体代码
#include <cuda_runtime.h>
#define NINF -INFINITY
__device__ void safe_merge(float cmax,
float nmax,
float csum,
float nsum,
float &cm,
float &cs) {
if (nmax == NINF)
cm = cmax, cs = csum;
else if (cmax == NINF)
cm = nmax, cs = nsum;
else {
if (cmax > nmax)
cs = cs + nsum * expf(nmax - cmax);
else {
cs = nsum + cs * expf(cmax - nmax);
cm = nmax;
}
}
}
__global__ void reduce_max(const float *__restrict__ input,
float *__restrict__ tmax,
float *__restrict__ tsum,
int n) {
int tid = threadIdx.x;
int idx = tid + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
extern __shared__ float cache[];
float *sumcache = cache + blockDim.x;
float lmax = idx < n ? input[idx] : -INFINITY;
float sum = idx < n ? 1.0f : 0.0f;
for (idx += stride; idx < n; idx += stride) {
float nextval = input[idx], nextsum = 1.0f;
if (lmax > nextval)
sum = sum + expf(nextval - lmax);
else {
sum = nextsum + sum * expf(lmax - nextval);
lmax = nextval;
}
}
cache[tid] = lmax;
sumcache[tid] = sum;
__syncthreads();
for (int i = blockDim.x / 2; i > 0; i >>= 1) {
float cur = cache[tid], next = cache[tid + i];
float cursum = sumcache[tid], nextsum = sumcache[tid + i];
if (tid < i)
safe_merge(cur, next, cursum, nextsum, cache[tid], sumcache[tid]);
__syncthreads();
}
if (tid == 0)
tmax[blockIdx.x] = cache[0], tsum[blockIdx.x] = sumcache[0];
}
__global__ void reduce_blocks(const float *__restrict__ max,
const float *__restrict__ sum,
float *maxval,
float *sumval,
int num) {
// use one block.
extern __shared__ float tmax[];
float *tsum = tmax + blockDim.x;
int tid = threadIdx.x;
float mx = tid < num ? max[tid] : -INFINITY;
float sm = tid < num ? sum[tid] : 0.0f;
int stride = blockDim.x * gridDim.x;
for (int i = tid + stride; i < num; i += stride) {
float nextmax = max[i], nextsum = sum[i];
safe_merge(mx, nextmax, sm, nextsum, mx, sm);
}
tmax[tid] = mx, tsum[tid] = sm;
__syncthreads();
for (int i = blockDim.x / 2; i > 0; i >>= 1) {
if (tid < i) {
float cur = tmax[tid], next = tmax[tid + i];
float cursum = tsum[tid], nextsum = tsum[tid + i];
safe_merge(cur, next, cursum, nextsum, tmax[tid], tsum[tid]);
}
__syncthreads();
}
*maxval = tmax[0];
*sumval = tsum[0];
}
__global__ void normalize(const float *input,
float *output,
float *max,
float *sum,
int n) {
int idx = threadIdx.x + blockDim.x * blockIdx.x;
int stride = blockDim.x * gridDim.x;
float mx = *max, sm = *sum;
for (; idx < n; idx += stride)
output[idx] = expf(input[idx] - mx) / sm;
}
void distribute(const float *input, float *output, int n) {
int threadsPerBlock = 256;
int blocksPerGrid = (n + threadsPerBlock - 1) / threadsPerBlock;
size_t bytes_size = blocksPerGrid * sizeof(float);
float *tmax, *tsum;
cudaMalloc(&tmax, bytes_size);
cudaMalloc(&tsum, bytes_size);
size_t allocPerBlock = threadsPerBlock * sizeof(float) * 2;
reduce_max<<<blocksPerGrid, threadsPerBlock, allocPerBlock>>>(input, tmax, tsum, n);
float *maxval, *sumval;
cudaMalloc(&maxval, sizeof(float));
cudaMalloc(&sumval, sizeof(float));
size_t required = threadsPerBlock * sizeof(float) * 2;
reduce_blocks<<<1, threadsPerBlock, required>>>(tmax, tsum, maxval, sumval, blocksPerGrid);
normalize<<<blocksPerGrid, threadsPerBlock>>>(input, output, maxval, sumval, n);
}