-
Notifications
You must be signed in to change notification settings - Fork 173
/
mat_transpose.cu
360 lines (335 loc) · 16.5 KB
/
mat_transpose.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
#include <stdio.h>
#include <stdlib.h>
#include <float.h>
#include <vector>
#include <algorithm>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <torch/types.h>
#include <torch/extension.h>
#define WARP_SIZE 256
#define WARP_SIZE_S 16
#define PAD 1
#define INT4(value) (reinterpret_cast<int4 *>(&(value))[0])
#define FLOAT4(value) (reinterpret_cast<float4 *>(&(value))[0])
#define HALF2(value) (reinterpret_cast<half2 *>(&(value))[0])
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162 *>(&(value))[0])
#define LDST128BITS(value) (reinterpret_cast<float4 *>(&(value))[0])
#define MAX_EXP_F32 88.3762626647949f
#define MIN_EXP_F32 -88.3762626647949f
#define MAX_EXP_F16 __float2half(11.089866488461016f)
#define MIN_EXP_F16 __float2half(-9.704060527839234f)
// -------------------------------------- FP32 --------------------------------------
// col2row means read x[row][col] and write y[col][row]
// row2col means read x[col][row] and write y[row][col]
__global__ void mat_transpose_f32_col2row_kernel(
float *x, float *y, const int row, const int col) {
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int global_row = global_idx / col;
const int global_col = global_idx % col;
if (global_idx < row * col) {
y[global_col * row + global_row] = x[global_idx];
}
}
__global__ void mat_transpose_f32_row2col_kernel(
float *x, float *y, const int row, const int col) {
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int global_col = global_idx / row;
const int global_row = global_idx % row;
if (global_idx < row * col) {
y[global_idx] = x[global_row * col + global_col];
}
}
__global__ void mat_transpose_f32x4_col2row_kernel(
float *x, float *y, const int row, const int col) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
int global_col = (global_idx * 4) % col;
int global_row = (global_idx * 4) / col;
if (global_row < row && global_col + 3 < col) {
float4 x_val = reinterpret_cast<float4 *>(x)[global_idx];
y[global_col * row + global_row] = x_val.x;
y[(global_col + 1) * row + global_row] = x_val.y;
y[(global_col + 2) * row + global_row] = x_val.z;
y[(global_col + 3) * row + global_row] = x_val.w;
}
}
__global__ void mat_transpose_f32x4_row2col_kernel(
float *x, float *y, const int row, const int col) {
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int global_col = (global_idx * 4) / row;
const int global_row = (global_idx * 4) % row;
if (global_row < row && global_col < col) {
float4 x_val;
x_val.x = x[global_row * col + global_col];
x_val.y = x[(global_row + 1) * col + global_col];
x_val.z = x[(global_row + 2) * col + global_col];
x_val.w = x[(global_row + 3) * col + global_col];
reinterpret_cast<float4 *>(y)[global_idx] = FLOAT4(x_val);
}
}
// work for row == col
__global__ void mat_transpose_f32_diagonal2d_kernel(
float *x, float *y, int row, int col) {
const int block_y = blockIdx.x;
const int block_x = (blockIdx.x + blockIdx.y) % gridDim.x;
const int global_col = threadIdx.x + blockDim.x * block_x;
const int global_row = threadIdx.y + blockDim.y * block_y;
if (global_col < col && global_row < row) {
y[global_row * col + global_col] = x[global_col * row + global_row];
}
}
__global__ void mat_transpose_f32_col2row2d_kernel(
float *x, float *y, const int row, const int col) {
const int global_x = blockIdx.x * blockDim.x + threadIdx.x;
const int global_y = blockIdx.y * blockDim.y + threadIdx.y;
if (global_x < col && global_y < row) {
y[global_x * row + global_y] = x[global_y * col + global_x];
}
}
__global__ void mat_transpose_f32_row2col2d_kernel(
float *x, float *y, const int row, const int col) {
const int global_y = blockIdx.x * blockDim.x + threadIdx.x;
const int global_x = blockIdx.y * blockDim.y + threadIdx.y;
if (global_y < col && global_x < row) {
y[global_y * row + global_x] = x[global_x * col + global_y];
}
}
__global__ void mat_transpose_f32x4_col2row2d_kernel(
float *x, float *y, const int row, const int col) {
const int global_x = blockIdx.x * blockDim.x + threadIdx.x;
const int global_y = blockIdx.y * blockDim.y + threadIdx.y;
if (global_x * 4 + 3 < col && global_y < row) {
float4 x_val = reinterpret_cast<float4 *>(x)[global_y * col / 4 + global_x];
y[(global_x * 4) * row + global_y] = x_val.x;
y[(global_x * 4 + 1) * row + global_y] = x_val.y;
y[(global_x * 4 + 2) * row + global_y] = x_val.z;
y[(global_x * 4 + 3) * row + global_y] = x_val.w;
}
}
__global__ void mat_transpose_f32x4_row2col2d_kernel(
float *x, float *y, const int row, const int col) {
const int global_x = blockIdx.x * blockDim.x + threadIdx.x;
const int global_y = blockIdx.y * blockDim.y + threadIdx.y;
if (global_y * 4 + 3 < row && global_x < col) {
float4 x_val;
x_val.x = x[(global_y * 4) * col + global_x];
x_val.y = x[(global_y * 4 + 1) * col + global_x];
x_val.z = x[(global_y * 4 + 2) * col + global_x];
x_val.w = x[(global_y * 4 + 3) * col + global_x];
reinterpret_cast<float4 *>(y)[global_x * row / 4 + global_y] = FLOAT4(x_val);
}
}
__global__ void mat_transpose_f32x4_shared_col2row2d_kernel(
float *x, float *y, const int row, const int col){
const int global_x = blockIdx.x * blockDim.x + threadIdx.x;
const int global_y = blockIdx.y * blockDim.y + threadIdx.y;
const int local_x = threadIdx.x;
const int local_y = threadIdx.y;
__shared__ float tile[WARP_SIZE_S][WARP_SIZE_S * 4];
if(global_x * 4 + 3 < col + 3 && global_y < row) {
// load value from x to shared memory
float4 x_val = reinterpret_cast<float4*>(x)[global_y * col / 4 + global_x];
FLOAT4(tile[local_y][local_x * 4]) = FLOAT4(x_val);
__syncthreads();
float4 smem_val;
// load value from shared memory to y.
// add STRIDE to satisfied different block size.
constexpr int STRIDE = WARP_SIZE_S / 4;
smem_val.x = tile[(local_y % STRIDE) * 4 ][local_x * 4 + local_y / STRIDE];
smem_val.y = tile[(local_y % STRIDE) * 4 + 1][local_x * 4 + local_y / STRIDE];
smem_val.z = tile[(local_y % STRIDE) * 4 + 2][local_x * 4 + local_y / STRIDE];
smem_val.w = tile[(local_y % STRIDE) * 4 + 3][local_x * 4 + local_y / STRIDE];
//map index n*n to (n/4)*(n*4)
const int bid_y = blockIdx.y * blockDim.y;
const int out_y = global_x * 4 + local_y / STRIDE;
const int out_x = (local_y % STRIDE) * 4 + bid_y;
reinterpret_cast<float4*>(y)[(out_y * row + out_x) / 4] = FLOAT4(smem_val);
}
}
__global__ void mat_transpose_f32x4_shared_row2col2d_kernel(
float *x, float *y, const int row, const int col){
const int global_x = blockIdx.x * blockDim.x + threadIdx.x;
const int global_y = blockIdx.y * blockDim.y + threadIdx.y;
const int local_x = threadIdx.x;
const int local_y = threadIdx.y;
__shared__ float tile[WARP_SIZE_S * 4][WARP_SIZE_S];
if(global_y * 4 < row && global_x < col) {
// load value from x to shared memory
float4 x_val;
x_val.x = x[(global_y * 4) * col + global_x];
x_val.y = x[(global_y * 4 + 1) * col + global_x];
x_val.z = x[(global_y * 4 + 2) * col + global_x];
x_val.w = x[(global_y * 4 + 3) * col + global_x];
tile[local_y * 4 ][local_x] = x_val.x;
tile[local_y * 4 + 1][local_x] = x_val.y;
tile[local_y * 4 + 2][local_x] = x_val.z;
tile[local_y * 4 + 3][local_x] = x_val.w;
__syncthreads();
float4 smem_val;
// load value from shared memory to y.
// add STRIDE to satisfied different block size.
//map index n*n to (n/4)*(n*4)
constexpr int STRIDE = WARP_SIZE_S / 4;
smem_val.x = tile[local_x * 4 + local_y / STRIDE][(local_y % STRIDE) * 4];
smem_val.y = tile[local_x * 4 + local_y / STRIDE][(local_y % STRIDE) * 4 + 1];
smem_val.z = tile[local_x * 4 + local_y / STRIDE][(local_y % STRIDE) * 4 + 2];
smem_val.w = tile[local_x * 4 + local_y / STRIDE][(local_y % STRIDE) * 4 + 3];
const int bid_x = blockIdx.x * blockDim.x;
const int bid_y = blockIdx.y * blockDim.y;
const int out_y = bid_x + (local_y % STRIDE) * 4;
const int out_x = bid_y * 4 + local_x * 4 + (local_y / STRIDE);
y[out_y * row + out_x] = smem_val.x;
y[(out_y + 1) * row + out_x] = smem_val.y;
y[(out_y + 2) * row + out_x] = smem_val.z;
y[(out_y + 3) * row + out_x] = smem_val.w;
}
}
__global__ void mat_transpose_f32x4_shared_bcf_col2row2d_kernel(
float *x, float *y, const int row, const int col){
const int global_x = blockIdx.x * blockDim.x + threadIdx.x;
const int global_y = blockIdx.y * blockDim.y + threadIdx.y;
const int local_x = threadIdx.x;
const int local_y = threadIdx.y;
__shared__ float tile[WARP_SIZE_S][WARP_SIZE_S * 4 + PAD];
if(global_x * 4 + 3 < col + 3 && global_y < row) {
// load value from x to shared memory
float4 x_val = reinterpret_cast<float4*>(x)[global_y * col / 4 + global_x];
tile[local_y][local_x * 4 ] = x_val.x;
tile[local_y][local_x * 4 + 1] = x_val.y;
tile[local_y][local_x * 4 + 2] = x_val.z;
tile[local_y][local_x * 4 + 3] = x_val.w;
__syncthreads();
float4 smem_val;
// load value from shared memory to y.
// add STRIDE to satisfied different block size.
constexpr int STRIDE = WARP_SIZE_S / 4;
smem_val.x = tile[(local_y % STRIDE) * 4 ][local_x * 4 + local_y / STRIDE];
smem_val.y = tile[(local_y % STRIDE) * 4 + 1][local_x * 4 + local_y / STRIDE];
smem_val.z = tile[(local_y % STRIDE) * 4 + 2][local_x * 4 + local_y / STRIDE];
smem_val.w = tile[(local_y % STRIDE) * 4 + 3][local_x * 4 + local_y / STRIDE];
//map index n*n to (n/4)*(n*4)
const int bid_y = blockIdx.y * blockDim.y;
const int out_y = global_x * 4 + local_y / STRIDE;
const int out_x = (local_y % STRIDE) * 4 + bid_y;
reinterpret_cast<float4*>(y)[(out_y * row + out_x) / 4] = FLOAT4(smem_val);
}
}
__global__ void mat_transpose_f32x4_shared_bcf_row2col2d_kernel(
float *x, float *y, const int row, const int col){
const int global_x = blockIdx.x * blockDim.x + threadIdx.x;
const int global_y = blockIdx.y * blockDim.y + threadIdx.y;
const int local_x = threadIdx.x;
const int local_y = threadIdx.y;
__shared__ float tile[WARP_SIZE_S * 4][WARP_SIZE_S + PAD];
if(global_y * 4 < row && global_x < col) {
// load value from x to shared memory
float4 x_val;
x_val.x = x[(global_y * 4) * col + global_x];
x_val.y = x[(global_y * 4 + 1) * col + global_x];
x_val.z = x[(global_y * 4 + 2) * col + global_x];
x_val.w = x[(global_y * 4 + 3) * col + global_x];
tile[local_y * 4 ][local_x] = x_val.x;
tile[local_y * 4 + 1][local_x] = x_val.y;
tile[local_y * 4 + 2][local_x] = x_val.z;
tile[local_y * 4 + 3][local_x] = x_val.w;
__syncthreads();
float4 smem_val;
// load value from shared memory to y.
// add STRIDE to satisfied different block size.
//map index n*n to (n/4)*(n*4)
constexpr int STRIDE = WARP_SIZE_S / 4;
smem_val.x = tile[local_x * 4 + local_y / STRIDE][(local_y % STRIDE) * 4];
smem_val.y = tile[local_x * 4 + local_y / STRIDE][(local_y % STRIDE) * 4 + 1];
smem_val.z = tile[local_x * 4 + local_y / STRIDE][(local_y % STRIDE) * 4 + 2];
smem_val.w = tile[local_x * 4 + local_y / STRIDE][(local_y % STRIDE) * 4 + 3];
const int bid_x = blockIdx.x * blockDim.x;
const int bid_y = blockIdx.y * blockDim.y;
const int out_y = bid_x + (local_y % STRIDE) * 4;
const int out_x = bid_y * 4 + local_x * 4 + (local_y / STRIDE);
y[out_y * row + out_x] = smem_val.x;
y[(out_y + 1) * row + out_x] = smem_val.y;
y[(out_y + 2) * row + out_x] = smem_val.z;
y[(out_y + 3) * row + out_x] = smem_val.w;
}
}
// TODO: may support double buffer pipeline mat transpose ?
// TODO: may support fp16 mat transpose ?
// --------------------- PyTorch bindings for custom kernel -----------------------
#define STRINGFY(str) #str
#define TORCH_BINDING_COMMON_EXTENSION(func) \
m.def(STRINGFY(func), &func, STRINGFY(func));
#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \
if (((T).options().dtype() != (th_type))) \
{ \
std::cout << "Tensor Info:" << (T).options() << std::endl; \
throw std::runtime_error("values must be " #th_type); \
}
#define TORCH_BINDING_MAT_TRANSPOSE(tag, th_type, element_type, n_pack) \
void mat_transpose_##tag(torch::Tensor x, torch::Tensor y) \
{ \
CHECK_TORCH_TENSOR_DTYPE(x, (th_type)) \
CHECK_TORCH_TENSOR_DTYPE(y, (th_type)) \
const int M = x.size(0); \
const int N = x.size(1); \
dim3 block(WARP_SIZE); \
dim3 grid(((N * M + WARP_SIZE - 1) / n_pack / WARP_SIZE)); \
mat_transpose_##tag##_kernel<<<grid, block>>>( \
reinterpret_cast<element_type *>(x.data_ptr()), \
reinterpret_cast<element_type *>(y.data_ptr()), M, N); \
}
#define TORCH_BINDING_MAT_TRANSPOSE2D(tag, th_type, element_type, n_element_row, n_element_col) \
void mat_transpose_##tag##2d(torch::Tensor x, torch::Tensor y) \
{ \
CHECK_TORCH_TENSOR_DTYPE(x, (th_type)) \
CHECK_TORCH_TENSOR_DTYPE(y, (th_type)) \
const int M = x.size(0); \
const int N = x.size(1); \
dim3 block(WARP_SIZE_S, WARP_SIZE_S); \
dim3 grid((N + WARP_SIZE_S - 1) / (WARP_SIZE_S * n_element_col), \
(M + WARP_SIZE_S - 1) / (WARP_SIZE_S / n_element_row)); \
mat_transpose_##tag##2d_kernel <<<grid, block>>>( \
reinterpret_cast<element_type *>(x.data_ptr()), \
reinterpret_cast<element_type *>(y.data_ptr()), M, N); \
}
// 1d index
TORCH_BINDING_MAT_TRANSPOSE(f32_col2row, torch::kFloat32, float, 1)
TORCH_BINDING_MAT_TRANSPOSE(f32_row2col, torch::kFloat32, float, 1)
TORCH_BINDING_MAT_TRANSPOSE(f32x4_col2row, torch::kFloat32, float, 4)
TORCH_BINDING_MAT_TRANSPOSE(f32x4_row2col, torch::kFloat32, float, 4)
// 2d index. easier for diagonal
TORCH_BINDING_MAT_TRANSPOSE2D(f32_col2row, torch::kFloat32, float, 1, 1)
TORCH_BINDING_MAT_TRANSPOSE2D(f32_row2col, torch::kFloat32, float, 1, 1)
TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_col2row, torch::kFloat32, float, 1, 4)
TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_row2col, torch::kFloat32, float, 4, 1)
// diagonal index method.
TORCH_BINDING_MAT_TRANSPOSE2D(f32_diagonal, torch::kFloat32, float, 1, 1)
// shared memory
TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_col2row, torch::kFloat32, float, 1, 4)
TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_row2col, torch::kFloat32, float, 4, 1)
// shared memory with bcf
TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_bcf_col2row, torch::kFloat32, float, 1, 4)
TORCH_BINDING_MAT_TRANSPOSE2D(f32x4_shared_bcf_row2col, torch::kFloat32, float, 4, 1)
// TODO: may support double buffer pipeline mat transpose ?
// TODO: may support fp16 mat transpose ?
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// 1d index
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32_col2row)
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_col2row)
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32_row2col)
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_row2col)
// 2d index. easier for diagonal
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32_col2row2d)
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_col2row2d)
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32_row2col2d)
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_row2col2d)
// diagonal index method.
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32_diagonal2d)
// shared memory optimize
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_shared_col2row2d)
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_shared_row2col2d)
//shared memory optimize with bcf
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_shared_bcf_col2row2d)
TORCH_BINDING_COMMON_EXTENSION(mat_transpose_f32x4_shared_bcf_row2col2d)
}