From 45fef7123e385c4c2d46649c6eaf40b5541b127d Mon Sep 17 00:00:00 2001 From: Supadchaya Puangpontip Date: Fri, 3 May 2024 20:18:55 -0700 Subject: [PATCH] Fix argument order (#2561) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2561 Move global weight decays parameters to after `uvm_cache_stats` Reviewed By: sryap Differential Revision: D56960935 fbshipit-source-id: ec66f1c416756880950d31dd4ae093d4e62a5ec8 --- .../backward/embedding_backward_split_host_template.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 6c68ae841..943f4541c 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -811,14 +811,14 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function( const bool is_experimental = false, const bool use_uniq_cache_locations_bwd = false, const bool use_homogeneous_placements = false, - const bool apply_global_weight_decay = false, + const c10::optional& uvm_cache_stats = c10::nullopt, {%- if "prev_iter_dev" not in args.split_function_arg_names %} const c10::optional& prev_iter_dev = c10::nullopt, {%- endif %} {%- if "iter" not in args.split_function_arg_names %} const int64_t iter = 0, {%- endif %} - const c10::optional& uvm_cache_stats = c10::nullopt + const bool apply_global_weight_decay = false ) { // TODO: refactor into macro {%- if has_gpu_support %} @@ -888,14 +888,14 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) { " bool is_experimental=False, " " bool use_uniq_cache_locations_bwd=False, " " bool use_homogeneous_placements=False, " - " bool apply_global_weight_decay=False, " + " Tensor? uvm_cache_stats=None, " {%- if "prev_iter_dev" not in args.split_function_arg_names %} " Tensor? prev_iter_dev=None, " {%- endif %} {%- if "iter" not in args.split_function_arg_names %} " int iter=0, " {%- endif %} - " Tensor? uvm_cache_stats=None" + " bool apply_global_weight_decay=False " ") -> Tensor", {PT2_COMPLIANT_TAG}); // We're playing a funny trick here: we're using the autograd