-
Notifications
You must be signed in to change notification settings - Fork 4
/
scan.cl
185 lines (155 loc) · 5.09 KB
/
scan.cl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
#define FPTYPE float
#define FPVECTYPE float4
__kernel void
reduce(__global const FPTYPE * in,
__global FPTYPE * isums,
const int n,
__local FPTYPE * lmem)
{
// First, calculate the bounds of the region of the array
// that this block will sum. We need these regions to match
// perfectly with those in the bottom-level scan, so we index
// as if vector types of length 4 were in use. This prevents
// errors due to slightly misaligned regions.
int region_size = ((n / 4) / get_num_groups(0)) * 4;
int block_start = get_group_id(0) * region_size;
// Give the last block any extra elements
int block_stop = (get_group_id(0) == get_num_groups(0) - 1) ?
n : block_start + region_size;
// Calculate starting index for this thread/work item
int tid = get_local_id(0);
int i = block_start + tid;
FPTYPE sum = 0.0f;
// Reduce multiple elements per thread
while (i < block_stop)
{
sum += in[i];
i += get_local_size(0);
}
// Load this thread's sum into local/shared memory
lmem[tid] = sum;
barrier(CLK_LOCAL_MEM_FENCE);
// Reduce the contents of shared/local memory
for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1)
{
if (tid < s)
{
lmem[tid] += lmem[tid + s];
}
barrier(CLK_LOCAL_MEM_FENCE);
}
// Write result for this block to global memory
if (tid == 0)
{
isums[get_group_id(0)] = lmem[0];
}
}
// This kernel scans the contents of local memory using a work
// inefficient, but highly parallel Kogge-Stone style scan.
// Set exclusive to 1 for an exclusive scan or 0 for an inclusive scan
inline FPTYPE scanLocalMem(FPTYPE val, __local FPTYPE* lmem, int exclusive)
{
// Set first half of local memory to zero to make room for scanning
int idx = get_local_id(0);
lmem[idx] = 0.0f;
// Set second half to block sums from global memory, but don't go out
// of bounds
idx += get_local_size(0);
lmem[idx] = val;
barrier(CLK_LOCAL_MEM_FENCE);
// Now, perform Kogge-Stone scan
FPTYPE t;
for (int i = 1; i < get_local_size(0); i *= 2)
{
t = lmem[idx - i]; barrier(CLK_LOCAL_MEM_FENCE);
lmem[idx] += t; barrier(CLK_LOCAL_MEM_FENCE);
}
return lmem[idx-exclusive];
}
__kernel void
top_scan(__global FPTYPE * isums,
const int n,
__local FPTYPE * lmem)
{
FPTYPE val = 0.0f;
if (get_local_id(0) < n)
{
val = isums[get_local_id(0)];
}
val = scanLocalMem(val, lmem, 1);
if (get_local_id(0) < n)
{
isums[get_local_id(0)] = val;
}
}
__kernel void
bottom_scan(__global const FPTYPE * in,
__global const FPTYPE * isums,
__global FPTYPE * out,
const int n,
__local FPTYPE * lmem)
{
__local FPTYPE s_seed;
// Prepare for reading 4-element vectors
// Assume n is divisible by 4
__global FPVECTYPE *in4 = (__global FPVECTYPE*) in;
__global FPVECTYPE *out4 = (__global FPVECTYPE*) out;
int n4 = n / 4; //vector type is 4 wide
int region_size = n4 / get_num_groups(0);
int block_start = get_group_id(0) * region_size;
// Give the last block any extra elements
int block_stop = (get_group_id(0) == get_num_groups(0) - 1) ?
n4 : block_start + region_size;
// Calculate starting index for this thread/work item
int i = block_start + get_local_id(0);
int window = block_start;
// Seed the bottom scan with the results from the top scan (i.e. load the per
// block sums from the previous kernel)
FPTYPE seed = isums[get_group_id(0)];
// Scan multiple elements per thread
while (window < block_stop)
{
FPVECTYPE val_4;
if (i < block_stop) // Make sure we don't read out of bounds
{
val_4 = in4[i];
}
else
{
val_4.x = 0.0f;
val_4.y = 0.0f;
val_4.z = 0.0f;
val_4.w = 0.0f;
}
// Serial scan in registers
val_4.y += val_4.x;
val_4.z += val_4.y;
val_4.w += val_4.z;
// ExScan sums in shared memory
FPTYPE res = scanLocalMem(val_4.w, lmem, 1);
// Update and write out to global memory
val_4.x += res + seed;
val_4.y += res + seed;
val_4.z += res + seed;
val_4.w += res + seed;
if (i < block_stop) // Make sure we don't write out of bounds
{
out4[i] = val_4;
}
// Next seed will be the last value
// Last thread puts seed into smem.
if (get_local_id(0) == get_local_size(0)-1) s_seed = val_4.w;
barrier(CLK_LOCAL_MEM_FENCE);
// Broadcast seed to other threads
seed = s_seed;
// Advance window
window += get_local_size(0);
i += get_local_size(0);
}
int group_size = n/get_num_groups(0);
int tt = get_local_id(0) + get_group_id(0) * group_size;
if(tt < n - 1)
out[tt + 1] = out[tt];
if(tt == 0)
out[0] = 0;
}