Skip to content

Commit

Permalink
avoid early NE bump by setting coef_ema=0 (#3161)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3161

X-link: facebookresearch/FBGEMM#257

refactor ensemble_rowwise_adagrad

Reviewed By: csmiler

Differential Revision: D63238676

fbshipit-source-id: e49491f742aa601cc44a16fd77bc02e573897041
  • Loading branch information
minhua-chen authored and facebook-github-bot committed Sep 24, 2024
1 parent 03537c6 commit 825e1a3
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 19 deletions.
31 changes: 13 additions & 18 deletions fbgemm_gpu/codegen/genscript/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,27 +1047,22 @@ def ensemble_rowwise_adagrad() -> Dict[str, Any]:
momentum2[idx] = new_sum_square_grads;
multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
coef_ema = fabs(momentum);
coef_ema = (row_counter[idx] > step_start) ? (momentum*1.0) : 0.0;
if (step_mode == 1) {
// row_counter[idx] records the number of appearances of this row
// row_counter[idx] tracks the number of appearances of this ID
row_counter[idx] += 1.0;
should_ema = floorf(row_counter[idx] / step_ema) - floorf((row_counter[idx]-1.0) / step_ema);
should_swap = floorf(row_counter[idx] / step_swap) - floorf((row_counter[idx]-1.0) / step_swap);
} else if (step_mode == 2) {
// row_counter[idx] records the step of last ema; prev_iter[idx] records the step of last swap
if (momentum > 0) {
should_ema = floorf(iter*1.0 / step_ema) - floorf(row_counter[idx] / step_ema);
should_swap = floorf(iter*1.0 / step_swap) - floorf(prev_iter[idx] / step_swap);
coef_ema = (should_ema > 0.5) ? powf(coef_ema, should_ema) : coef_ema;
} else {
should_ema = floorf((iter*1.0 - row_counter[idx]) / step_ema);
should_swap = floorf((iter*1.0 - prev_iter[idx]) / step_swap);
coef_ema = (should_ema > 0.5) ? powf(coef_ema, (iter*1.0 - row_counter[idx]) / step_ema) : coef_ema;
}
should_ema = floorf((iter*1.0 - row_counter[idx]) / step_ema);
should_swap = floorf((iter*1.0 - prev_iter[idx]) / step_swap);
// row_counter[idx] records the step of last ema
if (should_ema > 0.5) {
coef_ema = powf(coef_ema, (iter*1.0 - row_counter[idx]) / step_ema);
row_counter[idx] = iter*1.0;
}
if (iter*1.0 > step_start && should_swap > 0.5) {
// prev_iter[idx] records the step of last swap
if (should_swap > 0.5) {
prev_iter[idx] = iter*1.0;
}
} else {
Expand All @@ -1089,14 +1084,14 @@ def ensemble_rowwise_adagrad() -> Dict[str, Any]:
if (should_ema > 0.5) { // slow table ema
Vec4T<momentum1_ph_t> m_t(&momentum1[idx * D + d]);
m_t.acc.x = (1.0 - coef_ema) * weight_new.acc.x + coef_ema * m_t.acc.x + (fabs(momentum) - coef_ema) * multiplier * grad.acc.x;
m_t.acc.y = (1.0 - coef_ema) * weight_new.acc.y + coef_ema * m_t.acc.y + (fabs(momentum) - coef_ema) * multiplier * grad.acc.y;
m_t.acc.z = (1.0 - coef_ema) * weight_new.acc.z + coef_ema * m_t.acc.z + (fabs(momentum) - coef_ema) * multiplier * grad.acc.z;
m_t.acc.w = (1.0 - coef_ema) * weight_new.acc.w + coef_ema * m_t.acc.w + (fabs(momentum) - coef_ema) * multiplier * grad.acc.w;
m_t.acc.x = (1.0 - coef_ema) * weight_new.acc.x + coef_ema * m_t.acc.x + (momentum - coef_ema) * multiplier * grad.acc.x;
m_t.acc.y = (1.0 - coef_ema) * weight_new.acc.y + coef_ema * m_t.acc.y + (momentum - coef_ema) * multiplier * grad.acc.y;
m_t.acc.z = (1.0 - coef_ema) * weight_new.acc.z + coef_ema * m_t.acc.z + (momentum - coef_ema) * multiplier * grad.acc.z;
m_t.acc.w = (1.0 - coef_ema) * weight_new.acc.w + coef_ema * m_t.acc.w + (momentum - coef_ema) * multiplier * grad.acc.w;
m_t.store(&momentum1[idx * D + d]);
}
if (iter*1.0 > step_start && should_swap > 0.5) { // slow-to-fast swap
if (should_swap > 0.5) { // slow-to-fast swap
Vec4T<momentum1_ph_t> m_t(&momentum1[idx * D + d]);
weight_new.acc.x = m_t.acc.x * 1.0;
weight_new.acc.y = m_t.acc.y * 1.0;
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/test/tbe/training/backward_optimizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def execute_backward_optimizers_( # noqa C901
1e-4,
1.0,
1.0,
0.0,
-1.0,
StepMode.USE_ITER,
0.8,
)
Expand Down

0 comments on commit 825e1a3

Please sign in to comment.