-
Notifications
You must be signed in to change notification settings - Fork 166
/
dot_product.cu
284 lines (261 loc) · 14.3 KB
/
dot_product.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
#include <stdio.h>
#include <stdlib.h>
#include <float.h>
#include <vector>
#include <algorithm>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <torch/types.h>
#include <torch/extension.h>
#define WARP_SIZE 32
#define INT4(value) (reinterpret_cast<int4*>(&(value))[0])
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])
// -------------------------------------- FP32 --------------------------------------
// Warp Reduce Sum
template<const int kWarpSize = WARP_SIZE>
__device__ __forceinline__ float warp_reduce_sum_f32(float val) {
#pragma unroll
for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) {
val += __shfl_xor_sync(0xffffffff, val, mask);
}
return val;
}
// Dot Product
// grid(N/256), block(256)
// a: Nx1, b: Nx1, y=sum(elementwise_mul(a,b))
template<const int NUM_THREADS = 256>
__global__ void dot_prod_f32_f32_kernel(float* a, float* b, float* y, int N) {
int tid = threadIdx.x;
int idx = blockIdx.x * NUM_THREADS + tid;
constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;
__shared__ float reduce_smem[NUM_WARPS];
// keep the data in register is enougth for warp operaion.
float prod = (idx < N) ? a[idx] * b[idx] : 0.0f;
int warp = tid / WARP_SIZE;
int lane = tid % WARP_SIZE;
// perform warp sync reduce.
prod = warp_reduce_sum_f32<WARP_SIZE>(prod);
// warp leaders store the data to shared memory.
if (lane == 0) reduce_smem[warp] = prod;
__syncthreads(); // make sure the data is in shared memory.
// the first warp compute the final sum.
prod = (lane < NUM_WARPS) ? reduce_smem[lane] : 0.0f;
if (warp == 0) prod = warp_reduce_sum_f32<NUM_WARPS>(prod);
if (tid == 0) atomicAdd(y, prod);
}
// Dot Product + Vec4
// grid(N/256), block(256/4)
// a: Nx1, b: Nx1, y=sum(elementwise_mul(a,b))
template<const int NUM_THREADS = 256/4>
__global__ void dot_prod_f32x4_f32_kernel(float* a, float* b, float* y, int N) {
int tid = threadIdx.x;
int idx = (blockIdx.x * NUM_THREADS + tid) * 4;
constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;
__shared__ float reduce_smem[NUM_WARPS];
float4 reg_a = FLOAT4(a[idx]);
float4 reg_b = FLOAT4(b[idx]);
float prod = (idx < N) ? (reg_a.x * reg_b.x + reg_a.y * reg_b.y
+ reg_a.z * reg_b.z + reg_a.w * reg_b.w) : 0.0f;
int warp = tid / WARP_SIZE;
int lane = tid % WARP_SIZE;
// perform warp sync reduce.
prod = warp_reduce_sum_f32<WARP_SIZE>(prod);
// warp leaders store the data to shared memory.
if (lane == 0) reduce_smem[warp] = prod;
__syncthreads(); // make sure the data is in shared memory.
// the first warp compute the final sum.
prod = (lane < NUM_WARPS) ? reduce_smem[lane] : 0.0f;
if (warp == 0) prod = warp_reduce_sum_f32<NUM_WARPS>(prod);
if (tid == 0) atomicAdd(y, prod);
}
// -------------------------------------- FP16 --------------------------------------
// Warp Reduce Sum: Half
template<const int kWarpSize = WARP_SIZE>
__device__ __forceinline__ half warp_reduce_sum_f16_f16(half val) {
#pragma unroll
for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) {
val = __hadd(val, __shfl_xor_sync(0xffffffff, val, mask));
// val += __shfl_xor_sync(0xffffffff, val, mask);
}
return val;
}
template<const int kWarpSize = WARP_SIZE>
__device__ __forceinline__ float warp_reduce_sum_f16_f32(half val) {
float val_f32 = __half2float(val);
#pragma unroll
for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) {
val_f32 += __shfl_xor_sync(0xffffffff, val_f32, mask);
}
return val_f32;
}
template<const int NUM_THREADS = 256>
__global__ void dot_prod_f16_f32_kernel(half* a, half* b, float* y, int N) {
int tid = threadIdx.x;
int idx = blockIdx.x * NUM_THREADS + tid;
constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;
__shared__ float reduce_smem[NUM_WARPS];
// keep the data in register is enougth for warp operaion.
half prod_f16 = (idx < N) ? __hmul(a[idx], b[idx]) : __float2half(0.0f);
int warp = tid / WARP_SIZE;
int lane = tid % WARP_SIZE;
// perform warp sync reduce.
float prod = warp_reduce_sum_f16_f32<WARP_SIZE>(prod_f16);
// warp leaders store the data to shared memory.
if (lane == 0) reduce_smem[warp] = prod;
__syncthreads(); // make sure the data is in shared memory.
// the first warp compute the final sum.
prod = (lane < NUM_WARPS) ? reduce_smem[lane] : 0.0f;
if (warp == 0) prod = warp_reduce_sum_f32<NUM_WARPS>(prod);
if (tid == 0) atomicAdd(y, prod);
}
template<const int NUM_THREADS = 256/2>
__global__ void dot_prod_f16x2_f32_kernel(half* a, half* b, float* y, int N) {
int tid = threadIdx.x;
int idx = (blockIdx.x * NUM_THREADS + tid) * 2; // 2 half elements per thread
constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;
__shared__ float reduce_smem[NUM_WARPS];
// keep the data in register is enougth for warp operaion.
half2 reg_a = HALF2(a[idx]);
half2 reg_b = HALF2(b[idx]);
half prod_f16 = (idx < N) ? __hadd(__hmul(reg_a.x, reg_b.x),
__hmul(reg_a.y, reg_b.y)) : __float2half(0.0f);
int warp = tid / WARP_SIZE;
int lane = tid % WARP_SIZE;
// perform warp sync reduce.
float prod = warp_reduce_sum_f16_f32<WARP_SIZE>(prod_f16);
// warp leaders store the data to shared memory.
if (lane == 0) reduce_smem[warp] = prod;
__syncthreads(); // make sure the data is in shared memory.
// the first warp compute the final sum.
prod = (lane < NUM_WARPS) ? reduce_smem[lane] : 0.0f;
if (warp == 0) prod = warp_reduce_sum_f32<NUM_WARPS>(prod);
if (tid == 0) atomicAdd(y, prod);
}
template<const int NUM_THREADS = 256/8>
__global__ void dot_prod_f16x8_pack_f32_kernel(half* a, half* b, float* y, int N) {
int tid = threadIdx.x;
int idx = (blockIdx.x * NUM_THREADS + tid) * 8; // 8 half elements per thread
constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;
__shared__ float reduce_smem[NUM_WARPS];
// temporary register(memory), .local space in ptx, addressable
half pack_a[8], pack_b[8]; // 8x16 bits=128 bits.
LDST128BITS(pack_a[0]) = LDST128BITS(a[idx]); // load 128 bits
LDST128BITS(pack_b[0]) = LDST128BITS(b[idx]); // load 128 bits
const half z = __float2half(0.0f);
half prod_f16 = z;
#pragma unroll
for (int i = 0; i < 8; i += 2) {
half2 v = __hmul2(HALF2(pack_a[i]), HALF2(pack_b[i]));
prod_f16 += (((idx + i ) < N) ? (v.x + v.y) : z);
}
int warp = tid / WARP_SIZE;
int lane = tid % WARP_SIZE;
// perform warp sync reduce.
float prod = warp_reduce_sum_f16_f32<WARP_SIZE>(prod_f16);
// warp leaders store the data to shared memory.
if (lane == 0) reduce_smem[warp] = prod;
__syncthreads(); // make sure the data is in shared memory.
// the first warp compute the final sum.
prod = (lane < NUM_WARPS) ? reduce_smem[lane] : 0.0f;
if (warp == 0) prod = warp_reduce_sum_f32<NUM_WARPS>(prod);
if (tid == 0) atomicAdd(y, prod);
}
// --------------------- PyTorch bindings for custom kernel -----------------------
#define STRINGFY(str) #str
#define TORCH_BINDING_COMMON_EXTENSION(func) \
m.def(STRINGFY(func), &func, STRINGFY(func));
#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \
if(((T).options().dtype() != (th_type))) { \
std::cout << "Tensor Info:" << (T).options() << std::endl; \
throw std::runtime_error("values must be "#th_type); \
}
#define LANUCH_DOT_PROD_KERNEL(NT, packed_type, acc_type, element_type) \
dot_prod_##packed_type##_##acc_type##_kernel<(NT)><<<grid, block>>>( \
reinterpret_cast<element_type*>(a.data_ptr()), \
reinterpret_cast<element_type*>(b.data_ptr()), \
prod.data_ptr<float>(), N);
#define DISPATCH_DOT_PROD_KERNEL(K, packed_type, acc_type, element_type, n_elements) \
const int NT = (K)/(n_elements); \
dim3 block(NT); \
dim3 grid((S)); \
switch (NT) \
{ \
case 32: \
LANUCH_DOT_PROD_KERNEL(32, packed_type, acc_type, element_type) \
break; \
case 64: \
LANUCH_DOT_PROD_KERNEL(64, packed_type, acc_type, element_type) \
break; \
case 128: \
LANUCH_DOT_PROD_KERNEL(128, packed_type, acc_type, element_type) \
break; \
case 256: \
LANUCH_DOT_PROD_KERNEL(256, packed_type, acc_type, element_type) \
break; \
case 512: \
LANUCH_DOT_PROD_KERNEL(512, packed_type, acc_type, element_type) \
break; \
case 1024: \
LANUCH_DOT_PROD_KERNEL(1024, packed_type, acc_type, element_type) \
break; \
default: \
throw std::runtime_error( \
"only support (K)/(n_elements): 32/64/128/256/512/1024"); \
break; \
}
#define TORCH_BINDING_DOT_PROD(packed_type, acc_type, th_type, element_type, n_elements) \
torch::Tensor dot_prod_##packed_type##_##acc_type(torch::Tensor a, torch::Tensor b) { \
CHECK_TORCH_TENSOR_DTYPE(a, (th_type)) \
CHECK_TORCH_TENSOR_DTYPE(b, (th_type)) \
auto options = torch::TensorOptions().dtype(torch::kFloat32).device( \
torch::kCUDA, 0); \
auto prod = torch::zeros({1}, options); \
const int ndim = a.dim(); \
if (ndim != 2) { \
int N = 1; \
for (int i = 0; i < ndim; ++i) { N *= a.size(i); } \
dim3 block(256); \
dim3 grid(((N + 256 - 1) / 256) / (n_elements)); \
dot_prod_##packed_type##_##acc_type##_kernel< \
256 ><<<grid, block>>>( \
reinterpret_cast<element_type*>(a.data_ptr()), \
reinterpret_cast<element_type*>(b.data_ptr()), \
prod.data_ptr<float>(), N); \
} else { \
const int S = a.size(0); \
const int K = a.size(1); \
const int N = S * K; \
if ((K/(n_elements)) <= 1024) { \
DISPATCH_DOT_PROD_KERNEL(K, packed_type, acc_type, element_type, n_elements) \
} else { \
int N = 1; \
for (int i = 0; i < ndim; ++i) { N *= a.size(i); } \
dim3 block(256); \
dim3 grid(((N + 256 - 1) / 256) / (n_elements)); \
dot_prod_##packed_type##_##acc_type##_kernel< \
256 ><<<grid, block>>>( \
reinterpret_cast<element_type*>(a.data_ptr()), \
reinterpret_cast<element_type*>(b.data_ptr()), \
prod.data_ptr<float>(), N); \
} \
} \
return prod; \
}
// packed_type, acc_type, th_type, element_type, n_elements_per_pack
TORCH_BINDING_DOT_PROD(f32, f32, torch::kFloat32, float, 1)
TORCH_BINDING_DOT_PROD(f32x4, f32, torch::kFloat32, float, 4)
TORCH_BINDING_DOT_PROD(f16, f32, torch::kHalf, half, 1)
TORCH_BINDING_DOT_PROD(f16x2, f32, torch::kHalf, half, 2)
TORCH_BINDING_DOT_PROD(f16x8_pack, f32, torch::kHalf, half, 8)
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
TORCH_BINDING_COMMON_EXTENSION(dot_prod_f32_f32)
TORCH_BINDING_COMMON_EXTENSION(dot_prod_f32x4_f32)
TORCH_BINDING_COMMON_EXTENSION(dot_prod_f16_f32)
TORCH_BINDING_COMMON_EXTENSION(dot_prod_f16x2_f32)
TORCH_BINDING_COMMON_EXTENSION(dot_prod_f16x8_pack_f32)
}