diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index 27b4ec884..2ae401ea1 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -60,7 +60,6 @@ set(GPU_ONLY_OPTIMIZERS lamb partial_rowwise_adam partial_rowwise_lamb - ensemble_rowwise_adagrad lars_sgd none rowwise_adagrad_with_counter) @@ -87,7 +86,6 @@ set(GPU_OPTIMIZERS ${COMMON_OPTIMIZERS} ${GPU_ONLY_OPTIMIZERS}) set(VBE_OPTIMIZERS rowwise_adagrad rowwise_adagrad_with_counter - ensemble_rowwise_adagrad sgd dense) diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index afdcb8b3c..ac37444ff 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -335,7 +335,6 @@ def generate() -> None: lars_sgd(), partial_rowwise_adam(), partial_rowwise_lamb(), - ensemble_rowwise_adagrad(), rowwise_adagrad(), approx_rowwise_adagrad(), rowwise_adagrad_with_weight_decay(), diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index acf2af31f..15c100ed5 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -1020,127 +1020,6 @@ def adam() -> Dict[str, Any]: } -def ensemble_rowwise_adagrad() -> Dict[str, Any]: - split_precomputation = """ - at::acc_type g_local_sum_square = 0.0; - """ - split_precomputation += generate_optimized_grad_sum_loop_access( - """ - const float4* grad = &{grad_vec}.acc; - auto gx = grad->x; - auto gy = grad->y; - auto gz = grad->z; - auto gw = grad->w; - g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw; - """ - ) - split_precomputation += """ - const at::acc_type g_avg_square = - GROUP_REDUCE_ALL_SUM(g_local_sum_square, at::acc_type) / D; - - at::acc_type multiplier; - at::acc_type coef_ema; - at::acc_type should_ema; - at::acc_type should_swap; - if (threadIdx.x == 0) { - at::acc_type new_sum_square_grads = momentum2[idx] + g_avg_square; - momentum2[idx] = new_sum_square_grads; - multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); - - coef_ema = (row_counter[idx] > step_start) ? (momentum*1.0) : 0.0; - if (step_mode == 1) { - // row_counter[idx] tracks the number of appearances of this ID - row_counter[idx] += 1.0; - should_ema = floorf(row_counter[idx] / step_ema) - floorf((row_counter[idx]-1.0) / step_ema); - should_swap = floorf(row_counter[idx] / step_swap) - floorf((row_counter[idx]-1.0) / step_swap); - } else if (step_mode == 2) { - should_ema = floorf((iter*1.0 - row_counter[idx]) / step_ema); - should_swap = floorf((iter*1.0 - prev_iter[idx]) / step_swap); - // row_counter[idx] records the step of last ema - if (should_ema > 0.5) { - coef_ema = powf(coef_ema, (iter*1.0 - row_counter[idx]) / step_ema); - row_counter[idx] = iter*1.0; - } - // prev_iter[idx] records the step of last swap - if (should_swap > 0.5) { - prev_iter[idx] = iter*1.0; - } - } else { - should_ema = 0.0; - should_swap = 0.0; - } - } - multiplier = SHFL_SYNC(multiplier, 0); - coef_ema = SHFL_SYNC(coef_ema, 0); - should_ema = SHFL_SYNC(should_ema, 0); - should_swap = SHFL_SYNC(should_swap, 0); - """ - - split_weight_update = """ - weight_new.acc.x = weight_new.acc.x - multiplier * grad.acc.x; - weight_new.acc.y = weight_new.acc.y - multiplier * grad.acc.y; - weight_new.acc.z = weight_new.acc.z - multiplier * grad.acc.z; - weight_new.acc.w = weight_new.acc.w - multiplier * grad.acc.w; - - if (should_ema > 0.5) { // slow table ema - Vec4T m_t(&momentum1[idx * D + d]); - m_t.acc.x = (1.0 - coef_ema) * weight_new.acc.x + coef_ema * m_t.acc.x + (momentum - coef_ema) * multiplier * grad.acc.x; - m_t.acc.y = (1.0 - coef_ema) * weight_new.acc.y + coef_ema * m_t.acc.y + (momentum - coef_ema) * multiplier * grad.acc.y; - m_t.acc.z = (1.0 - coef_ema) * weight_new.acc.z + coef_ema * m_t.acc.z + (momentum - coef_ema) * multiplier * grad.acc.z; - m_t.acc.w = (1.0 - coef_ema) * weight_new.acc.w + coef_ema * m_t.acc.w + (momentum - coef_ema) * multiplier * grad.acc.w; - m_t.store(&momentum1[idx * D + d]); - } - - if (should_swap > 0.5) { // slow-to-fast swap - Vec4T m_t(&momentum1[idx * D + d]); - weight_new.acc.x = m_t.acc.x * 1.0; - weight_new.acc.y = m_t.acc.y * 1.0; - weight_new.acc.z = m_t.acc.z * 1.0; - weight_new.acc.w = m_t.acc.w * 1.0; - } - """ - - split_weight_update_cpu = "" # TODO - - return { - "optimizer": "ensemble_rowwise_adagrad", - "is_prototype_optimizer": True, - "args": OptimizerArgsSet.create( - [ - OptimItem( - ArgType.PLACEHOLDER_TENSOR, - "momentum1", - ph_tys=[ArgType.FLOAT_TENSOR, ArgType.BFLOAT16_TENSOR], - ), - OptimItem( - ArgType.PLACEHOLDER_TENSOR, - "momentum2", - ph_tys=[ArgType.FLOAT_TENSOR, ArgType.BFLOAT16_TENSOR], - ), - OptimItem(ArgType.TENSOR, "prev_iter"), - OptimItem(ArgType.TENSOR, "row_counter"), - OptimItem(ArgType.FLOAT, "learning_rate"), - OptimItem(ArgType.FLOAT, "eps"), - OptimItem(ArgType.FLOAT, "step_ema"), - OptimItem(ArgType.FLOAT, "step_swap"), - OptimItem(ArgType.FLOAT, "step_start"), - OptimItem(ArgType.FLOAT, "momentum"), - OptimItem(ArgType.INT, "iter"), - OptimItem(ArgType.INT, "step_mode"), - ] - ), - "split_precomputation": split_precomputation, - "split_weight_update": split_weight_update, - "split_post_update": "", - "split_weight_update_cpu": split_weight_update_cpu, - "has_cpu_support": False, - "has_gpu_support": True, - "has_vbe_support": True, - "has_global_weight_decay_support": False, - "has_ssd_support": False, - } - - def partial_rowwise_adam() -> Dict[str, Any]: split_precomputation = """ at::acc_type g_local_sum_square = 0.0; diff --git a/fbgemm_gpu/codegen/training/python/lookup_args.template b/fbgemm_gpu/codegen/training/python/lookup_args.template index 54fa11177..357aad622 100644 --- a/fbgemm_gpu/codegen/training/python/lookup_args.template +++ b/fbgemm_gpu/codegen/training/python/lookup_args.template @@ -60,10 +60,6 @@ class OptimizerArgs(NamedTuple): eps: float beta1: float beta2: float - step_ema: float - step_swap: float - step_start: float - step_mode: int weight_decay: float weight_decay_mode: int eta: float diff --git a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template index e03a879cb..2f14b27de 100644 --- a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template +++ b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template @@ -145,18 +145,6 @@ def invoke( {%- if "beta2" in args.split_function_arg_names %} beta2=optimizer_args.beta2, {%- endif %} - {%- if "step_ema" in args.split_function_arg_names %} - step_ema=optimizer_args.step_ema, - {%- endif %} - {%- if "step_swap" in args.split_function_arg_names %} - step_swap=optimizer_args.step_swap, - {%- endif %} - {%- if "step_start" in args.split_function_arg_names %} - step_start=optimizer_args.step_start, - {%- endif %} - {%- if "step_mode" in args.split_function_arg_names %} - step_mode=optimizer_args.step_mode, - {%- endif %} {%- if "weight_decay" in args.split_function_arg_names %} weight_decay=optimizer_args.weight_decay, {%- endif %} @@ -327,18 +315,6 @@ def invoke( {%- if "beta2" in args.split_function_arg_names %} beta2=optimizer_args.beta2, {%- endif %} - {%- if "step_ema" in args.split_function_arg_names %} - step_ema=optimizer_args.step_ema, - {%- endif %} - {%- if "step_swap" in args.split_function_arg_names %} - step_swap=optimizer_args.step_swap, - {%- endif %} - {%- if "step_start" in args.split_function_arg_names %} - step_start=optimizer_args.step_start, - {%- endif %} - {%- if "step_mode" in args.split_function_arg_names %} - step_mode=optimizer_args.step_mode, - {%- endif %} {%- if "weight_decay" in args.split_function_arg_names %} weight_decay=optimizer_args.weight_decay, {%- endif %} diff --git a/fbgemm_gpu/codegen/training/python/split_embedding_optimizer_codegen.template b/fbgemm_gpu/codegen/training/python/split_embedding_optimizer_codegen.template index b9be5cd4c..6c2380e7c 100644 --- a/fbgemm_gpu/codegen/training/python/split_embedding_optimizer_codegen.template +++ b/fbgemm_gpu/codegen/training/python/split_embedding_optimizer_codegen.template @@ -90,18 +90,6 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer): {%- if "beta2" in args.split_function_arg_names %} beta2: float = 0.999, {%- endif %} - {%- if "step_ema" in args.split_function_arg_names %} - step_ema: float = 10000, - {%- endif %} - {%- if "step_swap" in args.split_function_arg_names %} - step_swap: float = 10000, - {%- endif %} - {%- if "step_start" in args.split_function_arg_names %} - step_start: float = 0, - {%- endif %} - {%- if "step_mode" in args.split_function_arg_names %} - step_mode: int = 2, - {%- endif %} {%- if "weight_decay" in args.split_function_arg_names %} weight_decay: float = 0.0, {%- endif %} @@ -130,18 +118,6 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer): {%- if "beta2" in args.split_function_arg_names %} beta2=beta2, {%- endif %} - {%- if "step_ema" in args.split_function_arg_names %} - step_ema=step_ema, - {%- endif %} - {%- if "step_swap" in args.split_function_arg_names %} - step_swap=step_swap, - {%- endif %} - {%- if "step_start" in args.split_function_arg_names %} - step_start=step_start, - {%- endif %} - {%- if "step_mode" in args.split_function_arg_names %} - step_mode=step_mode, - {%- endif %} {%- if "weight_decay" in args.split_function_arg_names %} weight_decay=weight_decay, {%- endif %} @@ -186,7 +162,7 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer): rowwise = False {% endif %} {% elif state_tensor == "momentum2" %} - {% if optimizer in ["partial_rowwise_adam", "partial_rowwise_lamb", "ensemble_rowwise_adagrad"] %} + {% if optimizer in ["partial_rowwise_adam", "partial_rowwise_lamb"] %} rowwise = True {% else %} rowwise = False @@ -236,18 +212,6 @@ class SplitEmbedding{{ optimizer_class_name }}(Optimizer): {%- if "beta2" in args.split_function_arg_names %} self.beta2 = beta2 {%- endif %} - {%- if "step_ema" in args.split_function_arg_names %} - self.step_ema = step_ema - {%- endif %} - {%- if "step_swap" in args.split_function_arg_names %} - self.step_swap = step_swap - {%- endif %} - {%- if "step_start" in args.split_function_arg_names %} - self.step_start = step_start - {%- endif %} - {%- if "step_mode" in args.split_function_arg_names %} - self.step_mode = step_mode - {%- endif %} {%- if "weight_decay" in args.split_function_arg_names %} self.weight_decay = weight_decay {%- endif %} diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 554bd0b00..9b202c65c 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -145,6 +145,14 @@ class GlobalWeightDecayDefinition: lower_bound: float = 0.0 +@dataclass(frozen=True) +class EnsembleModeDefinition: + step_ema: float = 10000 + step_swap: float = 10000 + step_start: float = 0 + step_mode: StepMode = StepMode.USE_ITER + + # Keep in sync with fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh class UVMCacheStatsIndex(enum.IntEnum): num_calls = 0 @@ -473,14 +481,8 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module): beta2 (float = 0.999): The beta2 value used by LAMB and ADAM - step_ema (float = 10000): Used by ENSEMBLE_ROWWISE_ADAGRAD - - step_swap (float = 10000): Used by ENSEMBLE_ROWWISE_ADAGRAD - - step_start (float = 0.0): Used by ENSEMBLE_ROWWISE_ADAGRAD - - step_mode: (StepMode = StepMode.USE_ITER): Used by - ENSEMBLE_ROWWISE_ADAGRAD + ensemble_mode (Optional[EnsembleModeDefinition] = None): + Used by Ensemble Rowwise Adagrad counter_based_regularization (Optional[CounterBasedRegularizationDefinition] = None): Used by Rowwise Adagrad @@ -598,10 +600,7 @@ def __init__( # noqa C901 eta: float = 0.001, beta1: float = 0.9, beta2: float = 0.999, - step_ema: float = 10000, - step_swap: float = 10000, - step_start: float = 0, - step_mode: StepMode = StepMode.USE_ITER, + ensemble_mode: Optional[EnsembleModeDefinition] = None, counter_based_regularization: Optional[ CounterBasedRegularizationDefinition ] = None, @@ -920,6 +919,9 @@ def __init__( # noqa C901 self.gwd_start_iter: int = global_weight_decay.start_iter self.gwd_lower_bound: float = global_weight_decay.lower_bound + self.ensemble_mode: EnsembleModeDefinition = ( + ensemble_mode if ensemble_mode is not None else EnsembleModeDefinition() + ) if counter_based_regularization is None: counter_based_regularization = CounterBasedRegularizationDefinition() if cowclip_regularization is None: @@ -957,10 +959,6 @@ def __init__( # noqa C901 eps=eps, beta1=beta1, beta2=beta2, - step_ema=step_ema, - step_swap=step_swap, - step_start=step_start, - step_mode=step_mode.value, weight_decay=weight_decay, weight_decay_mode=opt_arg_weight_decay_mode.value, eta=eta, @@ -1000,6 +998,7 @@ def __init__( # noqa C901 ) rowwise = optimizer in [ OptimType.EXACT_ROWWISE_ADAGRAD, + OptimType.ENSEMBLE_ROWWISE_ADAGRAD, ] self._apply_split( construct_split_state( @@ -1029,7 +1028,6 @@ def __init__( # noqa C901 rowwise = optimizer in ( OptimType.PARTIAL_ROWWISE_ADAM, OptimType.PARTIAL_ROWWISE_LAMB, - OptimType.ENSEMBLE_ROWWISE_ADAGRAD, ) momentum2_dtype = ( torch.float32 @@ -1059,9 +1057,7 @@ def __init__( # noqa C901 else: # NOTE: make TorchScript work! self._register_nonpersistent_buffers("momentum2") - if self._used_rowwise_adagrad_with_counter or ( - optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD - ): + if self._used_rowwise_adagrad_with_counter: self._apply_split( construct_split_state( embedding_specs, @@ -1129,7 +1125,6 @@ def __init__( # noqa C901 OptimType.LAMB, OptimType.PARTIAL_ROWWISE_ADAM, OptimType.PARTIAL_ROWWISE_LAMB, - OptimType.ENSEMBLE_ROWWISE_ADAGRAD, ) or self._used_rowwise_adagrad_with_global_weight_decay ): @@ -1865,18 +1860,15 @@ def forward( # noqa: C901 assert self._feature_is_enabled( FeatureGateName.TBE_ENSEMBLE_ROWWISE_ADAGRAD ), "ENSEMBLE_ROWWISE_ADAGRAD is an inactive or deprecated feature!" + with torch.no_grad(): + if self.training: + self.ensemble_and_swap() return self._report_io_size_count( "fwd_output", - invokers.lookup_ensemble_rowwise_adagrad.invoke( + invokers.lookup_rowwise_adagrad.invoke( common_args, self.optimizer_args, momentum1, - momentum2, - prev_iter, - row_counter, - iter=int(self.iter.item()), - apply_global_weight_decay=False, - gwd_lower_bound=0.0, ), ) @@ -1935,6 +1927,27 @@ def forward( # noqa: C901 raise ValueError(f"Invalid OptimType: {self.optimizer}") + @torch.jit.ignore + def ensemble_and_swap(self) -> None: + should_ema = self.iter.item() % int(self.ensemble_mode.step_ema) == 0 + should_swap = self.iter.item() % int(self.ensemble_mode.step_swap) == 0 + if should_ema or should_swap: + weights = self.split_embedding_weights() + states = self.split_optimizer_states() + for i in range(len(self.embedding_specs)): + if should_ema: + coef_ema = ( + self.optimizer_args.momentum + if self.iter.item() > self.ensemble_mode.step_start + else 0.0 + ) + weights_cpu = weights[i].to( + dtype=states[i][1].dtype, device=states[i][1].device + ) + states[i][1].lerp_(weights_cpu, 1.0 - coef_ema) + if should_swap: + weights[i].copy_(states[i][1], non_blocking=True) + def reset_uvm_cache_stats(self) -> None: assert ( self.gather_uvm_cache_stats @@ -2343,9 +2356,7 @@ def get_optimizer_state(self) -> List[Dict[str, torch.Tensor]]: list_of_state_dict = [ { "sum": states[0], - "exp_avg": states[1], - "prev_iter": states[2], - "row_counter": states[3], + "sparse_ema": states[1], } for states in split_optimizer_states ] @@ -2390,8 +2401,7 @@ def split_optimizer_states( (8) `PARTIAL_ROWWISE_LAMB`: `momentum1`, `momentum2` (rowwise) - (9) `ENSEMBLE_ROWWISE_ADAGRAD`: `momentum2` (rowwise), `momentum1`, - `prev_iter` (rowwise), `row_counter` (rowwise) + (9) `ENSEMBLE_ROWWISE_ADAGRAD`: `momentum1` (rowwise), `momentum2` (10) `NONE`: no states (throwing an error) @@ -2428,19 +2438,6 @@ def get_optimizer_states( return splits states: List[List[torch.Tensor]] = [] - # For ensemble_rowwise_adagrad, momentum2 ("sum") should go first, - # as it is the default optimizer state for embedding pruning later. - if self.optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD: - states.append( - get_optimizer_states( - self.momentum2_dev, - self.momentum2_host, - self.momentum2_uvm, - self.momentum2_physical_offsets, - self.momentum2_physical_placements, - rowwise=True, - ) - ) if self.optimizer not in (OptimType.EXACT_SGD,): states.append( get_optimizer_states( @@ -2452,6 +2449,7 @@ def get_optimizer_states( rowwise=self.optimizer in [ OptimType.EXACT_ROWWISE_ADAGRAD, + OptimType.ENSEMBLE_ROWWISE_ADAGRAD, ], ) ) @@ -2460,6 +2458,7 @@ def get_optimizer_states( OptimType.PARTIAL_ROWWISE_ADAM, OptimType.LAMB, OptimType.PARTIAL_ROWWISE_LAMB, + OptimType.ENSEMBLE_ROWWISE_ADAGRAD, ): states.append( get_optimizer_states( @@ -2475,7 +2474,6 @@ def get_optimizer_states( if ( self._used_rowwise_adagrad_with_counter or self._used_rowwise_adagrad_with_global_weight_decay - or self.optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD ): states.append( get_optimizer_states( @@ -2487,10 +2485,7 @@ def get_optimizer_states( rowwise=True, ) ) - if ( - self._used_rowwise_adagrad_with_counter - or self.optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD - ): + if self._used_rowwise_adagrad_with_counter: states.append( get_optimizer_states( self.row_counter_dev, diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 2c02db4b9..cd4a4f118 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -37,7 +37,6 @@ apply_split_helper, CounterBasedRegularizationDefinition, CowClipDefinition, - StepMode, UVMCacheStatsIndex, WeightDecayMode, ) @@ -116,10 +115,6 @@ def __init__( eta: float = 0.001, # used by LARS-SGD, beta1: float = 0.9, # used by LAMB and ADAM beta2: float = 0.999, # used by LAMB and ADAM - step_ema: float = 10000, # used by ENSEMBLE_ROWWISE_ADAGRAD - step_swap: float = 10000, # used by ENSEMBLE_ROWWISE_ADAGRAD - step_start: float = 0, # used by ENSEMBLE_ROWWISE_ADAGRAD - step_mode: StepMode = StepMode.USE_ITER, # used by ENSEMBLE_ROWWISE_ADAGRAD counter_based_regularization: Optional[ CounterBasedRegularizationDefinition ] = None, # used by Rowwise Adagrad @@ -535,10 +530,6 @@ def __init__( eps=eps, beta1=beta1, beta2=beta2, - step_ema=step_ema, - step_swap=step_swap, - step_start=step_start, - step_mode=step_mode.value, weight_decay=weight_decay, weight_decay_mode=weight_decay_mode.value, eta=eta, diff --git a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py index adb96daaa..072afc5dc 100644 --- a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py @@ -26,6 +26,7 @@ CounterBasedRegularizationDefinition, CounterWeightDecayMode, CowClipDefinition, + EnsembleModeDefinition, GradSumDecay, LearningRateMode, SplitTableBatchedEmbeddingBagsCodegen, @@ -311,15 +312,17 @@ def execute_backward_optimizers_( # noqa C901 1e-4, 1.0, 1.0, - -1.0, + 0.0, StepMode.USE_ITER, 0.8, ) optimizer_kwargs["eps"] = eps - optimizer_kwargs["step_ema"] = step_ema - optimizer_kwargs["step_swap"] = step_swap - optimizer_kwargs["step_start"] = step_start - optimizer_kwargs["step_mode"] = step_mode + optimizer_kwargs["ensemble_mode"] = EnsembleModeDefinition( + step_ema=step_ema, + step_swap=step_swap, + step_start=step_start, + step_mode=step_mode, + ) optimizer_kwargs["momentum"] = momentum optimizer_kwargs["optimizer_state_dtypes"] = optimizer_state_dtypes @@ -555,14 +558,14 @@ def execute_backward_optimizers_( # noqa C901 if optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD: for t in range(T): iter_ = cc.iter.item() - (m2, m1, prev_iter, row_counter) = split_optimizer_states[t] + (m1, m2) = split_optimizer_states[t] if (m1.dtype == torch.float) and (m2.dtype == torch.float): tol = 1.0e-4 else: tol = 1.0e-2 # Some optimizers have non-float momentums - dense_cpu_grad = bs[t].weight.grad.cpu().to_dense() - m2_ref = dense_cpu_grad.pow(2).mean(dim=1) + m2_ref = torch.mul(bs[t].weight.cpu(), 1.0 - momentum) + weights_ref = m2_ref.mul(1.0) torch.testing.assert_close( m2.float().cpu().index_select(dim=0, index=xs[t].view(-1).cpu()), m2_ref.float() @@ -571,15 +574,8 @@ def execute_backward_optimizers_( # noqa C901 atol=tol, rtol=tol, ) - v_hat_t = m2_ref.view(m2_ref.numel(), 1) - weights_new = split_weights[t] - weights_ref = torch.addcdiv( - bs[t].weight.cpu(), - value=-lr, - tensor1=dense_cpu_grad, - tensor2=v_hat_t.sqrt_().add_(eps), - ) - m1_ref = torch.mul(weights_ref, 1.0 - momentum) + dense_cpu_grad = bs[t].weight.grad.cpu().to_dense() + m1_ref = dense_cpu_grad.pow(2).mean(dim=1) torch.testing.assert_close( m1.float().cpu().index_select(dim=0, index=xs[t].view(-1).cpu()), m1_ref.float() @@ -588,7 +584,14 @@ def execute_backward_optimizers_( # noqa C901 atol=tol, rtol=tol, ) - weights_ref = m1_ref.div(1.0) + v_hat_t = m1_ref.view(m1_ref.numel(), 1) + weights_new = split_weights[t] + weights_ref = torch.addcdiv( + weights_ref, + value=-lr, + tensor1=dense_cpu_grad, + tensor2=v_hat_t.sqrt_().add_(eps), + ) torch.testing.assert_close( weights_new.index_select(dim=0, index=xs[t].view(-1)).cpu(), weights_ref.index_select(dim=0, index=xs[t].view(-1).cpu()), @@ -599,9 +602,7 @@ def execute_backward_optimizers_( # noqa C901 optimizer_states_dict = get_optimizer_states[t] assert set(optimizer_states_dict.keys()) == { "sum", - "exp_avg", - "prev_iter", - "row_counter", + "sparse_ema", } if optimizer in (OptimType.PARTIAL_ROWWISE_LAMB, OptimType.LAMB):