Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

decouple ema and adagrad (fbgemm) #3180

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ set(GPU_ONLY_OPTIMIZERS
lamb
partial_rowwise_adam
partial_rowwise_lamb
ensemble_rowwise_adagrad
lars_sgd
none
rowwise_adagrad_with_counter)
Expand All @@ -87,7 +86,6 @@ set(GPU_OPTIMIZERS ${COMMON_OPTIMIZERS} ${GPU_ONLY_OPTIMIZERS})
set(VBE_OPTIMIZERS
rowwise_adagrad
rowwise_adagrad_with_counter
ensemble_rowwise_adagrad
sgd
dense)

Expand Down
1 change: 0 additions & 1 deletion fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,6 @@ def generate() -> None:
lars_sgd(),
partial_rowwise_adam(),
partial_rowwise_lamb(),
ensemble_rowwise_adagrad(),
rowwise_adagrad(),
approx_rowwise_adagrad(),
rowwise_adagrad_with_weight_decay(),
Expand Down
121 changes: 0 additions & 121 deletions fbgemm_gpu/codegen/genscript/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,127 +1020,6 @@ def adam() -> Dict[str, Any]:
}


def ensemble_rowwise_adagrad() -> Dict[str, Any]:
split_precomputation = """
at::acc_type<cache_t, true> g_local_sum_square = 0.0;
"""
split_precomputation += generate_optimized_grad_sum_loop_access(
"""
const float4* grad = &{grad_vec}.acc;
auto gx = grad->x;
auto gy = grad->y;
auto gz = grad->z;
auto gw = grad->w;
g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw;
"""
)
split_precomputation += """
const at::acc_type<cache_t, true> g_avg_square =
GROUP_REDUCE_ALL_SUM(g_local_sum_square, at::acc_type<cache_t, true>) / D;
at::acc_type<cache_t, true> multiplier;
at::acc_type<cache_t, true> coef_ema;
at::acc_type<cache_t, true> should_ema;
at::acc_type<cache_t, true> should_swap;
if (threadIdx.x == 0) {
at::acc_type<cache_t, true> new_sum_square_grads = momentum2[idx] + g_avg_square;
momentum2[idx] = new_sum_square_grads;
multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
coef_ema = (row_counter[idx] > step_start) ? (momentum*1.0) : 0.0;
if (step_mode == 1) {
// row_counter[idx] tracks the number of appearances of this ID
row_counter[idx] += 1.0;
should_ema = floorf(row_counter[idx] / step_ema) - floorf((row_counter[idx]-1.0) / step_ema);
should_swap = floorf(row_counter[idx] / step_swap) - floorf((row_counter[idx]-1.0) / step_swap);
} else if (step_mode == 2) {
should_ema = floorf((iter*1.0 - row_counter[idx]) / step_ema);
should_swap = floorf((iter*1.0 - prev_iter[idx]) / step_swap);
// row_counter[idx] records the step of last ema
if (should_ema > 0.5) {
coef_ema = powf(coef_ema, (iter*1.0 - row_counter[idx]) / step_ema);
row_counter[idx] = iter*1.0;
}
// prev_iter[idx] records the step of last swap
if (should_swap > 0.5) {
prev_iter[idx] = iter*1.0;
}
} else {
should_ema = 0.0;
should_swap = 0.0;
}
}
multiplier = SHFL_SYNC(multiplier, 0);
coef_ema = SHFL_SYNC(coef_ema, 0);
should_ema = SHFL_SYNC(should_ema, 0);
should_swap = SHFL_SYNC(should_swap, 0);
"""

split_weight_update = """
weight_new.acc.x = weight_new.acc.x - multiplier * grad.acc.x;
weight_new.acc.y = weight_new.acc.y - multiplier * grad.acc.y;
weight_new.acc.z = weight_new.acc.z - multiplier * grad.acc.z;
weight_new.acc.w = weight_new.acc.w - multiplier * grad.acc.w;
if (should_ema > 0.5) { // slow table ema
Vec4T<momentum1_ph_t> m_t(&momentum1[idx * D + d]);
m_t.acc.x = (1.0 - coef_ema) * weight_new.acc.x + coef_ema * m_t.acc.x + (momentum - coef_ema) * multiplier * grad.acc.x;
m_t.acc.y = (1.0 - coef_ema) * weight_new.acc.y + coef_ema * m_t.acc.y + (momentum - coef_ema) * multiplier * grad.acc.y;
m_t.acc.z = (1.0 - coef_ema) * weight_new.acc.z + coef_ema * m_t.acc.z + (momentum - coef_ema) * multiplier * grad.acc.z;
m_t.acc.w = (1.0 - coef_ema) * weight_new.acc.w + coef_ema * m_t.acc.w + (momentum - coef_ema) * multiplier * grad.acc.w;
m_t.store(&momentum1[idx * D + d]);
}
if (should_swap > 0.5) { // slow-to-fast swap
Vec4T<momentum1_ph_t> m_t(&momentum1[idx * D + d]);
weight_new.acc.x = m_t.acc.x * 1.0;
weight_new.acc.y = m_t.acc.y * 1.0;
weight_new.acc.z = m_t.acc.z * 1.0;
weight_new.acc.w = m_t.acc.w * 1.0;
}
"""

split_weight_update_cpu = "" # TODO

return {
"optimizer": "ensemble_rowwise_adagrad",
"is_prototype_optimizer": True,
"args": OptimizerArgsSet.create(
[
OptimItem(
ArgType.PLACEHOLDER_TENSOR,
"momentum1",
ph_tys=[ArgType.FLOAT_TENSOR, ArgType.BFLOAT16_TENSOR],
),
OptimItem(
ArgType.PLACEHOLDER_TENSOR,
"momentum2",
ph_tys=[ArgType.FLOAT_TENSOR, ArgType.BFLOAT16_TENSOR],
),
OptimItem(ArgType.TENSOR, "prev_iter"),
OptimItem(ArgType.TENSOR, "row_counter"),
OptimItem(ArgType.FLOAT, "learning_rate"),
OptimItem(ArgType.FLOAT, "eps"),
OptimItem(ArgType.FLOAT, "step_ema"),
OptimItem(ArgType.FLOAT, "step_swap"),
OptimItem(ArgType.FLOAT, "step_start"),
OptimItem(ArgType.FLOAT, "momentum"),
OptimItem(ArgType.INT, "iter"),
OptimItem(ArgType.INT, "step_mode"),
]
),
"split_precomputation": split_precomputation,
"split_weight_update": split_weight_update,
"split_post_update": "",
"split_weight_update_cpu": split_weight_update_cpu,
"has_cpu_support": False,
"has_gpu_support": True,
"has_vbe_support": True,
"has_global_weight_decay_support": False,
"has_ssd_support": False,
}


def partial_rowwise_adam() -> Dict[str, Any]:
split_precomputation = """
at::acc_type<cache_t, true> g_local_sum_square = 0.0;
Expand Down
4 changes: 0 additions & 4 deletions fbgemm_gpu/codegen/training/python/lookup_args.template
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,6 @@ class OptimizerArgs(NamedTuple):
eps: float
beta1: float
beta2: float
step_ema: float
step_swap: float
step_start: float
step_mode: int
weight_decay: float
weight_decay_mode: int
eta: float
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,6 @@ def invoke(
{%- if "beta2" in args.split_function_arg_names %}
beta2=optimizer_args.beta2,
{%- endif %}
{%- if "step_ema" in args.split_function_arg_names %}
step_ema=optimizer_args.step_ema,
{%- endif %}
{%- if "step_swap" in args.split_function_arg_names %}
step_swap=optimizer_args.step_swap,
{%- endif %}
{%- if "step_start" in args.split_function_arg_names %}
step_start=optimizer_args.step_start,
{%- endif %}
{%- if "step_mode" in args.split_function_arg_names %}
step_mode=optimizer_args.step_mode,
{%- endif %}
{%- if "weight_decay" in args.split_function_arg_names %}
weight_decay=optimizer_args.weight_decay,
{%- endif %}
Expand Down Expand Up @@ -327,18 +315,6 @@ def invoke(
{%- if "beta2" in args.split_function_arg_names %}
beta2=optimizer_args.beta2,
{%- endif %}
{%- if "step_ema" in args.split_function_arg_names %}
step_ema=optimizer_args.step_ema,
{%- endif %}
{%- if "step_swap" in args.split_function_arg_names %}
step_swap=optimizer_args.step_swap,
{%- endif %}
{%- if "step_start" in args.split_function_arg_names %}
step_start=optimizer_args.step_start,
{%- endif %}
{%- if "step_mode" in args.split_function_arg_names %}
step_mode=optimizer_args.step_mode,
{%- endif %}
{%- if "weight_decay" in args.split_function_arg_names %}
weight_decay=optimizer_args.weight_decay,
{%- endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,6 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer):
{%- if "beta2" in args.split_function_arg_names %}
beta2: float = 0.999,
{%- endif %}
{%- if "step_ema" in args.split_function_arg_names %}
step_ema: float = 10000,
{%- endif %}
{%- if "step_swap" in args.split_function_arg_names %}
step_swap: float = 10000,
{%- endif %}
{%- if "step_start" in args.split_function_arg_names %}
step_start: float = 0,
{%- endif %}
{%- if "step_mode" in args.split_function_arg_names %}
step_mode: int = 2,
{%- endif %}
{%- if "weight_decay" in args.split_function_arg_names %}
weight_decay: float = 0.0,
{%- endif %}
Expand Down Expand Up @@ -130,18 +118,6 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer):
{%- if "beta2" in args.split_function_arg_names %}
beta2=beta2,
{%- endif %}
{%- if "step_ema" in args.split_function_arg_names %}
step_ema=step_ema,
{%- endif %}
{%- if "step_swap" in args.split_function_arg_names %}
step_swap=step_swap,
{%- endif %}
{%- if "step_start" in args.split_function_arg_names %}
step_start=step_start,
{%- endif %}
{%- if "step_mode" in args.split_function_arg_names %}
step_mode=step_mode,
{%- endif %}
{%- if "weight_decay" in args.split_function_arg_names %}
weight_decay=weight_decay,
{%- endif %}
Expand Down Expand Up @@ -186,7 +162,7 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer):
rowwise = False
{% endif %}
{% elif state_tensor == "momentum2" %}
{% if optimizer in ["partial_rowwise_adam", "partial_rowwise_lamb", "ensemble_rowwise_adagrad"] %}
{% if optimizer in ["partial_rowwise_adam", "partial_rowwise_lamb"] %}
rowwise = True
{% else %}
rowwise = False
Expand Down Expand Up @@ -236,18 +212,6 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer):
{%- if "beta2" in args.split_function_arg_names %}
self.beta2 = beta2
{%- endif %}
{%- if "step_ema" in args.split_function_arg_names %}
self.step_ema = step_ema
{%- endif %}
{%- if "step_swap" in args.split_function_arg_names %}
self.step_swap = step_swap
{%- endif %}
{%- if "step_start" in args.split_function_arg_names %}
self.step_start = step_start
{%- endif %}
{%- if "step_mode" in args.split_function_arg_names %}
self.step_mode = step_mode
{%- endif %}
{%- if "weight_decay" in args.split_function_arg_names %}
self.weight_decay = weight_decay
{%- endif %}
Expand Down
Loading