Skip to content

Commit

Permalink
use float instead of bool for should_ema and should_swap (#3158)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#253

Pull Request resolved: #3158

 Use float instead of bool type for should_ema and should_swap, as bool may not be well supported in __shfl_sync()
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#warp-shuffle-synopsis

Reviewed By: q10, spcyppt

Differential Revision: D62962745
  • Loading branch information
minhua-chen authored and facebook-github-bot committed Sep 21, 2024
1 parent 2446c5d commit 86cf79f
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions fbgemm_gpu/codegen/genscript/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,33 +1040,33 @@ def ensemble_rowwise_adagrad() -> Dict[str, Any]:
at::acc_type<cache_t, true> multiplier;
at::acc_type<cache_t, true> coef_ema;
at::acc_type<bool, true> should_ema;
at::acc_type<bool, true> should_swap;
at::acc_type<cache_t, true> should_ema;
at::acc_type<cache_t, true> should_swap;
if (threadIdx.x == 0) {
at::acc_type<cache_t, true> new_sum_square_grads = momentum2[idx] + g_avg_square;
momentum2[idx] = new_sum_square_grads;
multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
coef_ema = fabs(momentum);
coef_ema = fabs(momentum);
if (step_mode == 1) {
// row_counter[idx] records the number of appearances of this row
row_counter[idx] += 1.0;
should_ema = ((int64_t)round(fmod(row_counter[idx], step_ema)) == 0);
should_swap = (row_counter[idx] > step_start && (int64_t)round(fmod(row_counter[idx], step_swap)) == 0);
should_ema = floorf(row_counter[idx] / step_ema) - floorf((row_counter[idx]-1.0) / step_ema);
should_swap = floorf(row_counter[idx] / step_swap) - floorf((row_counter[idx]-1.0) / step_swap);
} else if (step_mode == 2) {
// row_counter[idx] records the step of last ema; prev_iter[idx] records the step of last swap
should_ema = ((iter*1.0 - row_counter[idx]) >= step_ema);
should_swap = (iter*1.0 > step_start && (iter*1.0 - prev_iter[idx]) >= step_swap);
if (should_ema) {
coef_ema = (momentum>0) ? powf(fabs(momentum), (iter*1.0 - row_counter[idx])/max(1.0, step_ema)) : fabs(momentum);
should_ema = floorf(iter*1.0 / step_ema) - floorf(row_counter[idx] / step_ema);
should_swap = floorf(iter*1.0 / step_swap) - floorf(prev_iter[idx] / step_swap);
if (should_ema > 0.5) {
coef_ema = (momentum > 0) ? powf(coef_ema, should_ema) : coef_ema;
row_counter[idx] = iter*1.0;
}
if (should_swap) {
if (iter*1.0 > step_start && should_swap > 0.5) {
prev_iter[idx] = iter*1.0;
}
} else {
should_ema = false;
should_swap = false;
should_ema = 0;
should_swap = 0;
}
}
multiplier = SHFL_SYNC(multiplier, 0);
Expand All @@ -1081,7 +1081,7 @@ def ensemble_rowwise_adagrad() -> Dict[str, Any]:
weight_new.acc.z = weight_new.acc.z - multiplier * grad.acc.z;
weight_new.acc.w = weight_new.acc.w - multiplier * grad.acc.w;
if (should_ema) { // slow table ema
if (should_ema > 0.5) { // slow table ema
Vec4T<momentum1_ph_t> m_t(&momentum1[idx * D + d]);
m_t.acc.x = (1.0 - coef_ema) * weight_new.acc.x + coef_ema * m_t.acc.x + (fabs(momentum) - coef_ema) * multiplier * grad.acc.x;
m_t.acc.y = (1.0 - coef_ema) * weight_new.acc.y + coef_ema * m_t.acc.y + (fabs(momentum) - coef_ema) * multiplier * grad.acc.y;
Expand All @@ -1090,12 +1090,12 @@ def ensemble_rowwise_adagrad() -> Dict[str, Any]:
m_t.store(&momentum1[idx * D + d]);
}
if (should_swap) { // slow-to-fast swap
if (iter*1.0 > step_start && should_swap > 0.5) { // slow-to-fast swap
Vec4T<momentum1_ph_t> m_t(&momentum1[idx * D + d]);
weight_new.acc.x = m_t.acc.x;
weight_new.acc.y = m_t.acc.y;
weight_new.acc.z = m_t.acc.z;
weight_new.acc.w = m_t.acc.w;
weight_new.acc.x = m_t.acc.x * 1.0;
weight_new.acc.y = m_t.acc.y * 1.0;
weight_new.acc.z = m_t.acc.z * 1.0;
weight_new.acc.w = m_t.acc.w * 1.0;
}
"""

Expand Down

0 comments on commit 86cf79f

Please sign in to comment.