1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
| template < uint32_t head_dim, uint32_t BLOCK_SIZE, uint32_t num_pack_per_thread = 1, bool has_sm_scale = false, bool sub_mean = false, typename T> __global__ void QuantInt8Kernel( T *__restrict__ input, T *__restrict__ mean, int8_t *__restrict__ output, float *__restrict__ scale, float sm_scale, const uint32_t num_tokens, const uint32_t stride_bz_input, const uint32_t stride_seq_input, const uint32_t stride_h_input, const uint32_t stride_bz_mean, const uint32_t stride_h_mean, const uint32_t stride_bz_output, const uint32_t stride_seq_output, const uint32_t stride_h_output, const uint32_t stride_bz_scale, const uint32_t stride_h_scale) { static_assert(std::is_same<T, half>::value || std::is_same<T, nv_bfloat16>::value, "Only half and bfloat16 are supported"); static_assert(num_pack_per_thread > 0, "The number of pack per thread must be greater than 0");
constexpr uint32_t pack_size = 8; constexpr uint32_t num_threads_per_token = head_dim / pack_size;
static_assert(num_threads_per_token <= 32, "The number of threads per token must be less than or equal to " "warp size");
T x_val[num_pack_per_thread][8]; T mean_val[8]; float x_val_float[num_pack_per_thread][8]; float mean_val_float[8];
uint32_t bx = blockIdx.x; uint32_t head_id = blockIdx.y; uint32_t batch_id = blockIdx.z; uint32_t thread_id = threadIdx.x;
uint32_t thread_base_token = bx * BLOCK_SIZE + thread_id / num_threads_per_token; T *input_ptr_base = input + batch_id * stride_bz_input + head_id * stride_h_input + thread_base_token * stride_seq_input + thread_id % num_threads_per_token * pack_size; T *mean_ptr_base = mean + batch_id * stride_bz_mean + head_id * stride_h_mean + thread_id % num_threads_per_token * pack_size; int8_t *output_ptr_base = output + batch_id * stride_bz_output + head_id * stride_h_output + thread_base_token * stride_seq_output + thread_id % num_threads_per_token * pack_size; float *scale_ptr_base = scale + batch_id * stride_bz_scale + head_id * stride_h_scale + bx;
if constexpr (sub_mean) { *(float4 *)(&mean_val[0]) = *(float4 *)(mean_ptr_base); #pragma unroll for (uint32_t j = 0; j < 8; j++) { mean_val_float[j] = convert_to_float(mean_val[j]); } }
constexpr uint32_t iter_stride = BLOCK_SIZE / num_pack_per_thread;
for (uint32_t i = 0; i < num_pack_per_thread; i++) { if (thread_base_token + i * iter_stride < num_tokens) { *(float4 *)(&x_val[i][0]) = *(float4 *)(input_ptr_base + i * iter_stride * stride_seq_input); #pragma unroll for (uint32_t j = 0; j < 8; j++) { x_val_float[i][j] = convert_to_float(x_val[i][j]); }
if constexpr (sub_mean) { #pragma unroll for (uint32_t j = 0; j < 8; j++) { x_val_float[i][j] -= mean_val_float[j]; } }
if constexpr (has_sm_scale) { #pragma unroll for (uint32_t j = 0; j < 8; j++) { x_val_float[i][j] *= sm_scale; } } } else { #pragma unroll for (uint32_t j = 0; j < 8; j++) { x_val_float[i][j] = 0.0f; } } }
float amax_val = 0.0000001f;
#pragma unroll for (uint32_t i = 0; i < num_pack_per_thread; i++) { #pragma unroll for (uint32_t j = 0; j < 8; j++) { amax_val = fmaxf(amax_val, fabsf(x_val_float[i][j])); } }
__shared__ float s_amax; const float block_amax_val = vllm::blockReduceMax(amax_val); if (thread_id == 0) { s_amax = block_amax_val; scale_ptr_base[0] = s_amax / 127.0f; }
__syncthreads();
float tmp_scale = 127.0f / s_amax;
char4 o_val[num_pack_per_thread][2];
#pragma unroll for (uint32_t i = 0; i < num_pack_per_thread; i++) { #pragma unroll for (uint32_t j = 0; j < 2; j += 1) { o_val[i][j] = make_char4(float_to_int8_rn(x_val_float[i][j * 4 + 0] * tmp_scale), float_to_int8_rn(x_val_float[i][j * 4 + 1] * tmp_scale), float_to_int8_rn(x_val_float[i][j * 4 + 2] * tmp_scale), float_to_int8_rn(x_val_float[i][j * 4 + 3] * tmp_scale)); } }
#pragma unroll for (uint32_t i = 0; i < num_pack_per_thread; i++) {
if (thread_base_token + i * iter_stride < num_tokens) { *reinterpret_cast<float2 *>(output_ptr_base + i * iter_stride * stride_seq_output) = *reinterpret_cast<float2 *>(&o_val[i][0]); } } }
|