diff --git a/fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu index 92a5bb806..546e1a95b 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu @@ -19,6 +19,10 @@ using Tensor = at::Tensor; using namespace fbgemm_gpu; +//////////////////////////////////////////////////////////////////////////////// +// Kernel Template Definition +//////////////////////////////////////////////////////////////////////////////// + template < typename emb_t, typename grad_t, @@ -384,42 +388,17 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ } // for each run } +//////////////////////////////////////////////////////////////////////////////// +// Explicit Template Instantiations +//////////////////////////////////////////////////////////////////////////////// + /* Explicitly instantiate the kernel function template. The instantiations are based on the types enumerated by DISPATCH_EMB_GRAD_CACHE_TYPES macro used in embedding_backward_split_template.cu */ -{%- for grad_type in ['float', 'at::Half'] %} -{%- for emb_type in ['uint8_t', 'float', 'at::Half'] %} -{%- for cache_type in ['float', 'at::Half'] %} - -//////////////////////////////////////////////////////////////////////////////// -#ifdef FBGEMM_USE_SUBWARP_SHUFFLE -//////////////////////////////////////////////////////////////////////////////// - -{#- /* - Compute the Cartesian product of (kMaxVecsPerThread, kThreadGroupSize) - in the FBGEMM_USE_SUBWARP_SHUFFLE case - - constexpr int kMaxVecsPerThread = std::max({{ kMaxElemPerThread }} / 4, 1); - constexpr int kThreadGroupSize = kWarpSize / std::max(4 / {{ kMaxElemPerThread }}, 1); - - This is needed to compute the unique tuples to use for explicit instantiation, - so that we can avoid duplicate template instantiations. -*/ #} -{%- set tuples = [] %} -{%- for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %} -{%- if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %} - {%- set t0 = [ (kMaxElemPerThread // 4), 1 ] | max %} - {%- set t1 = [ 4 // kMaxElemPerThread, 1] | max %} - {%- set temp = tuples.append((t0, "(kWarpSize / " ~ t1 ~ ")")) %} -{%- endif %} -{%- endfor %} - -{#- /* Enumerate over the unique tuples */ #} -{%- for (kMaxVecsPerThread, kThreadGroupSize) in tuples | unique %} - +{%- macro template_instantiation(emb_type, grad_type, cache_type, kMaxVecsPerThread, kThreadGroupSize) %} template __global__ __launch_bounds__(kMaxThreads) void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1 < {{ emb_type }}, @@ -484,7 +463,51 @@ void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimize const int32_t max_segment_length_per_cta, const bool use_deterministic_algorithms, {{ args.split_kernel_args_no_defaults | replace_pta_namespace() | join(",\n ") | replace("cache_t", cache_type) }}); +{%- endmacro %} + +{%- macro bulk_template_instantiations(kMaxVecsPerThread, kThreadGroupSize) %} + {%- for grad_type in ['float', 'at::Half'] %} + {%- for emb_type in ['uint8_t', 'float', 'at::Half'] %} + {%- for cache_type in ['float', 'at::Half'] %} + {{ template_instantiation(emb_type, grad_type, cache_type, kMaxVecsPerThread, kThreadGroupSize) }} + {%- endfor %} + {%- endfor %} + {%- endfor %} +{%- endmacro %} + +{%- if is_experimental_optimizer %} + +{{ bulk_template_instantiations(max_embedding_dim // items_per_warp, 'kWarpSize') }} + +{%- else %} + +//////////////////////////////////////////////////////////////////////////////// +#ifdef FBGEMM_USE_SUBWARP_SHUFFLE +//////////////////////////////////////////////////////////////////////////////// + +{#- /* + Compute the Cartesian product of (kMaxVecsPerThread, kThreadGroupSize) + in the FBGEMM_USE_SUBWARP_SHUFFLE case + + constexpr int kMaxVecsPerThread = std::max({{ kMaxElemPerThread }} / 4, 1); + constexpr int kThreadGroupSize = kWarpSize / std::max(4 / {{ kMaxElemPerThread }}, 1); + + This is needed to compute the unique tuples to use for explicit instantiation, + so that we can avoid duplicate template instantiations. +*/ #} +{%- set tuples = [] %} +{%- for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %} +{%- if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %} + {%- set t0 = [ (kMaxElemPerThread // 4), 1 ] | max %} + {%- set t1 = [ 4 // kMaxElemPerThread, 1] | max %} + {%- set temp = tuples.append((t0, "(kWarpSize / " ~ t1 ~ ")")) %} +{%- endif %} +{%- endfor %} + +{#- /* Enumerate over the unique tuples */ #} +{%- for (kMaxVecsPerThread, kThreadGroupSize) in tuples | unique %} + {{ bulk_template_instantiations(kMaxVecsPerThread, kThreadGroupSize) }} {%- endfor %} //////////////////////////////////////////////////////////////////////////////// @@ -502,87 +525,18 @@ void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimize {%- for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %} {%- if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %} {%- set t0 = [ (kMaxElemPerThread // 4), 1 ] | max %} - {%- set t1 = [ 4 // kMaxElemPerThread, 1] | max %} {%- set temp = tuples.append((t0, "kWarpSize")) %} {%- endif %} {%- endfor %} {#- /* Enumerate over the unique tuples */ #} {%- for (kMaxVecsPerThread, kThreadGroupSize) in tuples | unique %} - -template __global__ __launch_bounds__(kMaxThreads) -void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1 -< {{ emb_type }}, - {{ grad_type }}, - {{ cache_type }}, - {{ kMaxVecsPerThread }}, - {{ kThreadGroupSize }} -> ( - const pta::PackedTensorAccessor64<{{ grad_type }}, 2, at::RestrictPtrTraits> grad_output, - {%- if optimizer != "none" %} - pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights, - {%- if not dense %} - pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> uvm_weights, - pta::PackedTensorAccessor64<{{ cache_type }}, 2, at::RestrictPtrTraits> lxu_cache_weights, - const pta::PackedTensorAccessor32 - weights_placements, - {%- endif %} - {%- endif %} // if optimizer != "none" - const pta::PackedTensorAccessor32 weights_offsets, - {%- if not nobag %} - const pta::PackedTensorAccessor32 D_offsets, - {%- else %} - int64_t D, - {%- endif %} - const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, - const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, - const pta::PackedTensorAccessor32 long_run_ids, - const pta::PackedTensorAccessor32 num_long_run_ids, - {%- if not nobag %} - const pta::PackedTensorAccessor32 sorted_infos, - {%- else %} - const pta::PackedTensorAccessor32 sorted_infos, - {%- endif %} - {%- if not dense %} - const pta::PackedTensorAccessor32 - sorted_lxu_cache_locations, - {%- endif %} - {%- if weighted %} - const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, - {%- endif %} - {%- if not dense and optimizer != "none" %} - bool stochastic_rounding, - at::PhiloxCudaState stochastic_rounding_philox_args, - {%- else %} - pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> grad_dev_weights, - {%- if optimizer == "none" %} - const int32_t max_D, - {%- endif %} - {%- endif %} // if not dense and optimizer != "none" - {%- if not nobag and vbe %} - const pta::PackedTensorAccessor32 B_offsets, - const pta::PackedTensorAccessor32 output_offsets, - {%- endif %} - {%- if not nobag %} - const int32_t info_B_num_bits, - const uint32_t info_B_mask, - {%- endif %} - const pta::PackedTensorAccessor32 long_run_id_to_really_long_run_ids, - pta::PackedTensorAccessor32, 2, at::RestrictPtrTraits> temp_grad_accum, - pta::PackedTensorAccessor32 grad_accum_counter, - const int32_t max_segment_length_per_cta, - const bool use_deterministic_algorithms, - {{ args.split_kernel_args_no_defaults | replace_pta_namespace() | join(",\n ") | replace("cache_t", cache_type) }}); - + {{ bulk_template_instantiations(kMaxVecsPerThread, kThreadGroupSize) }} {%- endfor %} //////////////////////////////////////////////////////////////////////////////// #endif //////////////////////////////////////////////////////////////////////////////// -{%- endfor %} -{%- endfor %} -{%- endfor %} - +{%- endif %} // clang-format on diff --git a/fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu index 96d174c0f..b6d090852 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu @@ -19,6 +19,10 @@ using Tensor = at::Tensor; using namespace fbgemm_gpu; +//////////////////////////////////////////////////////////////////////////////// +// Kernel Template Definition +//////////////////////////////////////////////////////////////////////////////// + template < typename emb_t, typename grad_t, @@ -228,42 +232,17 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ } +//////////////////////////////////////////////////////////////////////////////// +// Explicit Template Instantiations +//////////////////////////////////////////////////////////////////////////////// + /* Explicitly instantiate the kernel function template. The instantiations are based on the types enumerated by DISPATCH_EMB_GRAD_CACHE_TYPES macro used in embedding_backward_split_template.cu */ -{%- for grad_type in ['float', 'at::Half'] %} -{%- for emb_type in ['uint8_t', 'float', 'at::Half'] %} -{%- for cache_type in ['float', 'at::Half'] %} - -//////////////////////////////////////////////////////////////////////////////// -#ifdef FBGEMM_USE_SUBWARP_SHUFFLE -//////////////////////////////////////////////////////////////////////////////// - -{#- /* - Compute the Cartesian product of (kMaxVecsPerThread, kThreadGroupSize) - in the FBGEMM_USE_SUBWARP_SHUFFLE case - - constexpr int kMaxVecsPerThread = std::max({{ kMaxElemPerThread }} / 4, 1); - constexpr int kThreadGroupSize = kWarpSize / std::max(4 / {{ kMaxElemPerThread }}, 1); - - This is needed to compute the unique tuples to use for explicit instantiation, - so that we can avoid duplicate template instantiations. -*/ #} -{%- set tuples = [] %} -{%- for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %} -{%- if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %} - {%- set t0 = [ (kMaxElemPerThread // 4), 1 ] | max %} - {%- set t1 = [ 4 // kMaxElemPerThread, 1] | max %} - {%- set temp = tuples.append((t0, "(kWarpSize / " ~ t1 ~ ")")) %} -{%- endif %} -{%- endfor %} - -{#- /* Enumerate over the unique tuples */ #} -{%- for (kMaxVecsPerThread, kThreadGroupSize) in tuples | unique %} - +{%- macro template_instantiation(emb_type, grad_type, cache_type, kMaxVecsPerThread, kThreadGroupSize) %} template __global__ __launch_bounds__(kBackwardMaxThreads) void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1 < {{ emb_type }}, @@ -321,7 +300,51 @@ void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimize const uint32_t info_B_mask, {%- endif %} {{ args.split_kernel_args_no_defaults | replace_pta_namespace() | join(",\n ") | replace("cache_t", cache_type) }}); +{%- endmacro %} + +{%- macro bulk_template_instantiations(kMaxVecsPerThread, kThreadGroupSize) %} + {%- for grad_type in ['float', 'at::Half'] %} + {%- for emb_type in ['uint8_t', 'float', 'at::Half'] %} + {%- for cache_type in ['float', 'at::Half'] %} + {{ template_instantiation(emb_type, grad_type, cache_type, kMaxVecsPerThread, kThreadGroupSize) }} + {%- endfor %} + {%- endfor %} + {%- endfor %} +{%- endmacro %} + + +{%- if is_experimental_optimizer %} + +{{ bulk_template_instantiations(max_embedding_dim // items_per_warp, 'kWarpSize') }} + +{%- else %} + +//////////////////////////////////////////////////////////////////////////////// +#ifdef FBGEMM_USE_SUBWARP_SHUFFLE +//////////////////////////////////////////////////////////////////////////////// + +{#- /* + Compute the Cartesian product of (kMaxVecsPerThread, kThreadGroupSize) + in the FBGEMM_USE_SUBWARP_SHUFFLE case + + constexpr int kMaxVecsPerThread = std::max({{ kMaxElemPerThread }} / 4, 1); + constexpr int kThreadGroupSize = kWarpSize / std::max(4 / {{ kMaxElemPerThread }}, 1); + + This is needed to compute the unique tuples to use for explicit instantiation, + so that we can avoid duplicate template instantiations. +*/ #} +{%- set tuples = [] %} +{%- for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %} +{%- if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %} + {%- set t0 = [ (kMaxElemPerThread // 4), 1 ] | max %} + {%- set t1 = [ 4 // kMaxElemPerThread, 1] | max %} + {%- set temp = tuples.append((t0, "(kWarpSize / " ~ t1 ~ ")")) %} +{%- endif %} +{%- endfor %} +{#- /* Enumerate over the unique tuples */ #} +{%- for (kMaxVecsPerThread, kThreadGroupSize) in tuples | unique %} + {{ bulk_template_instantiations(kMaxVecsPerThread, kThreadGroupSize) }} {%- endfor %} //////////////////////////////////////////////////////////////////////////////// @@ -339,80 +362,18 @@ void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimize {%- for kMaxElemPerThread in range(1, max_embedding_dim // (items_per_warp // 4) + 1) %} {%- if kMaxElemPerThread in [1, 2] or kMaxElemPerThread % 4 == 0 %} {%- set t0 = [ (kMaxElemPerThread // 4), 1 ] | max %} - {%- set t1 = [ 4 // kMaxElemPerThread, 1] | max %} {%- set temp = tuples.append((t0, "kWarpSize")) %} {%- endif %} {%- endfor %} {#- /* Enumerate over the unique tuples */ #} {%- for (kMaxVecsPerThread, kThreadGroupSize) in tuples | unique %} - -template __global__ __launch_bounds__(kBackwardMaxThreads) -void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1 -< {{ emb_type }}, - {{ grad_type }}, - {{ cache_type }}, - {{ kMaxVecsPerThread }}, - {{ kThreadGroupSize }} -> ( - const pta::PackedTensorAccessor64<{{ grad_type }}, 2, at::RestrictPtrTraits> grad_output, - {%- if optimizer != "none" %} - pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights, - {%- if not dense %} - pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> uvm_weights, - pta::PackedTensorAccessor64<{{ cache_type }}, 2, at::RestrictPtrTraits> lxu_cache_weights, - const pta::PackedTensorAccessor32 weights_placements, - {%- endif %} - {%- endif %} - const pta::PackedTensorAccessor32 weights_offsets, - {%- if not nobag %} - const pta::PackedTensorAccessor32 D_offsets, - {%- else %} - int64_t D, - {%- endif %} - const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, - const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, - {%- if not nobag %} - const pta::PackedTensorAccessor32 sorted_infos, - {%- else %} - const pta::PackedTensorAccessor32 sorted_infos, - {%- endif %} - {%- if not dense %} - const pta::PackedTensorAccessor32 sorted_lxu_cache_locations, - {%- endif %} - {%- if weighted %} - const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, - {%- endif %} - const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, - int32_t max_segment_length_per_warp, - {%- if not dense and optimizer != "none" %} - bool stochastic_rounding, - at::PhiloxCudaState stochastic_rounding_philox_args, - {%- else %} - pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> grad_dev_weights, - {%- if optimizer == "none" %} - const int32_t max_D, - {%- endif %} - {%- endif %} // if not dense and optimizer != "none" - {%- if not nobag and vbe %} - const pta::PackedTensorAccessor32 B_offsets, - const pta::PackedTensorAccessor32 output_offsets, - {%- endif %} - {%- if not nobag %} - const int32_t info_B_num_bits, - const uint32_t info_B_mask, - {%- endif %} - {{ args.split_kernel_args_no_defaults | replace_pta_namespace() | join(",\n ") | replace("cache_t", cache_type) }}); - + {{ bulk_template_instantiations(kMaxVecsPerThread, kThreadGroupSize) }} {%- endfor %} //////////////////////////////////////////////////////////////////////////////// #endif //////////////////////////////////////////////////////////////////////////////// -{%- endfor %} -{%- endfor %} -{%- endfor %} - +{%- endif %} // clang-format on diff --git a/fbgemm_gpu/codegen/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_template.cu index be793ca86..74e8e041a 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_template.cu @@ -177,14 +177,15 @@ grad_mean{{ vbe_desc }}_kernel( // Utility Macros //////////////////////////////////////////////////////////////////////////////// -{%- if experimental_optimizer %} +{%- if is_experimental_optimizer %} + /* For the experimental optimizers, kMaxVecsPerThread and kThreadGroupSize are - fixed to 1024 and kWarpSize, respectively. + fixed to 8 (1024 elements) and kWarpSize, respectively. */ #define DISPATCH_KERNEL_BODY(MAX_D, ...) \ [&] { \ - constexpr auto kMaxVecsPerThread = 1024; \ + constexpr auto kMaxVecsPerThread = {{ max_embedding_dim // items_per_warp }}; \ constexpr auto kThreadGroupSize = kWarpSize; \ return __VA_ARGS__(); \ }() @@ -237,6 +238,7 @@ grad_mean{{ vbe_desc }}_kernel( }() #endif + {%- endif %} diff --git a/fbgemm_gpu/codegen/embedding_common_code_generator.py b/fbgemm_gpu/codegen/embedding_common_code_generator.py index 97d5bb96d..c67e40a28 100644 --- a/fbgemm_gpu/codegen/embedding_common_code_generator.py +++ b/fbgemm_gpu/codegen/embedding_common_code_generator.py @@ -971,6 +971,7 @@ def rowwise_weighted_adagrad() -> Dict[str, Any]: return { "optimizer": "rowwise_weighted_adagrad", + "is_experimental_optimizer": True, "args": make_args( [ (TENSOR, "momentum1"), @@ -1088,6 +1089,7 @@ def lamb() -> Dict[str, Any]: return { "optimizer": "lamb", + "is_experimental_optimizer": True, "args": make_args( [ (TENSOR, "momentum1"), @@ -1232,6 +1234,7 @@ def adam() -> Dict[str, Any]: return { "optimizer": "adam", + "is_experimental_optimizer": True, "args": make_args( [ (TENSOR, "momentum1"), @@ -1361,6 +1364,7 @@ def lars_sgd() -> Dict[str, Any]: return { "optimizer": "lars_sgd", + "is_experimental_optimizer": True, "args": make_args( [ (TENSOR, "momentum1"), diff --git a/fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template b/fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template index 3ff429aac..874fc04ed 100644 --- a/fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template +++ b/fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template @@ -6,11 +6,13 @@ # LICENSE file in the root directory of this source tree. import torch +{%- if is_experimental_optimizer %} +import warnings +{%- endif %} from .lookup_args import * -{% if is_fbcode %} - +{%- if is_fbcode %} # Provide compatibility to downstream packages for eventual migration to the split training / inference packages try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cuda_training") @@ -22,44 +24,61 @@ except Exception: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:cumem_utils") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") -torch.ops.load_library( - "//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings" -) +torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:embedding_inplace_update") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:embedding_inplace_update_cpu") -{% else %} -#import os -#torch.ops.load_library(os.path.join(os.path.join(os.path.dirname(os.path.dirname(__file__)), "fbgemm_gpu_py.so"))) -{% endif %} +{%- else %} +# import os +# torch.ops.load_library(os.path.join(os.path.join(os.path.dirname(os.path.dirname(__file__)), "fbgemm_gpu_py.so"))) +{%- endif %} + + +{%- if is_experimental_optimizer %} +_{{ optimizer }}_first_invocation = True +{%- endif %} def invoke( common_args: CommonArgs, optimizer_args: OptimizerArgs, - {% if "momentum1_dev" in args.split_function_arg_names %} + {%- if "momentum1_dev" in args.split_function_arg_names %} momentum1: Momentum, - {% endif %} - {% if "momentum2_dev" in args.split_function_arg_names %} + {%- endif %} + {%- if "momentum2_dev" in args.split_function_arg_names %} momentum2: Momentum, - {% endif %} - {% if "prev_iter_dev" in args.split_function_arg_names %} + {%- endif %} + {%- if "prev_iter_dev" in args.split_function_arg_names %} prev_iter: Momentum, - {% endif %} - {% if "row_counter_dev" in args.split_function_arg_names %} + {%- endif %} + {%- if "row_counter_dev" in args.split_function_arg_names %} row_counter: Momentum, - {% endif %} - {% if "iter" in args.split_function_arg_names %} + {%- endif %} + {%- if "iter" in args.split_function_arg_names %} iter: int, - {% endif %} - {% if "max_counter" in args.split_function_arg_names %} + {%- endif %} + {%- if "max_counter" in args.split_function_arg_names %} max_counter: float, - {% endif %} - {% if "total_unique_indices" in args.split_function_arg_names %} + {%- endif %} + {%- if "total_unique_indices" in args.split_function_arg_names %} total_unique_indices: int, - {% endif %} + {%- endif %} ) -> torch.Tensor: - {% if has_cpu_support %} + {%- if is_experimental_optimizer %} + global _{{ optimizer }}_first_invocation + if _{{ optimizer }}_first_invocation: + warnings.warn( + f"""\033[93m + [FBGEMM_GPU] NOTE: The training optimizer '{{ optimizer }}' is marked as + EXPERIMENTAL and thus not optimized, in order to reduce code compilation + times and build sizes! + \033[0m""", + RuntimeWarning, + ) + _{{ optimizer }}_first_invocation = False + {%- endif %} + + {%- if has_cpu_support %} if (common_args.host_weights.numel() > 0): return torch.ops.fbgemm.split_embedding_codegen_lookup_{{ optimizer }}_function_cpu( # common_args @@ -80,97 +99,97 @@ def invoke( gradient_clipping = optimizer_args.gradient_clipping, max_gradient=optimizer_args.max_gradient, stochastic_rounding=optimizer_args.stochastic_rounding, - {% if "learning_rate" in args.split_function_arg_names %} + {%- if "learning_rate" in args.split_function_arg_names %} learning_rate=optimizer_args.learning_rate, - {% endif %} - {% if "eps" in args.split_function_arg_names %} + {%- endif %} + {%- if "eps" in args.split_function_arg_names %} eps=optimizer_args.eps, - {% endif %} - {% if "beta1" in args.split_function_arg_names %} + {%- endif %} + {%- if "beta1" in args.split_function_arg_names %} beta1=optimizer_args.beta1, - {% endif %} - {% if "beta2" in args.split_function_arg_names %} + {%- endif %} + {%- if "beta2" in args.split_function_arg_names %} beta2=optimizer_args.beta2, - {% endif %} - {% if "weight_decay" in args.split_function_arg_names %} + {%- endif %} + {%- if "weight_decay" in args.split_function_arg_names %} weight_decay=optimizer_args.weight_decay, - {% endif %} - {% if "weight_decay_mode" in args.split_function_arg_names %} + {%- endif %} + {%- if "weight_decay_mode" in args.split_function_arg_names %} weight_decay_mode=optimizer_args.weight_decay_mode, - {% endif %} - {% if "eta" in args.split_function_arg_names %} + {%- endif %} + {%- if "eta" in args.split_function_arg_names %} eta=optimizer_args.eta, - {% endif %} - {% if "momentum" in args.split_function_arg_names %} + {%- endif %} + {%- if "momentum" in args.split_function_arg_names %} momentum=optimizer_args.momentum, - {% endif %} - {% if "counter_halflife" in args.split_function_arg_names %} + {%- endif %} + {%- if "counter_halflife" in args.split_function_arg_names %} counter_halflife=optimizer_args.counter_halflife, - {% endif %} - {% if "adjustment_iter" in args.split_function_arg_names %} + {%- endif %} + {%- if "adjustment_iter" in args.split_function_arg_names %} adjustment_iter=optimizer_args.adjustment_iter, - {% endif %} - {% if "adjustment_ub" in args.split_function_arg_names %} + {%- endif %} + {%- if "adjustment_ub" in args.split_function_arg_names %} adjustment_ub=optimizer_args.adjustment_ub, - {% endif %} - {% if "learning_rate_mode" in args.split_function_arg_names %} + {%- endif %} + {%- if "learning_rate_mode" in args.split_function_arg_names %} learning_rate_mode=optimizer_args.learning_rate_mode, - {% endif %} - {% if "grad_sum_decay" in args.split_function_arg_names %} + {%- endif %} + {%- if "grad_sum_decay" in args.split_function_arg_names %} grad_sum_decay=optimizer_args.grad_sum_decay, - {% endif %} - {% if "tail_id_threshold" in args.split_function_arg_names %} + {%- endif %} + {%- if "tail_id_threshold" in args.split_function_arg_names %} tail_id_threshold=optimizer_args.tail_id_threshold, - {% endif %} - {% if "is_tail_id_thresh_ratio" in args.split_function_arg_names %} + {%- endif %} + {%- if "is_tail_id_thresh_ratio" in args.split_function_arg_names %} is_tail_id_thresh_ratio=optimizer_args.is_tail_id_thresh_ratio, - {% endif %} + {%- endif %} # momentum1 - {% if "momentum1_dev" in args.split_function_arg_names %} + {%- if "momentum1_dev" in args.split_function_arg_names %} momentum1_host=momentum1.host, momentum1_offsets=momentum1.offsets, momentum1_placements=momentum1.placements, - {% endif %} + {%- endif %} # momentum2 - {% if "momentum2_dev" in args.split_function_arg_names %} + {%- if "momentum2_dev" in args.split_function_arg_names %} momentum2_host=momentum2.host, momentum2_offsets=momentum2.offsets, momentum2_placements=momentum2.placements, - {% endif %} + {%- endif %} # prev_iter - {% if "prev_iter_dev" in args.split_function_arg_names %} + {%- if "prev_iter_dev" in args.split_function_arg_names %} prev_iter_host=prev_iter.host, prev_iter_offsets=prev_iter.offsets, prev_iter_placements=prev_iter.placements, - {% endif %} + {%- endif %} # row_counter - {% if "row_counter_dev" in args.split_function_arg_names %} + {%- if "row_counter_dev" in args.split_function_arg_names %} row_counter_host=row_counter.host, row_counter_offsets=row_counter.offsets, row_counter_placements=row_counter.placements, - {% endif %} + {%- endif %} # iter - {% if "iter" in args.split_function_arg_names %} + {%- if "iter" in args.split_function_arg_names %} iter=iter, - {% endif %} + {%- endif %} # max counter - {% if "max_counter" in args.split_function_arg_names %} + {%- if "max_counter" in args.split_function_arg_names %} max_counter=max_counter, - {% endif %} + {%- endif %} ) - {% if not has_gpu_support %} + {%- if not has_gpu_support %} else: assert False, "{{ optimizer }} has only CPU support. host_weights.numel() must be greater than 0." - {% endif %} - {% endif %} + {%- endif %} + {%- endif %} - {% if has_gpu_support %} + {%- if has_gpu_support %} vbe_metadata = common_args.vbe_metadata return torch.ops.fbgemm.split_embedding_codegen_lookup_{{ optimizer }}_function( # common_args - {% if not dense %} + {%- if not dense %} placeholder_autograd_tensor=common_args.placeholder_autograd_tensor, - {% endif %} + {%- endif %} dev_weights=common_args.dev_weights, uvm_weights=common_args.uvm_weights, lxu_cache_weights=common_args.lxu_cache_weights, @@ -195,99 +214,99 @@ def invoke( max_B_feature_rank=vbe_metadata.max_B_feature_rank, vbe_output_size=vbe_metadata.output_size, # optimizer_args - {% if optimizer == "none" %} + {%- if optimizer == "none" %} total_hash_size = optimizer_args.total_hash_size, - {% else %} + {%- else %} gradient_clipping = optimizer_args.gradient_clipping, max_gradient=optimizer_args.max_gradient, stochastic_rounding=optimizer_args.stochastic_rounding, - {% endif %} # if optimizer == none - {% if "learning_rate" in args.split_function_arg_names %} + {%- endif %} # if optimizer == none + {%- if "learning_rate" in args.split_function_arg_names %} learning_rate=optimizer_args.learning_rate, - {% endif %} - {% if "eps" in args.split_function_arg_names %} + {%- endif %} + {%- if "eps" in args.split_function_arg_names %} eps=optimizer_args.eps, - {% endif %} - {% if "beta1" in args.split_function_arg_names %} + {%- endif %} + {%- if "beta1" in args.split_function_arg_names %} beta1=optimizer_args.beta1, - {% endif %} - {% if "beta2" in args.split_function_arg_names %} + {%- endif %} + {%- if "beta2" in args.split_function_arg_names %} beta2=optimizer_args.beta2, - {% endif %} - {% if "weight_decay" in args.split_function_arg_names %} + {%- endif %} + {%- if "weight_decay" in args.split_function_arg_names %} weight_decay=optimizer_args.weight_decay, - {% endif %} - {% if "weight_decay_mode" in args.split_function_arg_names %} + {%- endif %} + {%- if "weight_decay_mode" in args.split_function_arg_names %} weight_decay_mode=optimizer_args.weight_decay_mode, - {% endif %} - {% if "eta" in args.split_function_arg_names %} + {%- endif %} + {%- if "eta" in args.split_function_arg_names %} eta=optimizer_args.eta, - {% endif %} - {% if "momentum" in args.split_function_arg_names %} + {%- endif %} + {%- if "momentum" in args.split_function_arg_names %} momentum=optimizer_args.momentum, - {% endif %} - {% if "counter_halflife" in args.split_function_arg_names %} + {%- endif %} + {%- if "counter_halflife" in args.split_function_arg_names %} counter_halflife=optimizer_args.counter_halflife, - {% endif %} - {% if "adjustment_iter" in args.split_function_arg_names %} + {%- endif %} + {%- if "adjustment_iter" in args.split_function_arg_names %} adjustment_iter=optimizer_args.adjustment_iter, - {% endif %} - {% if "adjustment_ub" in args.split_function_arg_names %} + {%- endif %} + {%- if "adjustment_ub" in args.split_function_arg_names %} adjustment_ub=optimizer_args.adjustment_ub, - {% endif %} - {% if "learning_rate_mode" in args.split_function_arg_names %} + {%- endif %} + {%- if "learning_rate_mode" in args.split_function_arg_names %} learning_rate_mode=optimizer_args.learning_rate_mode, - {% endif %} - {% if "grad_sum_decay" in args.split_function_arg_names %} + {%- endif %} + {%- if "grad_sum_decay" in args.split_function_arg_names %} grad_sum_decay=optimizer_args.grad_sum_decay, - {% endif %} - {% if "tail_id_threshold" in args.split_function_arg_names %} + {%- endif %} + {%- if "tail_id_threshold" in args.split_function_arg_names %} tail_id_threshold=optimizer_args.tail_id_threshold, - {% endif %} - {% if "is_tail_id_thresh_ratio" in args.split_function_arg_names %} + {%- endif %} + {%- if "is_tail_id_thresh_ratio" in args.split_function_arg_names %} is_tail_id_thresh_ratio=optimizer_args.is_tail_id_thresh_ratio, - {% endif %} + {%- endif %} # momentum1 - {% if "momentum1_dev" in args.split_function_arg_names %} + {%- if "momentum1_dev" in args.split_function_arg_names %} momentum1_dev=momentum1.dev, momentum1_uvm=momentum1.uvm, momentum1_offsets=momentum1.offsets, momentum1_placements=momentum1.placements, - {% endif %} + {%- endif %} # momentum2 - {% if "momentum2_dev" in args.split_function_arg_names %} + {%- if "momentum2_dev" in args.split_function_arg_names %} momentum2_dev=momentum2.dev, momentum2_uvm=momentum2.uvm, momentum2_offsets=momentum2.offsets, momentum2_placements=momentum2.placements, - {% endif %} + {%- endif %} # prev_iter - {% if "prev_iter_dev" in args.split_function_arg_names %} + {%- if "prev_iter_dev" in args.split_function_arg_names %} prev_iter_dev=prev_iter.dev, prev_iter_uvm=prev_iter.uvm, prev_iter_offsets=prev_iter.offsets, prev_iter_placements=prev_iter.placements, - {% endif %} + {%- endif %} # row_counter - {% if "row_counter_dev" in args.split_function_arg_names %} + {%- if "row_counter_dev" in args.split_function_arg_names %} row_counter_dev=row_counter.dev, row_counter_uvm=row_counter.uvm, row_counter_offsets=row_counter.offsets, row_counter_placements=row_counter.placements, - {% endif %} + {%- endif %} # iter - {% if "iter" in args.split_function_arg_names %} + {%- if "iter" in args.split_function_arg_names %} iter=iter, - {% endif %} + {%- endif %} # max counter - {% if "max_counter" in args.split_function_arg_names %} + {%- if "max_counter" in args.split_function_arg_names %} max_counter=max_counter, - {% endif %} + {%- endif %} # total_unique_indices - {% if "total_unique_indices" in args.split_function_arg_names %} + {%- if "total_unique_indices" in args.split_function_arg_names %} total_unique_indices = total_unique_indices, - {% endif %} + {%- endif %} output_dtype=common_args.output_dtype, is_experimental=common_args.is_experimental, ) - {% endif %} + {%- endif %}