diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 6c68ae841..943f4541c 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -811,14 +811,14 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function( const bool is_experimental = false, const bool use_uniq_cache_locations_bwd = false, const bool use_homogeneous_placements = false, - const bool apply_global_weight_decay = false, + const c10::optional& uvm_cache_stats = c10::nullopt, {%- if "prev_iter_dev" not in args.split_function_arg_names %} const c10::optional& prev_iter_dev = c10::nullopt, {%- endif %} {%- if "iter" not in args.split_function_arg_names %} const int64_t iter = 0, {%- endif %} - const c10::optional& uvm_cache_stats = c10::nullopt + const bool apply_global_weight_decay = false ) { // TODO: refactor into macro {%- if has_gpu_support %} @@ -888,14 +888,14 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) { " bool is_experimental=False, " " bool use_uniq_cache_locations_bwd=False, " " bool use_homogeneous_placements=False, " - " bool apply_global_weight_decay=False, " + " Tensor? uvm_cache_stats=None, " {%- if "prev_iter_dev" not in args.split_function_arg_names %} " Tensor? prev_iter_dev=None, " {%- endif %} {%- if "iter" not in args.split_function_arg_names %} " int iter=0, " {%- endif %} - " Tensor? uvm_cache_stats=None" + " bool apply_global_weight_decay=False " ") -> Tensor", {PT2_COMPLIANT_TAG}); // We're playing a funny trick here: we're using the autograd