Skip to content

Commit

Permalink
refactor step_mode in ensemble_rowwise_adagrad (#3137)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3137

X-link: facebookresearch/FBGEMM#230

refactor step_mode in ensemble_rowwise_adagrad

Reviewed By: q10, spcyppt

Differential Revision: D62608418
  • Loading branch information
minhua-chen authored and facebook-github-bot committed Sep 14, 2024
1 parent a90aac1 commit 94fb88d
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions fbgemm_gpu/codegen/genscript/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,27 +1047,24 @@ def ensemble_rowwise_adagrad() -> Dict[str, Any]:
momentum2[idx] = new_sum_square_grads;
multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
coef_ema = fabs(momentum);
coef_ema = momentum*1.0;
if (step_mode == 1) {
// row_counter[idx] records the number of appearances of this row
row_counter[idx] += 1.0;
should_ema = ((int64_t)round(fmod(row_counter[idx], step_ema)) == 0);
should_swap = (row_counter[idx] > step_start && (int64_t)round(fmod(row_counter[idx], step_swap)) == 0);
} else if (step_mode == 2) {
} else {
// row_counter[idx] records the step of last ema; prev_iter[idx] records the step of last swap
should_ema = ((iter*1.0 - row_counter[idx]) >= step_ema);
should_swap = (iter*1.0 > step_start && (iter*1.0 - prev_iter[idx]) >= step_swap);
if (should_ema) {
coef_ema = (momentum>0) ? powf(fabs(momentum), (iter*1.0 - row_counter[idx])/max(1.0, step_ema)) : fabs(momentum);
coef_ema = powf(momentum, (iter*1.0 - row_counter[idx])/max(1.0, step_ema));
row_counter[idx] = iter*1.0;
}
if (should_swap) {
prev_iter[idx] = iter*1.0;
}
} else {
should_ema = false;
should_swap = false;
}
}
}
multiplier = SHFL_SYNC(multiplier, 0);
coef_ema = SHFL_SYNC(coef_ema, 0);
Expand All @@ -1083,10 +1080,16 @@ def ensemble_rowwise_adagrad() -> Dict[str, Any]:
if (should_ema) { // slow table ema
Vec4T<momentum1_ph_t> m_t(&momentum1[idx * D + d]);
m_t.acc.x = (1.0 - coef_ema) * weight_new.acc.x + coef_ema * m_t.acc.x + (fabs(momentum) - coef_ema) * multiplier * grad.acc.x;
m_t.acc.y = (1.0 - coef_ema) * weight_new.acc.y + coef_ema * m_t.acc.y + (fabs(momentum) - coef_ema) * multiplier * grad.acc.y;
m_t.acc.z = (1.0 - coef_ema) * weight_new.acc.z + coef_ema * m_t.acc.z + (fabs(momentum) - coef_ema) * multiplier * grad.acc.z;
m_t.acc.w = (1.0 - coef_ema) * weight_new.acc.w + coef_ema * m_t.acc.w + (fabs(momentum) - coef_ema) * multiplier * grad.acc.w;
m_t.acc.x = (1.0 - coef_ema) * weight_new.acc.x + coef_ema * m_t.acc.x;
m_t.acc.y = (1.0 - coef_ema) * weight_new.acc.y + coef_ema * m_t.acc.y;
m_t.acc.z = (1.0 - coef_ema) * weight_new.acc.z + coef_ema * m_t.acc.z;
m_t.acc.w = (1.0 - coef_ema) * weight_new.acc.w + coef_ema * m_t.acc.w;
if (step_mode == 2) {
m_t.acc.x = m_t.acc.x + (momentum - coef_ema) * multiplier * grad.acc.x;
m_t.acc.y = m_t.acc.y + (momentum - coef_ema) * multiplier * grad.acc.y;
m_t.acc.z = m_t.acc.z + (momentum - coef_ema) * multiplier * grad.acc.z;
m_t.acc.w = m_t.acc.w + (momentum - coef_ema) * multiplier * grad.acc.w;
}
m_t.store(&momentum1[idx * D + d]);
}
Expand Down

0 comments on commit 94fb88d

Please sign in to comment.