diff --git a/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu index 3b6df599d..bc7483025 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu @@ -154,8 +154,7 @@ __global__ __launch_bounds__(kForwardMaxThreads) void auto weight_row = WeightRow>( const_cast(&weights[idx_j * D_emb]), nullptr, - D, - nullptr); + D); float2 qparams; if (std::is_same::value) { qparams = weight_row.load_qparams(); @@ -174,8 +173,7 @@ __global__ __launch_bounds__(kForwardMaxThreads) void auto weight_row = WeightRow>( const_cast(&weights[idx_j * D_emb]), nullptr, - D, - nullptr); + D); float2 qparams; if (std::is_same::value) { qparams = weight_row.load_qparams(); diff --git a/fbgemm_gpu/codegen/embedding_common_code_generator.py b/fbgemm_gpu/codegen/embedding_common_code_generator.py index eb8476ed2..5ed4957af 100644 --- a/fbgemm_gpu/codegen/embedding_common_code_generator.py +++ b/fbgemm_gpu/codegen/embedding_common_code_generator.py @@ -1090,7 +1090,7 @@ def lamb() -> Dict[str, Any]: split_precomputation = """ at::acc_type weight_sum_sq = 0.0; at::acc_type rtw_sum_sq = 0.0; - auto weight_row = WeightRow>(weights, cache_weights, D, nullptr); + auto weight_row = WeightRow>(weights, cache_weights, D); float2 qparams; if (std::is_same::value && !cache_weights) { qparams = weight_row.load_qparams(); @@ -1187,7 +1187,7 @@ def partial_rowwise_lamb() -> Dict[str, Any]: at::acc_type weight_sum_sq = 0.0; at::acc_type rtw_sum_sq = 0.0; - auto weight_row = WeightRow>(weights, cache_weights, D, nullptr); + auto weight_row = WeightRow>(weights, cache_weights, D); float2 qparams; if (std::is_same::value && !cache_weights) { qparams = weight_row.load_qparams(); @@ -1375,7 +1375,7 @@ def lars_sgd() -> Dict[str, Any]: at::acc_type weight_sum_sq = 0.0; at::acc_type grad_sum_sq = 0.0; - auto weight_row = WeightRow>(weights, cache_weights, D, nullptr); + auto weight_row = WeightRow>(weights, cache_weights, D); float2 qparams; if (std::is_same::value && !cache_weights) { qparams = weight_row.load_qparams(); diff --git a/fbgemm_gpu/codegen/embedding_forward_split_kernel_nobag_small_template.cu b/fbgemm_gpu/codegen/embedding_forward_split_kernel_nobag_small_template.cu index d7164f02e..8bc946ab1 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_kernel_nobag_small_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_split_kernel_nobag_small_template.cu @@ -153,8 +153,7 @@ batch_index_select_dim0_codegen_forward_small_kernel( auto weight_row_emb = WeightRow( const_cast(&weights[idx_j * D_emb]), nullptr, - D, - nullptr); + D); [[maybe_unused]] float2 qparams_emb; if (std::is_same::value) { qparams_emb = weight_row_emb.load_qparams(); @@ -166,8 +165,7 @@ batch_index_select_dim0_codegen_forward_small_kernel( auto weight_row_cache = WeightRow( const_cast(&weights[idx_j * D_emb]), const_cast(&lxu_cache_weights[cache_idx_j][0]), - D, - nullptr); + D); Vec4T weight = weight_row_cache.load(d, qparams_cache); weight.store(&output[output_j][d]); } else { diff --git a/fbgemm_gpu/codegen/embedding_optimizer_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/embedding_optimizer_split_device_kernel_template.cuh index 5bdfa415c..69ea09516 100644 --- a/fbgemm_gpu/codegen/embedding_optimizer_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/embedding_optimizer_split_device_kernel_template.cuh @@ -73,15 +73,16 @@ DEVICE_INLINE void split_{{ optimizer }}_table_update_kernel( struct SharedMemory>> weight_update_buffer; Vec4T>* shared_weight_update_row = is_int8 ? weight_update_buffer.getPointer() : nullptr; + + StochasticRoundingRNGState state; auto weight_row_template = WeightRow>( - weights, cache_weights, D, nullptr); - - weight_row_template.set_stochastic_rounding( - stochastic_rounding, - stochastic_rounding_philox_args, - threadIdx.x + run_id * blockDim.x - ); + weights, + cache_weights, + D, + stochastic_rounding ? &state : nullptr, + &stochastic_rounding_philox_args, + threadIdx.x + run_id * blockDim.x); float2 qparams_template; if (is_int8 && !cache_weights) { diff --git a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh index ffa4db66f..2b9860888 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh @@ -1334,15 +1334,41 @@ template // TODO: pass in dimension info and calculate qparams for rowwise integer // quantization struct WeightRow { + // Constructor for no stochastic rounding + DEVICE_INLINE WeightRow(emb_t* row, cache_t* cache_row, int dim) + : row_(row), + cache_row_(cache_row), + dim_(dim), + stoc_rounding_state_(nullptr) {} + + // Constructor for stochastic rounding DEVICE_INLINE WeightRow( emb_t* row, cache_t* cache_row, int dim, - StochasticRoundingRNGState* stoc_rounding_state) - : row_(row), - cache_row_(cache_row), - dim_(dim), - stoc_rounding_state_(stoc_rounding_state) {} + StochasticRoundingRNGState* stoc_rounding_state, + const at::PhiloxCudaState* stochastic_rounding_philox_args, + const uint64_t salt_value) + : row_(row), cache_row_(cache_row), dim_(dim) { + // Set the internal stoc_rounding_state_ + stoc_rounding_state_ = stoc_rounding_state; + + if constexpr (!std::is_same_v) { + if (stoc_rounding_state != nullptr) { + const auto stochastic_rounding_seeds = + at::cuda::philox::unpack(*stochastic_rounding_philox_args); + + stochastic_rounding_init( + std::get<0>(stochastic_rounding_seeds) ^ + std::get<1>(stochastic_rounding_seeds), + // The salt value should be different for every *run* and every + // *thread*. + salt_value, + stoc_rounding_state); + } + } + } + emb_t* row_; cache_t* cache_row_; int dim_; @@ -1466,30 +1492,6 @@ struct WeightRow { evict_cache(d, qparams); } } - - DEVICE_INLINE void set_stochastic_rounding( - const bool stochastic_rounding, - const at::PhiloxCudaState stochastic_rounding_philox_args, - const uint64_t salt_value) { - if constexpr (!std::is_same_v) { - if (stochastic_rounding) { - StochasticRoundingRNGState state; - const auto stochastic_rounding_seeds = - at::cuda::philox::unpack(stochastic_rounding_philox_args); - - stochastic_rounding_init( - std::get<0>(stochastic_rounding_seeds) ^ - std::get<1>(stochastic_rounding_seeds), - // The salt value should be different for every *run* and every - // *thread*. - salt_value, - &state); - - // Set the internal stoc_rounding_state_ - stoc_rounding_state_ = &state; - } - } - } }; template diff --git a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu index ba6648c42..d8a913cb4 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu @@ -116,15 +116,13 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel( if constexpr (std::is_same_v) { D_emb += kINT8QparamsBytes; } + StochasticRoundingRNGState state; auto weight_row = WeightRow( &weights[weights_offset_current + idx_current * D_emb + 0], &lxu_cache_weights[cache_set * kWarpSize + insert_slot][0], D_current, - nullptr); - - weight_row.set_stochastic_rounding( - stochastic_rounding, - stochastic_rounding_philox_args, + stochastic_rounding ? &state : nullptr, + &stochastic_rounding_philox_args, (blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x) * kWarpSize + @@ -142,8 +140,7 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel( auto weight_row_emb = WeightRow( &weights[weights_offset_insert + idx_insert * D_emb + 0], nullptr, - D_insert, - nullptr); + D_insert); weight_row_emb.warp_copy_to_cache( &lxu_cache_weights[cache_set * kWarpSize + insert_slot][0], diff --git a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu index a0f5f715b..1087f64ae 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu @@ -78,6 +78,9 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel( BitonicSort>::sort(costs, slots); const int32_t sorted_slot = slots[0]; const int64_t sorted_lru_cost = costs[0]; + const auto stoc_rounding_salt = kWarpSize * + (blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + + threadIdx.x); for (int32_t l = 0; l < min(SL, kWarpSize); ++l) { const int32_t insert_slot = shfl_sync(sorted_slot, l); @@ -120,19 +123,14 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel( D_emb += kINT8QparamsBytes; } + StochasticRoundingRNGState state; auto weight_row = WeightRow( &weights[weights_offset_current + idx_current * D_emb + 0], &lxu_cache_weights[cache_set * kWarpSize + insert_slot][0], D_current, - nullptr); - - weight_row.set_stochastic_rounding( - stochastic_rounding, - stochastic_rounding_philox_args, - (blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + - threadIdx.x) * - kWarpSize + - l); + stochastic_rounding ? &state : nullptr, + &stochastic_rounding_philox_args, + stoc_rounding_salt + l); weight_row.warp_evict_cache(D_current, blockDim.x, threadIdx.x); } @@ -145,8 +143,7 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel( auto weight_row_emb = WeightRow( &weights[weights_offset_insert + idx_insert * D_emb + 0], nullptr, - D_insert, - nullptr); + D_insert); weight_row_emb.warp_copy_to_cache( &lxu_cache_weights[cache_set * kWarpSize + insert_slot][0], diff --git a/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu b/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu index cf5c6778b..266bf7ed6 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu @@ -55,15 +55,13 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_flush_kernel( if constexpr (std::is_same_v) { D_emb += kINT8QparamsBytes; } + StochasticRoundingRNGState state; auto weight_row = WeightRow>( &weights[weights_offset_current + idx_current * D_emb + 0], &lxu_cache_weights[b][0], D_current, - nullptr); - - weight_row.set_stochastic_rounding( - stochastic_rounding, - stochastic_rounding_philox_args, + stochastic_rounding ? &state : nullptr, + &stochastic_rounding_philox_args, blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x); diff --git a/fbgemm_gpu/src/split_embeddings_cache/reset_weight_momentum.cu b/fbgemm_gpu/src/split_embeddings_cache/reset_weight_momentum.cu index a4e6c6bd6..e60840473 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/reset_weight_momentum.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/reset_weight_momentum.cu @@ -179,7 +179,7 @@ __global__ __launch_bounds__(kMaxThreads) void reset_weight_momentum_kernel( auto weight_row_template = WeightRow>( - weights, cache_weights, D, nullptr); + weights, cache_weights, D); // reset momentum1 const int32_t d = (i % chunk4s_per_row) * 4;