Skip to content

Commit

Permalink
decouple ema and adagrad (fbgemm) (pytorch#3180)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#276

Pull Request resolved: pytorch#3180

decouple ema and adagrad (fbgemm)

Differential Revision: D63458836
  • Loading branch information
minhua-chen authored and facebook-github-bot committed Sep 27, 2024
1 parent b0e69ca commit f7a22b6
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 268 deletions.
2 changes: 0 additions & 2 deletions fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ set(GPU_ONLY_OPTIMIZERS
lamb
partial_rowwise_adam
partial_rowwise_lamb
ensemble_rowwise_adagrad
lars_sgd
none
rowwise_adagrad_with_counter)
Expand All @@ -87,7 +86,6 @@ set(GPU_OPTIMIZERS ${COMMON_OPTIMIZERS} ${GPU_ONLY_OPTIMIZERS})
set(VBE_OPTIMIZERS
rowwise_adagrad
rowwise_adagrad_with_counter
ensemble_rowwise_adagrad
sgd
dense)

Expand Down
1 change: 0 additions & 1 deletion fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,6 @@ def generate() -> None:
lars_sgd(),
partial_rowwise_adam(),
partial_rowwise_lamb(),
ensemble_rowwise_adagrad(),
rowwise_adagrad(),
approx_rowwise_adagrad(),
rowwise_adagrad_with_weight_decay(),
Expand Down
121 changes: 0 additions & 121 deletions fbgemm_gpu/codegen/genscript/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,127 +1020,6 @@ def adam() -> Dict[str, Any]:
}


def ensemble_rowwise_adagrad() -> Dict[str, Any]:
split_precomputation = """
at::acc_type<cache_t, true> g_local_sum_square = 0.0;
"""
split_precomputation += generate_optimized_grad_sum_loop_access(
"""
const float4* grad = &{grad_vec}.acc;
auto gx = grad->x;
auto gy = grad->y;
auto gz = grad->z;
auto gw = grad->w;
g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw;
"""
)
split_precomputation += """
const at::acc_type<cache_t, true> g_avg_square =
GROUP_REDUCE_ALL_SUM(g_local_sum_square, at::acc_type<cache_t, true>) / D;
at::acc_type<cache_t, true> multiplier;
at::acc_type<cache_t, true> coef_ema;
at::acc_type<cache_t, true> should_ema;
at::acc_type<cache_t, true> should_swap;
if (threadIdx.x == 0) {
at::acc_type<cache_t, true> new_sum_square_grads = momentum2[idx] + g_avg_square;
momentum2[idx] = new_sum_square_grads;
multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
coef_ema = (row_counter[idx] > step_start) ? (momentum*1.0) : 0.0;
if (step_mode == 1) {
// row_counter[idx] tracks the number of appearances of this ID
row_counter[idx] += 1.0;
should_ema = floorf(row_counter[idx] / step_ema) - floorf((row_counter[idx]-1.0) / step_ema);
should_swap = floorf(row_counter[idx] / step_swap) - floorf((row_counter[idx]-1.0) / step_swap);
} else if (step_mode == 2) {
should_ema = floorf((iter*1.0 - row_counter[idx]) / step_ema);
should_swap = floorf((iter*1.0 - prev_iter[idx]) / step_swap);
// row_counter[idx] records the step of last ema
if (should_ema > 0.5) {
coef_ema = powf(coef_ema, (iter*1.0 - row_counter[idx]) / step_ema);
row_counter[idx] = iter*1.0;
}
// prev_iter[idx] records the step of last swap
if (should_swap > 0.5) {
prev_iter[idx] = iter*1.0;
}
} else {
should_ema = 0.0;
should_swap = 0.0;
}
}
multiplier = SHFL_SYNC(multiplier, 0);
coef_ema = SHFL_SYNC(coef_ema, 0);
should_ema = SHFL_SYNC(should_ema, 0);
should_swap = SHFL_SYNC(should_swap, 0);
"""

split_weight_update = """
weight_new.acc.x = weight_new.acc.x - multiplier * grad.acc.x;
weight_new.acc.y = weight_new.acc.y - multiplier * grad.acc.y;
weight_new.acc.z = weight_new.acc.z - multiplier * grad.acc.z;
weight_new.acc.w = weight_new.acc.w - multiplier * grad.acc.w;
if (should_ema > 0.5) { // slow table ema
Vec4T<momentum1_ph_t> m_t(&momentum1[idx * D + d]);
m_t.acc.x = (1.0 - coef_ema) * weight_new.acc.x + coef_ema * m_t.acc.x + (momentum - coef_ema) * multiplier * grad.acc.x;
m_t.acc.y = (1.0 - coef_ema) * weight_new.acc.y + coef_ema * m_t.acc.y + (momentum - coef_ema) * multiplier * grad.acc.y;
m_t.acc.z = (1.0 - coef_ema) * weight_new.acc.z + coef_ema * m_t.acc.z + (momentum - coef_ema) * multiplier * grad.acc.z;
m_t.acc.w = (1.0 - coef_ema) * weight_new.acc.w + coef_ema * m_t.acc.w + (momentum - coef_ema) * multiplier * grad.acc.w;
m_t.store(&momentum1[idx * D + d]);
}
if (should_swap > 0.5) { // slow-to-fast swap
Vec4T<momentum1_ph_t> m_t(&momentum1[idx * D + d]);
weight_new.acc.x = m_t.acc.x * 1.0;
weight_new.acc.y = m_t.acc.y * 1.0;
weight_new.acc.z = m_t.acc.z * 1.0;
weight_new.acc.w = m_t.acc.w * 1.0;
}
"""

split_weight_update_cpu = "" # TODO

return {
"optimizer": "ensemble_rowwise_adagrad",
"is_prototype_optimizer": True,
"args": OptimizerArgsSet.create(
[
OptimItem(
ArgType.PLACEHOLDER_TENSOR,
"momentum1",
ph_tys=[ArgType.FLOAT_TENSOR, ArgType.BFLOAT16_TENSOR],
),
OptimItem(
ArgType.PLACEHOLDER_TENSOR,
"momentum2",
ph_tys=[ArgType.FLOAT_TENSOR, ArgType.BFLOAT16_TENSOR],
),
OptimItem(ArgType.TENSOR, "prev_iter"),
OptimItem(ArgType.TENSOR, "row_counter"),
OptimItem(ArgType.FLOAT, "learning_rate"),
OptimItem(ArgType.FLOAT, "eps"),
OptimItem(ArgType.FLOAT, "step_ema"),
OptimItem(ArgType.FLOAT, "step_swap"),
OptimItem(ArgType.FLOAT, "step_start"),
OptimItem(ArgType.FLOAT, "momentum"),
OptimItem(ArgType.INT, "iter"),
OptimItem(ArgType.INT, "step_mode"),
]
),
"split_precomputation": split_precomputation,
"split_weight_update": split_weight_update,
"split_post_update": "",
"split_weight_update_cpu": split_weight_update_cpu,
"has_cpu_support": False,
"has_gpu_support": True,
"has_vbe_support": True,
"has_global_weight_decay_support": False,
"has_ssd_support": False,
}


def partial_rowwise_adam() -> Dict[str, Any]:
split_precomputation = """
at::acc_type<cache_t, true> g_local_sum_square = 0.0;
Expand Down
4 changes: 0 additions & 4 deletions fbgemm_gpu/codegen/training/python/lookup_args.template
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,6 @@ class OptimizerArgs(NamedTuple):
eps: float
beta1: float
beta2: float
step_ema: float
step_swap: float
step_start: float
step_mode: int
weight_decay: float
weight_decay_mode: int
eta: float
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,6 @@ def invoke(
{%- if "beta2" in args.split_function_arg_names %}
beta2=optimizer_args.beta2,
{%- endif %}
{%- if "step_ema" in args.split_function_arg_names %}
step_ema=optimizer_args.step_ema,
{%- endif %}
{%- if "step_swap" in args.split_function_arg_names %}
step_swap=optimizer_args.step_swap,
{%- endif %}
{%- if "step_start" in args.split_function_arg_names %}
step_start=optimizer_args.step_start,
{%- endif %}
{%- if "step_mode" in args.split_function_arg_names %}
step_mode=optimizer_args.step_mode,
{%- endif %}
{%- if "weight_decay" in args.split_function_arg_names %}
weight_decay=optimizer_args.weight_decay,
{%- endif %}
Expand Down Expand Up @@ -327,18 +315,6 @@ def invoke(
{%- if "beta2" in args.split_function_arg_names %}
beta2=optimizer_args.beta2,
{%- endif %}
{%- if "step_ema" in args.split_function_arg_names %}
step_ema=optimizer_args.step_ema,
{%- endif %}
{%- if "step_swap" in args.split_function_arg_names %}
step_swap=optimizer_args.step_swap,
{%- endif %}
{%- if "step_start" in args.split_function_arg_names %}
step_start=optimizer_args.step_start,
{%- endif %}
{%- if "step_mode" in args.split_function_arg_names %}
step_mode=optimizer_args.step_mode,
{%- endif %}
{%- if "weight_decay" in args.split_function_arg_names %}
weight_decay=optimizer_args.weight_decay,
{%- endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,6 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer):
{%- if "beta2" in args.split_function_arg_names %}
beta2: float = 0.999,
{%- endif %}
{%- if "step_ema" in args.split_function_arg_names %}
step_ema: float = 10000,
{%- endif %}
{%- if "step_swap" in args.split_function_arg_names %}
step_swap: float = 10000,
{%- endif %}
{%- if "step_start" in args.split_function_arg_names %}
step_start: float = 0,
{%- endif %}
{%- if "step_mode" in args.split_function_arg_names %}
step_mode: int = 2,
{%- endif %}
{%- if "weight_decay" in args.split_function_arg_names %}
weight_decay: float = 0.0,
{%- endif %}
Expand Down Expand Up @@ -130,18 +118,6 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer):
{%- if "beta2" in args.split_function_arg_names %}
beta2=beta2,
{%- endif %}
{%- if "step_ema" in args.split_function_arg_names %}
step_ema=step_ema,
{%- endif %}
{%- if "step_swap" in args.split_function_arg_names %}
step_swap=step_swap,
{%- endif %}
{%- if "step_start" in args.split_function_arg_names %}
step_start=step_start,
{%- endif %}
{%- if "step_mode" in args.split_function_arg_names %}
step_mode=step_mode,
{%- endif %}
{%- if "weight_decay" in args.split_function_arg_names %}
weight_decay=weight_decay,
{%- endif %}
Expand Down Expand Up @@ -186,7 +162,7 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer):
rowwise = False
{% endif %}
{% elif state_tensor == "momentum2" %}
{% if optimizer in ["partial_rowwise_adam", "partial_rowwise_lamb", "ensemble_rowwise_adagrad"] %}
{% if optimizer in ["partial_rowwise_adam", "partial_rowwise_lamb"] %}
rowwise = True
{% else %}
rowwise = False
Expand Down Expand Up @@ -236,18 +212,6 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer):
{%- if "beta2" in args.split_function_arg_names %}
self.beta2 = beta2
{%- endif %}
{%- if "step_ema" in args.split_function_arg_names %}
self.step_ema = step_ema
{%- endif %}
{%- if "step_swap" in args.split_function_arg_names %}
self.step_swap = step_swap
{%- endif %}
{%- if "step_start" in args.split_function_arg_names %}
self.step_start = step_start
{%- endif %}
{%- if "step_mode" in args.split_function_arg_names %}
self.step_mode = step_mode
{%- endif %}
{%- if "weight_decay" in args.split_function_arg_names %}
self.weight_decay = weight_decay
{%- endif %}
Expand Down
Loading

0 comments on commit f7a22b6

Please sign in to comment.