Skip to content

Commit

Permalink
Ensure all kernel launches in fbgemm are checked (pytorch#1625)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1625

Uses `fbcode/caffe2/torch/testing/_internal/check_kernel_launches.py` to perform the check

Reviewed By: sryap

Differential Revision: D43785568

fbshipit-source-id: 4c2ed8adc25d4bf0294107037c1f2bb17762966f
  • Loading branch information
r-barnes authored and facebook-github-bot committed Jun 20, 2023
1 parent 5cef9fc commit 8f1b877
Show file tree
Hide file tree
Showing 10 changed files with 35 additions and 24 deletions.
8 changes: 7 additions & 1 deletion fbgemm_gpu/bench/verify_fp16_stochastic_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
* LICENSE file in the root directory of this source tree.
*/

#include <c10/cuda/CUDAException.h>

#include <cuda.h>
#include <cuda_fp16.h>
#include <curand.h>
Expand Down Expand Up @@ -142,7 +144,7 @@ void flush_cache(
cudaMemcpy(d_flush, flush.data(), cache_size, cudaMemcpyHostToDevice);
const unsigned num_blocks = cache_size / 512;
flush_gpu<<<num_blocks, 512>>>(d_flush, d_flush2, do_write);
cudaDeviceSynchronize();
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

int main(int argc, char* argv[]) {
Expand Down Expand Up @@ -207,6 +209,7 @@ int main(int argc, char* argv[]) {
auto start = std::chrono::high_resolution_clock::now();
convert_float_to_half_direct<<<num_blocks, block_size>>>(
d_f16_direct_array, d_f32_array, test_size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
cudaDeviceSynchronize();
auto end = std::chrono::high_resolution_clock::now();
cudaError_t e = cudaGetLastError();
Expand All @@ -223,6 +226,7 @@ int main(int argc, char* argv[]) {
start = std::chrono::high_resolution_clock::now();
convert_float_to_half_bitcarry<<<num_blocks, block_size>>>(
d_f16_bitcarry_array, d_f32_array, test_size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
cudaDeviceSynchronize();
end = std::chrono::high_resolution_clock::now();
e = cudaGetLastError();
Expand All @@ -239,6 +243,7 @@ int main(int argc, char* argv[]) {
start = std::chrono::high_resolution_clock::now();
convert_float_to_half_shortrand<<<num_blocks, block_size>>>(
d_f16_shortrand_array, d_f32_array, d_random_number, test_size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
cudaDeviceSynchronize();
end = std::chrono::high_resolution_clock::now();
e = cudaGetLastError();
Expand All @@ -255,6 +260,7 @@ int main(int argc, char* argv[]) {
start = std::chrono::high_resolution_clock::now();
convert_float_to_half_assemblefloat<<<num_blocks, block_size>>>(
d_f16_assemblefloat_array, d_f32_array, d_random_number, test_size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
cudaDeviceSynchronize();
end = std::chrono::high_resolution_clock::now();
e = cudaGetLastError();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ Tensor {{ "dense" if dense else "split" }}_embedding_codegen_grad_indice_weights
FixedDivisor(total_B / T)
{% endif %}
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return;
}
{% endfor %}
Expand Down
3 changes: 3 additions & 0 deletions fbgemm_gpu/codegen/embedding_forward_split_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ Tensor {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else ""
{%- endif %} // if not dense
MAKE_PTA(output, output_t, 2, 64)
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
{%- if vbe %}
output = output.reshape({-1});
{%- endif %}
Expand Down Expand Up @@ -395,6 +396,7 @@ Tensor {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else ""
{%- endif %}
MAKE_PTA(output, output_t, 2, 64)
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return;
}
{%- endfor %}
Expand Down Expand Up @@ -434,6 +436,7 @@ Tensor {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else ""
{%- endif %}
MAKE_PTA(output, output_t, 2, 64)
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return;
{%- if not dense %}
} // if (use_lxu_cache == {{ use_cache }})
Expand Down
6 changes: 4 additions & 2 deletions fbgemm_gpu/include/fbgemm_gpu/bench_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
*/

#pragma once

#include <c10/cuda/CUDAException.h>

#include <cuda.h>
#include <curand.h>
#include <curand_kernel.h>
Expand All @@ -33,10 +36,9 @@ void flush_cache(int cache_size_mb = 40, bool do_write = false) {
CUDA_CHECK(
cudaMemcpy(d_flush, flush.data(), cache_size, cudaMemcpyHostToDevice));
flush_gpu<<<cache_size / 512, 512>>>(d_flush, d_flush2, do_write);
C10_CUDA_KERNEL_LAUNCH_CHECK();
CUDA_CHECK(cudaFree(d_flush));
CUDA_CHECK(cudaFree(d_flush2));
CUDA_CHECK(cudaDeviceSynchronize());
CUDA_CHECK(cudaGetLastError());
}

void generate_random_table(float* d_f32_table, unsigned size) {
Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/src/metric_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ at::Tensor batch_auc(
num_entries, \
last_block_num_entries, \
padded_num_entries_per_block, \
num_blocks);
num_blocks); \
C10_CUDA_KERNEL_LAUNCH_CHECK();
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "auc_wrapper_1", [&] {
AT_DISPATCH_ALL_TYPES_AND(
Expand Down
20 changes: 10 additions & 10 deletions fbgemm_gpu/src/quantize_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -777,8 +777,8 @@ Tensor _float_to_fused8bitrowwise_gpu_t(const Tensor& input) {
nrows,
ncols,
output.data_ptr<std::uint8_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
// range_tensor is used to store the range for each embedding row.
// We save range/255.0f as row scale, and use 255.0f / (range + kEpsilon) to
Expand Down Expand Up @@ -816,8 +816,8 @@ Tensor _float_to_fused8bitrowwise_gpu_t(const Tensor& input) {
ncols,
output.data_ptr<std::uint8_t>(),
range_tensor.data_ptr<float>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
{
Expand All @@ -836,8 +836,8 @@ Tensor _float_to_fused8bitrowwise_gpu_t(const Tensor& input) {
nrows,
ncols,
output.data_ptr<std::uint8_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
Expand Down Expand Up @@ -912,8 +912,8 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
ncols,
output.data_ptr<std::uint8_t>(),
forward);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
// range_tensor is used to store the range for each embedding row.
// We save max_pos/max_val(rowwise) as row scale to quantize
Expand Down Expand Up @@ -953,8 +953,8 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
output.data_ptr<std::uint8_t>(),
range_tensor.data_ptr<float>(),
forward);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
{
Expand All @@ -974,8 +974,8 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) {
ncols,
output.data_ptr<std::uint8_t>(),
forward);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
Expand Down Expand Up @@ -1108,8 +1108,8 @@ Tensor _fused8bitrowwise_to_float_gpu_t(const Tensor& input) {
nrows,
ncols,
output.data_ptr<scalar_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
return output;
}
Expand Down Expand Up @@ -1199,8 +1199,8 @@ Tensor _FP8rowwise_to_float_gpu_t(const Tensor& input, bool forward) {
ncols,
output.data_ptr<scalar_t>(),
forward);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
return output;
}
Expand Down Expand Up @@ -1431,8 +1431,8 @@ Tensor _float_to_fusednbitrowwise_gpu_t(
nrows,
ncols,
output.data_ptr<std::uint8_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
return output;
}
Expand Down Expand Up @@ -1518,8 +1518,8 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
nrows,
ncols,
output.data_ptr<scalar_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
return output;
}
Expand Down
2 changes: 2 additions & 0 deletions fbgemm_gpu/src/sparse_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3170,6 +3170,8 @@ zipf_cuda(const double a, const int64_t n, const int64_t seed) {
0,
at::cuda::getCurrentCUDAStream()>>>(
a, seed, y.packed_accessor64<long, 1, at::RestrictPtrTraits>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
return y;
}
Expand Down
14 changes: 6 additions & 8 deletions fbgemm_gpu/src/split_embeddings_cache_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,8 @@ DLL_PUBLIC void lxu_cache_flush_cuda(
.packed_accessor64<cache_t, 2, at::RestrictPtrTraits>(),
stochastic_rounding,
rng_engine_inputs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}));
C10_CUDA_KERNEL_LAUNCH_CHECK();
return;
}

namespace {
Expand Down Expand Up @@ -1102,8 +1101,8 @@ void lru_cache_insert_cuda(
gather_cache_stats,
uvm_cache_stats
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}));
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
} // namespace
Expand Down Expand Up @@ -1753,8 +1752,8 @@ void lfu_update_counts_cuda(
unique_indices_count
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
lfu_state.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
constexpr int32_t kCacheSetBits = 24;
Expand Down Expand Up @@ -2127,8 +2126,8 @@ void lfu_cache_insert_cuda(
.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(),
stochastic_rounding,
rng_engine_inputs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}));
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
} // namespace
Expand Down Expand Up @@ -2394,9 +2393,8 @@ void lfu_cache_insert_byte_cuda(
.packed_accessor64<uint8_t, 2, at::RestrictPtrTraits>(),
lfu_state.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(),
row_alignment);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
} // namespace
Expand Down Expand Up @@ -2978,6 +2976,6 @@ DLL_PUBLIC void reset_weight_momentum_cuda(
buffer_ids.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
lxu_cache_locations
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}));
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
1 change: 0 additions & 1 deletion fbgemm_gpu/src/split_embeddings_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ transpose_embedding_input(
} else {
INVOKE_LINEARIZE_INDEX_KERNEL(int64_t, true);
}
{
size_t temp_storage_bytes = 0;
AT_CUDA_CHECK(
Expand Down
1 change: 0 additions & 1 deletion fbgemm_gpu/src/ssd_split_embeddings_cache_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ Tensor masked_index_put_cuda(
} // lambda
);

C10_CUDA_KERNEL_LAUNCH_CHECK();
return self;
}

Expand Down

0 comments on commit 8f1b877

Please sign in to comment.