Q,K Matrices Quantization

The int8 quantization of Q,KQ,K is located in csrc/fused/fused.cu.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
template <
uint32_t head_dim,
uint32_t BLOCK_SIZE,
uint32_t num_pack_per_thread = 1,
bool has_sm_scale = false,
bool sub_mean = false,
typename T>
__global__ void QuantInt8Kernel(
T *__restrict__ input,
T *__restrict__ mean,
int8_t *__restrict__ output,
float *__restrict__ scale,
float sm_scale,
const uint32_t num_tokens,
const uint32_t stride_bz_input, const uint32_t stride_seq_input,
const uint32_t stride_h_input, const uint32_t stride_bz_mean,
const uint32_t stride_h_mean, const uint32_t stride_bz_output,
const uint32_t stride_seq_output, const uint32_t stride_h_output,
const uint32_t stride_bz_scale, const uint32_t stride_h_scale) {
static_assert(std::is_same<T, half>::value ||
std::is_same<T, nv_bfloat16>::value,
"Only half and bfloat16 are supported");
static_assert(num_pack_per_thread > 0,
"The number of pack per thread must be greater than 0");

constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16
constexpr uint32_t num_threads_per_token = head_dim / pack_size;

static_assert(num_threads_per_token <= 32,
"The number of threads per token must be less than or equal to "
"warp size");

T x_val[num_pack_per_thread][8];
T mean_val[8];
float x_val_float[num_pack_per_thread][8];
float mean_val_float[8];

uint32_t bx = blockIdx.x;
uint32_t head_id = blockIdx.y;
uint32_t batch_id = blockIdx.z;
uint32_t thread_id = threadIdx.x;

uint32_t thread_base_token =
bx * BLOCK_SIZE + thread_id / num_threads_per_token;
T *input_ptr_base = input + batch_id * stride_bz_input +
head_id * stride_h_input +
thread_base_token * stride_seq_input +
thread_id % num_threads_per_token * pack_size;
T *mean_ptr_base = mean + batch_id * stride_bz_mean +
head_id * stride_h_mean +
thread_id % num_threads_per_token * pack_size;
int8_t *output_ptr_base = output + batch_id * stride_bz_output +
head_id * stride_h_output +
thread_base_token * stride_seq_output +
thread_id % num_threads_per_token * pack_size;
float *scale_ptr_base =
scale + batch_id * stride_bz_scale + head_id * stride_h_scale + bx;

if constexpr (sub_mean) {
*(float4 *)(&mean_val[0]) = *(float4 *)(mean_ptr_base);
#pragma unroll
for (uint32_t j = 0; j < 8; j++) {
mean_val_float[j] = convert_to_float(mean_val[j]);
}
}

constexpr uint32_t iter_stride = BLOCK_SIZE / num_pack_per_thread;

// load the data
for (uint32_t i = 0; i < num_pack_per_thread; i++) {
if (thread_base_token + i * iter_stride < num_tokens) {
*(float4 *)(&x_val[i][0]) =
*(float4 *)(input_ptr_base + i * iter_stride * stride_seq_input);
#pragma unroll
for (uint32_t j = 0; j < 8; j++) {
x_val_float[i][j] = convert_to_float(x_val[i][j]);
}

if constexpr (sub_mean) {
#pragma unroll
for (uint32_t j = 0; j < 8; j++) {
x_val_float[i][j] -= mean_val_float[j];
}
}

if constexpr (has_sm_scale) {
#pragma unroll
for (uint32_t j = 0; j < 8; j++) {
x_val_float[i][j] *= sm_scale;
}
}
} else {
#pragma unroll
for (uint32_t j = 0; j < 8; j++) {
x_val_float[i][j] = 0.0f;
}
}
}

float amax_val = 0.0000001f; // prevent from dividing by zero

#pragma unroll
for (uint32_t i = 0; i < num_pack_per_thread; i++) {
#pragma unroll
for (uint32_t j = 0; j < 8; j++) {
amax_val = fmaxf(amax_val, fabsf(x_val_float[i][j]));
}
}

__shared__ float s_amax;
const float block_amax_val = vllm::blockReduceMax(amax_val);
if (thread_id == 0) {
s_amax = block_amax_val;
scale_ptr_base[0] = s_amax / 127.0f;
}

__syncthreads();

float tmp_scale = 127.0f / s_amax;

char4 o_val[num_pack_per_thread][2];

#pragma unroll
for (uint32_t i = 0; i < num_pack_per_thread; i++) {
#pragma unroll
for (uint32_t j = 0; j < 2; j += 1) {
o_val[i][j] =
make_char4(float_to_int8_rn(x_val_float[i][j * 4 + 0] * tmp_scale),
float_to_int8_rn(x_val_float[i][j * 4 + 1] * tmp_scale),
float_to_int8_rn(x_val_float[i][j * 4 + 2] * tmp_scale),
float_to_int8_rn(x_val_float[i][j * 4 + 3] * tmp_scale));
}
}

// int8 result
#pragma unroll
for (uint32_t i = 0; i < num_pack_per_thread; i++) {

if (thread_base_token + i * iter_stride < num_tokens) {
*reinterpret_cast<float2 *>(output_ptr_base +
i * iter_stride * stride_seq_output) =
*reinterpret_cast<float2 *>(&o_val[i][0]);
}
}
}