Skip to content

Commit

Permalink
2024-09-18 nightly release (e27c5e1)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Sep 18, 2024
1 parent 1738f97 commit 79ef064
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu(
std::optional<int64_t> max_float8_D,
std::optional<int64_t> fp8_exponent_bits,
std::optional<int64_t> fp8_exponent_bias) {
if (offsets.scalar_type() != indices.scalar_type()) {
offsets = offsets.toType(indices.scalar_type());
}
if (static_cast<PoolingMode>(pooling_mode) == PoolingMode::NONE) {
std::vector<int64_t> max_D_list{
max_int2_D,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no
{%- set func_name = "nbit::" + emb_weight_type + "_split_embedding" + ("_nobag" if nobag else "") + "_codegen_forward_" + wdesc + "_kernel_small_L" %}

#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name_{{ emb_weight_type }} = func_name_{{ emb_weight_type }};
const auto func_name_{{ emb_weight_type }} = "{{ func_name }}_{{ emb_weight_type }}";
#endif

#ifdef X
Expand Down
3 changes: 3 additions & 0 deletions fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ void bounds_check_indices_cpu(
const std::optional<Tensor>& weights,
const std::optional<Tensor>& B_offsets,
const int64_t max_B) {
if (offsets.scalar_type() != indices.scalar_type()) {
offsets = offsets.toType(indices.scalar_type());
}
const auto vbe = B_offsets.has_value();
if (vbe) {
TENSOR_NDIM_EQUALS(B_offsets.value(), 1);
Expand Down
161 changes: 113 additions & 48 deletions fbgemm_gpu/experimental/gen_ai/src/comm/car.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,26 @@ DEVICE_INLINE bf16x8 add_bf16x8(bf16x8 a, bf16x8 b) {
return c;
}

template <int32_t kWorldSize>
__global__ void one_shot_all_reduce(
int32_t rank,
int32_t world_size,
int32_t flag,
std::array<int32_t*, 8> barriers,
std::array<at::BFloat16*, 8> inputs,
template <int32_t kWorldSize, bool has_acc>
#if defined(USE_ROCM)
__launch_bounds__(512)
#endif
__global__ void one_shot_all_reduce(
int32_t rank,
int32_t world_size,
int32_t flag,
std::array<int32_t*, 8> barriers,
std::array<at::BFloat16*, 8> inputs,
#if defined(USE_ROCM)
at::BFloat16* __restrict__ ar_input,
at::BFloat16* __restrict__ acc,
at::BFloat16* __restrict__ output,
#else
at::BFloat16* ar_input,
at::BFloat16* acc,
at::BFloat16* output,
int32_t N) {
#endif
int32_t N) {
// It is expensive to launch hipMemcpyAsync on ROCm
// Move data copy here. Each block copies part of input data
at::BFloat16* input = inputs[rank];
Expand Down Expand Up @@ -143,11 +152,11 @@ __global__ void one_shot_all_reduce(

// Sum the values from the different ranks.
bf16x8 sums;
if (acc) {
if constexpr (has_acc) {
*reinterpret_cast<uint4*>(&sums) =
*reinterpret_cast<const uint4*>(&acc[i]);
} else {
memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
*reinterpret_cast<uint4*>(&sums) = uint4{0};
}

#pragma unroll kWorldSize
Expand Down Expand Up @@ -336,15 +345,24 @@ static DEVICE_INLINE void ld_flag_acquire(int32_t& flag, int32_t* flag_addr) {
#endif
}

template <int32_t kWorldSize>
template <int32_t kWorldSize, bool has_acc>
#if defined(USE_ROCM)
__launch_bounds__(512) __global__ void two_shot_all_reduce(
#else
__launch_bounds__(1024) __global__ void two_shot_all_reduce(
#endif
int32_t rank,
int32_t world_size,
int32_t flag,
std::array<int32_t*, 8> barriers,
std::array<at::BFloat16*, 8> inputs,
#if defined(USE_ROCM)
at::BFloat16* __restrict__ acc,
at::BFloat16* __restrict__ output,
#else
at::BFloat16* acc,
at::BFloat16* output,
#endif
int32_t N) {
int32_t N_per_rank = N / kWorldSize;
int32_t N_start = N_per_rank * rank;
Expand Down Expand Up @@ -374,13 +392,11 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce(
__syncthreads();

at::BFloat16* src_d[kWorldSize];
int dst_rank[kWorldSize];

#pragma unroll kWorldSize
for (int ii = 0; ii < kWorldSize; ++ii) {
int d_rank = (rank + ii) % kWorldSize;
src_d[ii] = inputs[d_rank];
dst_rank[ii] = d_rank;
}

// Each block accumulates the values from the different GPUs on the same
Expand All @@ -395,11 +411,12 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce(
}

bf16x8 sums;
if (acc) {

if constexpr (has_acc) {
*reinterpret_cast<uint4*>(&sums) =
*reinterpret_cast<const uint4*>(&acc[i + N_start]);
} else {
memset(reinterpret_cast<void*>(&sums), 0, sizeof(sums));
*reinterpret_cast<uint4*>(&sums) = uint4{0};
}

#pragma unroll kWorldSize
Expand Down Expand Up @@ -433,11 +450,19 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce(
// Gather all needed elts from other intra-node ranks
for (size_t i = threadIdx.x * 8 + blockIdx.x * blockDim.x * 8; i < N_per_rank;
i += gridDim.x * blockDim.x * 8) {
uint4 temp[kWorldSize];
#pragma unroll kWorldSize
for (int ii = 0; ii < kWorldSize; ++ii) {
int d_rank = (rank + ii) % kWorldSize;
int i_r = N_start + i + (d_rank - rank) * N_per_rank;
temp[ii] = reinterpret_cast<const uint4*>(&src_d[ii][i_r])[0];
}

#pragma unroll kWorldSize
for (int ii = 0; ii < kWorldSize; ++ii) {
int i_r = N_start + i + (dst_rank[ii] - rank) * N_per_rank;
*reinterpret_cast<uint4*>(&output[i_r]) =
reinterpret_cast<const uint4*>(&src_d[ii][i_r])[0];
int d_rank = (rank + ii) % kWorldSize;
int i_r = N_start + i + (d_rank - rank) * N_per_rank;
*reinterpret_cast<uint4*>(&output[i_r]) = temp[ii];
}
}
}
Expand Down Expand Up @@ -474,7 +499,11 @@ void one_shot_car_allreduce(
constexpr int32_t N_per_thread = 8;
constexpr int32_t N_per_warp = N_per_thread * kThreadsPerWarp;
TORCH_CHECK(N % N_per_warp == 0);
#if defined(USE_ROCM)
constexpr int32_t kThreadsPerBlock = 512;
#else
constexpr int32_t kThreadsPerBlock = 1024;
#endif
constexpr int32_t kMaxBlocks = 24;

dim3 threads(0, 1, 1);
Expand All @@ -494,21 +523,37 @@ void one_shot_car_allreduce(
threads.x = threads_per_block;
}

#define X(kWorldSize) \
if (state->world_size_ == kWorldSize) { \
one_shot_all_reduce<kWorldSize> \
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
state->rank_, \
state->world_size_, \
state->flag_ * state->world_size_, \
barriers, \
inputs, \
y.data_ptr<at::BFloat16>(), \
z ? z->data_ptr<at::BFloat16>() : nullptr, \
y_allreduce.data_ptr<at::BFloat16>(), \
N); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
return; \
#define X(kWorldSize) \
if (state->world_size_ == kWorldSize) { \
if (z) { \
one_shot_all_reduce<kWorldSize, true> \
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
state->rank_, \
state->world_size_, \
state->flag_ * state->world_size_, \
barriers, \
inputs, \
y.data_ptr<at::BFloat16>(), \
z->data_ptr<at::BFloat16>(), \
y_allreduce.data_ptr<at::BFloat16>(), \
N); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
return; \
} else { \
one_shot_all_reduce<kWorldSize, false> \
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
state->rank_, \
state->world_size_, \
state->flag_ * state->world_size_, \
barriers, \
inputs, \
y.data_ptr<at::BFloat16>(), \
nullptr, \
y_allreduce.data_ptr<at::BFloat16>(), \
N); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
return; \
} \
}

TORCH_CHECK(
Expand All @@ -520,7 +565,7 @@ void one_shot_car_allreduce(

#undef X
return;
}
} // namespace fbgemm_gpu

void two_shot_car_allreduce(
at::Tensor y_allreduce,
Expand Down Expand Up @@ -565,26 +610,46 @@ void two_shot_car_allreduce(
TORCH_CHECK(N_per_rank % N_per_thread == 0);
auto threads_per_rank = N_per_rank / N_per_thread;

#if defined(USE_ROCM)
constexpr int32_t kThreadsPerBlock = 512;
#else
constexpr int32_t kThreadsPerBlock = 1024;
#endif

constexpr int32_t kMaxBlocks = 24;

auto blocks = std::min<int32_t>(
cuda_calc_block_count(threads_per_rank, kThreadsPerBlock), kMaxBlocks);

#define X(kWorldSize) \
if (state->world_size_ == kWorldSize) { \
two_shot_all_reduce<kWorldSize> \
<<<blocks, kThreadsPerBlock, 0, at::cuda::getCurrentCUDAStream()>>>( \
state->rank_, \
state->world_size_, \
state->flag_ * state->world_size_, \
barriers, \
inputs, \
z ? z->data_ptr<at::BFloat16>() : nullptr, \
y_allreduce.data_ptr<at::BFloat16>(), \
N); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
return; \
#define X(kWorldSize) \
if (state->world_size_ == kWorldSize) { \
if (z) { \
two_shot_all_reduce<kWorldSize, true> \
<<<blocks, kThreadsPerBlock, 0, at::cuda::getCurrentCUDAStream()>>>( \
state->rank_, \
state->world_size_, \
state->flag_ * state->world_size_, \
barriers, \
inputs, \
z->data_ptr<at::BFloat16>(), \
y_allreduce.data_ptr<at::BFloat16>(), \
N); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
return; \
} else { \
two_shot_all_reduce<kWorldSize, false> \
<<<blocks, kThreadsPerBlock, 0, at::cuda::getCurrentCUDAStream()>>>( \
state->rank_, \
state->world_size_, \
state->flag_ * state->world_size_, \
barriers, \
inputs, \
nullptr, \
y_allreduce.data_ptr<at::BFloat16>(), \
N); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
return; \
} \
}

TORCH_CHECK(
Expand Down
74 changes: 73 additions & 1 deletion fbgemm_gpu/include/fbgemm_gpu/utils/tensor_accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#include <ATen/ATen.h>
#include <c10/core/ScalarType.h>
#include <c10/macros/Macros.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Deprecated.h>
Expand Down Expand Up @@ -472,6 +473,53 @@ template <
using PackedTensorAccessor64 =
GenericPackedTensorAccessor<T, N, PtrTraits, int64_t>;

template <typename T>
inline at::ScalarType scalar_type_for() {
#define TYPE_CASE(U, name) \
if constexpr (std::is_same_v<T, U>) { \
return at::ScalarType::name; \
}

AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(TYPE_CASE)

#undef TYPE_CASE

return at::ScalarType::Undefined;
}

template <typename T>
inline void check_scalar_type(
const at::TensorBase& tensor
#ifdef FBGEMM_GPU_MEMCHECK
,
const char* const func_name,
const char* const tensor_name
#endif
) {
const auto expected_type = scalar_type_for<T>();

TORCH_CHECK(
tensor.scalar_type() == expected_type ||
(isQIntType(tensor.scalar_type()) &&
toUnderlying(tensor.scalar_type()) == expected_type),
#ifdef FBGEMM_GPU_MEMCHECK
"[ ",
func_name,
" ]: ",
#endif
"Expected tensor ",
#ifdef FBGEMM_GPU_MEMCHECK
"'",
tensor_name,
"' ",
#endif
"to have scalar type ",
expected_type,
", but found ",
tensor.scalar_type(),
" instead!");
}

} // namespace fbgemm_gpu

#ifdef FBGEMM_GPU_MEMCHECK
Expand Down Expand Up @@ -521,10 +569,21 @@ pta::PackedTensorAccessor32<T, N, PtrTraits> make_packed_tensor_accessor32(
#else
const at::Tensor& tensor) {
#endif

TORCH_CHECK(
tensor.numel() <=
static_cast<int64_t>(std::numeric_limits<int32_t>::max()),
"numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64");

fbgemm_gpu::check_scalar_type<T>(
tensor
#ifdef FBGEMM_GPU_MEMCHECK
,
func_name,
ptr_name
#endif
);

#ifdef FBGEMM_GPU_MEMCHECK
return make_generic_packed_tensor_accessor<T, N, PtrTraits, int32_t>(
tensor, ptr_name, func_name);
Expand All @@ -542,10 +601,23 @@ pta::PackedTensorAccessor64<T, N, PtrTraits> make_packed_tensor_accessor64(
const at::Tensor& tensor,
const char* const ptr_name,
const char* const func_name) {
#else
const at::Tensor& tensor) {
#endif

fbgemm_gpu::check_scalar_type<T>(
tensor
#ifdef FBGEMM_GPU_MEMCHECK
,
func_name,
ptr_name
#endif
);

#ifdef FBGEMM_GPU_MEMCHECK
return make_generic_packed_tensor_accessor<T, N, PtrTraits, int64_t>(
tensor, ptr_name, func_name);
#else
const at::Tensor& tensor) {
return tensor.packed_accessor64<T, N, PtrTraits>();
#endif
}
Expand Down
Loading

0 comments on commit 79ef064

Please sign in to comment.