From a5d8432278333df8852e8b435302f899d41b2e24 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Fri, 13 Sep 2024 04:52:53 -0700 Subject: [PATCH 01/27] Redefine FBGEMM targets with gpu_cpp_library [10/N] (#3131) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3131 X-link: https://github.com/facebookresearch/FBGEMM/pull/218 - Redefine `input_combine_*` targets using `gpu_cpp_library` Reviewed By: spcyppt Differential Revision: D62611707 fbshipit-source-id: f6ddbb989eaa99b82925e25e1522b8a4718b396b --- fbgemm_gpu/src/placeholder.cpp | 5 ++--- fbgemm_gpu/test/combine/common.py | 10 ++++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/fbgemm_gpu/src/placeholder.cpp b/fbgemm_gpu/src/placeholder.cpp index 16d443485..82cf3aa52 100644 --- a/fbgemm_gpu/src/placeholder.cpp +++ b/fbgemm_gpu/src/placeholder.cpp @@ -6,7 +6,6 @@ * LICENSE file in the root directory of this source tree. */ -/* - This is placeholder code to force compilation and generation of .so file. -*/ +/// This is a placeholder source file that is used to force compilation and +/// generation of an .SO file. namespace fbgemm_gpu {} diff --git a/fbgemm_gpu/test/combine/common.py b/fbgemm_gpu/test/combine/common.py index a25d7cc9f..75c8bb529 100644 --- a/fbgemm_gpu/test/combine/common.py +++ b/fbgemm_gpu/test/combine/common.py @@ -11,14 +11,16 @@ import fbgemm_gpu import torch +from fbgemm_gpu.utils.loader import load_torch_module + # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. open_source: bool = getattr(fbgemm_gpu, "open_source", False) if not open_source: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine") + load_torch_module( + "//deeplearning/fbgemm/fbgemm_gpu:input_combine", + hip_path="//deeplearning/fbgemm/fbgemm_gpu:input_combine_hip", + ) torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_cpu") From ad3cf9f4549a6c7c349c345af9c4915641d69a88 Mon Sep 17 00:00:00 2001 From: Shiyan Deng Date: Fri, 13 Sep 2024 09:04:08 -0700 Subject: [PATCH 02/27] add to device for tbe module inputs (#3117) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/225 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3117 This makes it easier to handle h2d data transfer when running tgif python models. Differential Revision: D62535037 fbshipit-source-id: fdb8ae482c95a6b42fae1f33e839584a15fa7e4d --- ..._table_batched_embeddings_ops_inference.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index f8b691be4..721e4f248 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -153,6 +153,26 @@ def random_quant_scaled_tensor( ) +@torch.fx.wrap +def inputs_to_device( + indices: torch.Tensor, + offsets: torch.Tensor, + per_sample_weights: Optional[torch.Tensor], + device: torch.device, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + if device.type == "meta": + return indices, offsets, per_sample_weights + + non_blocking = device.type != "cpu" + if indices.device != device: + indices = indices.to(device, non_blocking=non_blocking) + if offsets.device != device: + offsets = offsets.to(device, non_blocking=non_blocking) + if per_sample_weights is not None and per_sample_weights.device != device: + per_sample_weights = per_sample_weights.to(device, non_blocking=non_blocking) + return indices, offsets, per_sample_weights + + # pyre-fixme[13]: Attribute `cache_miss_counter` is never initialized. class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module): """ @@ -756,6 +776,10 @@ def forward( self.weight_initialized ), "weight needs to be initialized before forward function" + indices, offsets, per_sample_weights = inputs_to_device( + indices, offsets, per_sample_weights, self.bounds_check_warning.device + ) + # First bound check: check if the indices/offsets are within the boundary # of the original embedding rows before pruning. # Note that this is only applied when we enable pruning (if the perf becomes From 6f1c0c3a47f4dca7a7a29cd2e0b083b5e387d3e1 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Fri, 13 Sep 2024 10:08:38 -0700 Subject: [PATCH 03/27] use correct operator signature - SymIntArrayRef and IntArrayRef (#3133) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3133 X-link: https://github.com/facebookresearch/FBGEMM/pull/223 # context * make fbgemm operator `permute_multi_embedding` PT2 compatible. * `out_lengths` is the list of sizes for all the output KT, which should be dynamic dims. * change the `out_lengths` from `std::vector` to `c10::SymIntArrayRef`, and other type compatibility fixes. * actually should use `c10::IntArrayRef` for the cpu and gpu impl, **only use SymInt in meta function** Reviewed By: ezyang Differential Revision: D62622020 fbshipit-source-id: 78be244e8bd837ab3c47f6320b169694bec51336 --- .../fbgemm_gpu/permute_multi_embedding_function.h | 4 ++-- .../permute_multi_embedding_function.cpp | 4 ++-- .../permute_multi_embedding_ops.cu | 12 +++--------- .../permute_multi_embedding_ops_cpu.cpp | 12 +++--------- 4 files changed, 10 insertions(+), 22 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h b/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h index 4bbe877f0..1cfc1a987 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h +++ b/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h @@ -48,7 +48,7 @@ std::vector permute_multi_embedding_function_cpu( const Tensor& permutes, const Tensor& in_shapes, const Tensor& out_shapes, - const c10::SymIntArrayRef out_lengths, + const c10::IntArrayRef out_lengths, const bool& reverse_permute); std::vector permute_multi_embedding_function_meta( @@ -64,7 +64,7 @@ std::vector permute_multi_embedding_function_gpu( const Tensor& permutes, const Tensor& in_shapes, const Tensor& out_shapes, - const c10::SymIntArrayRef out_lengths, + const c10::IntArrayRef out_lengths, const bool& reverse_permute); std::tuple, std::vector, std::vector> diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp index ed8daa31b..1eba2be30 100644 --- a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp @@ -42,7 +42,7 @@ variable_list PermuteMultiEmbeddingOp::forward( const auto permute_op = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::permute_multi_embedding_function", "") - .typed(); + .typed(); return permute_op.call( pooled_embs, permutes, in_shapes, out_shapes, out_lengths, false); @@ -64,7 +64,7 @@ variable_list PermuteMultiEmbeddingOp::backward( const auto permute_op = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::permute_multi_embedding_function", "") - .typed(); + .typed(); auto grad_input = permute_op.call( grad_output, permutes, out_shapes, in_shapes, in_lengths, true); grad_input.push_back(torch::autograd::Variable()); // permutes diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu index 9ef3c62a4..11948bb22 100644 --- a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu @@ -222,7 +222,7 @@ std::vector permute_multi_embedding_function_gpu( const Tensor& permutes, const Tensor& in_shapes, const Tensor& out_shapes, - const c10::SymIntArrayRef out_lengths, + const c10::IntArrayRef out_lengths, const bool& reverse_permute) { // we assume that there's at least one input tensor in the list // it should be enforced from the caller side who has the knowledge. @@ -302,7 +302,7 @@ std::vector permute_multi_embedding_gpu( const Tensor& permutes, const Tensor& in_shapes, const Tensor& out_shapes, - const c10::SymIntArrayRef out_lengths) { + const c10::IntArrayRef out_lengths) { return permute_multi_embedding_function_gpu( pooled_embs, permutes, in_shapes, out_shapes, out_lengths, false); } @@ -314,14 +314,8 @@ std::vector regroup_keyed_tensor_gpu( const std::vector>& groups) { auto [permutes, in_shapes, out_shapes, out_lengths] = kt_regroup_arguments_gpu(pooled_embs[0], keys, lengths, groups); - std::vector out; - std::transform( - out_lengths.begin(), - out_lengths.end(), - std::back_inserter(out), - [](const int32_t v) { return c10::SymInt(v); }); return permute_multi_embedding_function_gpu( - pooled_embs, permutes, in_shapes, out_shapes, out, false); + pooled_embs, permutes, in_shapes, out_shapes, out_lengths, false); } } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp index 9ca8f7d2a..75d37c2d3 100644 --- a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp @@ -19,7 +19,7 @@ std::vector permute_multi_embedding_function_cpu( const Tensor& permutes, const Tensor& /* in_shapes */, const Tensor& /* out_shapes */, - const c10::SymIntArrayRef out_lengths, + const c10::IntArrayRef out_lengths, const bool& reverse_permute) { std::vector inputs; inputs.reserve(pooled_embs.size()); @@ -171,7 +171,7 @@ std::vector permute_multi_embedding_cpu( const Tensor& permutes, const Tensor& in_shapes, const Tensor& out_shapes, - const c10::SymIntArrayRef out_lengths) { + const c10::IntArrayRef out_lengths) { return permute_multi_embedding_function_cpu( pooled_embs, permutes, in_shapes, out_shapes, out_lengths, false); } @@ -321,14 +321,8 @@ std::vector regroup_keyed_tensor_cpu( const std::vector>& groups) { auto [permutes, in_shapes, out_shapes, out_lengths] = kt_regroup_arguments_cpu(pooled_embs[0], keys, lengths, groups); - std::vector out; - std::transform( - out_lengths.begin(), - out_lengths.end(), - std::back_inserter(out), - [](const int32_t v) { return c10::SymInt(v); }); return permute_multi_embedding_function_cpu( - pooled_embs, permutes, in_shapes, out_shapes, out, false); + pooled_embs, permutes, in_shapes, out_shapes, out_lengths, false); } std::vector regroup_keyed_tensor_meta( From 8d3e30430589e0018d42a277d3d6f6d316b89d40 Mon Sep 17 00:00:00 2001 From: Supadchaya Puangpontip Date: Fri, 13 Sep 2024 10:38:37 -0700 Subject: [PATCH 04/27] Update Unified TBE API (#3070) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/204 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3070 Reviewed By: sryap Differential Revision: D62167861 fbshipit-source-id: d0a359bcc804f6be68b35db3722408da22de98f5 --- .../genscript/generate_backward_split.py | 50 +- .../genscript/generate_forward_split.py | 11 + .../codegen/genscript/optimizer_args.py | 102 +- ...dding_split_host_pt2_autograd_template.cpp | 1222 ++++++++++------- ...ng_split_host_pt2_cpu_wrapper_template.cpp | 4 +- ...g_split_host_pt2_cuda_wrapper_template.cpp | 388 +++--- 6 files changed, 1040 insertions(+), 737 deletions(-) diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index c7d77e21a..afdcb8b3c 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -163,38 +163,38 @@ def generate_backward_split_gpu(**kwargs: Any) -> None: # has_gpu_support=True if not kwargs.get("dense"): # Generate CUDA autograd - template_filepath = ( - "training/backward/embedding_backward_split_host_template.cpp" - ) + for ssd in [True, False] if kwargs.get("has_ssd_support") else [False]: + template_filepath = ( + "training/backward/embedding_backward_split_host_template.cpp" + ) desc = "ssd" if ssd else "split" + sdesc = "_ssd" if ssd else "" filename = f"gen_embedding_backward_{desc}_{optimizer}.cpp" CodeTemplate.load(template_filepath).write( filename, is_forward=False, ssd=ssd, **kwargs ) - # Generate PT2 unified autograd, and PT2 backward wrapper - for template_filepath, filename in [ - ( - "training/pt2/embedding_split_host_pt2_autograd_template.cpp", - f"gen_embedding_split_{optimizer}_pt2_autograd.cpp", - ), - ( - "training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp", - f"gen_embedding_backward_split_{optimizer}_pt2_cuda_wrapper.cpp", - ), - ]: - CodeTemplate.load(template_filepath).write( - filename, is_forward=False, **kwargs - ) - - if kwargs.get("has_cpu_support") or kwargs.get("has_gpu_support"): - # Generates Python invoker for CUDA + CPU, and PT2 - template = CodeTemplate.load( - "training/python/split_embedding_codegen_lookup_invoker.template" - ) - for ssd in [True, False] if kwargs.get("has_ssd_support") else [False]: - sdesc = "_ssd" if ssd else "" + # Generate PT2 unified autograd, and PT2 backward wrapper for all optimizers + for template_filepath, filename in [ + ( + "training/pt2/embedding_split_host_pt2_autograd_template.cpp", + f"gen_embedding_{desc}_{optimizer}_pt2_autograd.cpp", + ), + ( + "training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp", + f"gen_embedding_backward_{desc}_{optimizer}_pt2_cuda_wrapper.cpp", + ), + ]: + CodeTemplate.load(template_filepath).write( + filename, is_forward=False, ssd=ssd, **kwargs + ) + + if kwargs.get("has_cpu_support") or kwargs.get("has_gpu_support"): + # Generates Python invoker for CUDA + CPU, and PT2 + template = CodeTemplate.load( + "training/python/split_embedding_codegen_lookup_invoker.template" + ) for filename in [ f"lookup_{optimizer}{sdesc}.py", f"lookup_{optimizer}{sdesc}_pt2.py", diff --git a/fbgemm_gpu/codegen/genscript/generate_forward_split.py b/fbgemm_gpu/codegen/genscript/generate_forward_split.py index 8844712b0..285cf9a55 100644 --- a/fbgemm_gpu/codegen/genscript/generate_forward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_forward_split.py @@ -85,6 +85,17 @@ def generate_pt2_wrappers() -> None: is_forward=True, ) + # Generate PT2 forward wrapper (CUDA) + CodeTemplate.load( + "training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp", + ).write( + f"gen_embedding_forward_ssd_pt2_cuda_wrapper.cpp", + has_gpu_support=True, + is_forward=True, + has_vbe_support=True, + ssd=True, + ) + @staticmethod def generate_small_kernels() -> None: # Generate the small kernels (for nobag only) for the forward splits diff --git a/fbgemm_gpu/codegen/genscript/optimizer_args.py b/fbgemm_gpu/codegen/genscript/optimizer_args.py index bd10fa8d6..dab0bc5fa 100644 --- a/fbgemm_gpu/codegen/genscript/optimizer_args.py +++ b/fbgemm_gpu/codegen/genscript/optimizer_args.py @@ -27,6 +27,23 @@ from torch_type_utils import arg_type_to_tensor_type, ArgType, TensorType +###################################################################### +# Optimizer Args Set Item +###################################################################### + + +@dataclass +class OptimizerArgsSetItem: + ty: ArgType # type + name: str + default: Union[float, ArgType] = 0 # DEFAULT_ARG_VAL + ph_tys: Optional[List[ArgType]] = None # placeholder types + + +# Alias b/c the name is too long +OptimItem = OptimizerArgsSetItem + + ###################################################################### ## Helper functions for the code generator script ## ###################################################################### @@ -117,6 +134,11 @@ def int_tensor_arg(name: str, gpu: bool = True, pass_by_ref: bool = False) -> st return _arg("int32_t", name, gpu=gpu, pass_by_ref=pass_by_ref) +def tensor_list_arg_no_default(name: str, pass_by_ref: bool) -> str: + ref = "&" if pass_by_ref else "" + return f"at::TensorList{ref} {name}" + + def tensor_arg(name: str) -> str: return f"Tensor {name}" @@ -165,6 +187,10 @@ def schema_sym_int_arg_no_default(name: str) -> str: return f"SymInt {name}" +def schema_tensor_list_arg_no_default(name: str) -> str: + return f"Tensor[] {name}" + + def make_kernel_arg( # pyre-fixme[11]: Annotation `ArgType` is not defined as a type. ty: ArgType, @@ -282,21 +308,69 @@ def make_ivalue_cast(ty: ArgType) -> str: }[ty] -###################################################################### -# Optimizer Args Set Item -###################################################################### - - @dataclass -class OptimizerArgsSetItem: - ty: ArgType # type - name: str - default: Union[float, ArgType] = 0 # DEFAULT_ARG_VAL - ph_tys: Optional[List[ArgType]] = None # placeholder types - +class PT2ArgsSet: + split_function_args: List[str] + split_function_arg_names: List[str] + split_function_schemas: List[str] + split_saved_tensor_list: List[str] -# Alias b/c the name is too long -OptimItem = OptimizerArgsSetItem + @staticmethod + # pyre-ignore[3] + def create( + split_arg_spec: List[OptimItem], + ): + """ + PT2ArgsSet.create() is a method that creates different formats given the optimization arguments + to be used in TBE codegen PT2 templates. + + Parameters: + split_arg_spec: List[OptimItem] - list of argument specs + + Returns: + PT2ArgsSet object with the following attributes: + split_function_args: List[str] - List of function arguments + e.g., ['at::TensorList momentum1', 'double eps', 'double weight_decay']. + split_function_arg_names: List[str] - List of argument names + e.g., ['momentum1', 'eps', 'weight_decay']. + split_function_schemas: List[str] - List of arguments in the schema format + e.g., ['Tensor[] momentum1', 'float eps', 'float weight_decay']. + split_saved_tensor_list: List[str] - List of saved tensors for the split function + e.g., ['momentum1']. + """ + split_function_arg_names = [] + split_function_args = [] + split_function_schemas = [] + split_saved_tensor_list = [] + for s in split_arg_spec: + if s.ty in ( + ArgType.TENSOR, + ArgType.INT_TENSOR, + ArgType.LONG_TENSOR, + ArgType.PLACEHOLDER_TENSOR, + ): + name = s.name.rsplit("_", 1)[0] + if name not in split_function_arg_names: + split_function_arg_names.append(name) + split_saved_tensor_list.append(name) + split_function_args.append( + tensor_list_arg_no_default(name, pass_by_ref=False) + ) + split_function_schemas.append( + schema_tensor_list_arg_no_default(name) + ) + else: + split_function_arg_names.append(s.name) + split_function_args.append(make_function_arg(s.ty, s.name, s.default)) + split_function_schemas.append( + make_function_schema_arg(s.ty, s.name, s.default) + ) + return PT2ArgsSet( + split_function_args=split_function_args, + split_function_arg_names=split_function_arg_names, + split_function_schemas=split_function_schemas, + split_saved_tensor_list=split_saved_tensor_list, + ) ###################################################################### @@ -324,6 +398,7 @@ class OptimizerArgs: placeholder_tensor_names: List[str] # pyre-fixme[11]: Annotation `TensorType` is not defined as a type. placeholder_type_combos: Union[List[Dict[str, TensorType]], List[None]] + unified_pt2: PT2ArgsSet @staticmethod # pyre-ignore[3] @@ -419,6 +494,7 @@ def create( ], placeholder_tensor_names=ph_tensor_names, placeholder_type_combos=ph_combos, + unified_pt2=PT2ArgsSet.create(split_arg_spec), ) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 40245fba2..b623b92d0 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -36,22 +36,468 @@ #include #include "fbgemm_gpu/utils/dispatch_macros.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" +#include "fbgemm_gpu/config/feature_gates.h" using Tensor = at::Tensor; using namespace fbgemm_gpu; +{#/* Module description */#} +{%- set fwd_mdesc = "ssd" if ssd else ("dense" if dense else "split") %} +{%- set bwd_mdesc = "ssd" if ssd else "split" %} + +{%- if ssd %} +enum SSDTensor { + {%- for tensor in ssd_tensors %} + {{ tensor | upper }} = {{ loop.index - 1 }}, + {%- endfor %} +}; +{%- endif %} + +//////////////////////////////////////////////////////////////////////////////// +// Macro Helper Functions +//////////////////////////////////////////////////////////////////////////////// + +// TO DO: Refactor +{# +/* This macro generates a code blob for dispatching corresponding weighted and + unweighted forward op from via Pytorch dispatcher +*/ +#} +{%- macro call_forward_op_dispatch(nobag, weighted, vbe, is_gwd) %} + {%- set forward_op = "{}_embedding{}_codegen_forward_{}{}{}_pt2_wrapper".format( + fwd_mdesc, + "_nobag" if nobag else "", + "weighted" if weighted else "unweighted", + "_vbe" if vbe else "", + "_gwd" if is_gwd else "", + ) + %} + {%- set has_experimental = has_experimental_support( + dense, nobag, vbe, is_index_select=False, is_rocm=is_rocm, ssd=ssd + ) and not is_gwd + %} + static auto embedding_codegen_forward_op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::{{ forward_op }}", "") + .typed(); + + return { + embedding_codegen_forward_op.call( + weights_host, + flatten_weights_dev, + weights_uvm, + lxu_cache_weights, + weights_placements, + weights_offsets, + {%- if nobag %} + D, + {%- else %} + D_offsets, + total_D, + max_D, + {%- endif %} + hash_size_cumsum, + indices, + offsets, + {%- if not nobag %} + pooling_mode, + indice_weights_value, + {%- endif %} {# /* if not nobag */ #} + {%- if not dense %} + {{ "ssd_tensors[SSDTensor::ROW_ADDRS]" if ssd else "lxu_cache_locations" }}, + uvm_cache_stats_, + {%- endif %} + {%- if not nobag %} + {%- if vbe %} + vbe_row_output_offsets, + vbe_b_t_map, + vbe_output_size, + info_B_num_bits, + info_B_mask_int64, + {%- endif %} {# /* if vbe */ #} + {%- if is_gwd %} + prev_iter_dev_, + learning_rate, + weight_decay, + iter, + gwd_lower_bound, + {%- endif %} {# /* if is_gwd */ #} + {%- endif %} {# /* if not nobag */ #} + is_experimental, + output_dtype + ) + }; +{%- endmacro %} + +/* This macro generates a code blob for dispatching corresponding weighted and + unweighted backward op via Pytorch dispatcher +*/ +{%- macro call_backward_op_dispatch(nobag, weighted, vbe, is_gwd) %} + {%- set wdesc = "_weighted" if weighted else "_unweighted" %} + {%- set backward_op = "{}_embedding{}_backward_codegen_{}{}{}{}_pt2_wrapper".format( + bwd_mdesc, + "_nobag" if nobag else "", + optimizer, + wdesc, + "_vbe" if vbe else "", + "_gwd" if is_gwd else "", + ) + %} + static auto embedding_codegen{{ wdesc }}_backward_op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::{{ backward_op }}", "") + .typed(); + + grad_weights_dev = embedding_codegen{{ wdesc }}_backward_op.call( + grad_output, + weights_host, + weights_dev, + weights_uvm, + lxu_cache_weights, + weights_placements, + weights_offsets, + {% if nobag %} + D, + {%- else %} + D_offsets, + max_D, + {%- endif %} {# /* if nobag */ #} + hash_size_cumsum, + total_hash_size_bits, + indices, + offsets, + {%- if not nobag %} + pooling_mode, + indice_weights, + {%- endif %} {# /* if not nobag */ #} + {%- if ssd %} + ssd_row_addrs, + {%- else %} + lxu_cache_locations, + {%- endif %} + BT_block_size, + max_segment_length_per_warp, + {%- if optimizer != "none" %} + stochastic_rounding, + {%- endif %} + info_B_num_bits, + info_B_mask_int64, + {%- if vbe %} + B_offsets, + vbe_row_output_offsets, + vbe_b_t_map, + {%- endif %} {# /* if vbe */ #} + {%- if not dense %} + use_uniq_cache_locations_bwd, + use_homogeneous_placements, + {%- endif %} + {%- if is_gwd %} + {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} + prev_iter_dev, + {%- endif %} + {%- if "iter" not in args_pt2.split_function_arg_names %} + iter, + {%- endif %} + gwd_lower_bound, + {%- endif %} {# /* if is_gwd */ #} + {{ args_pt2.split_function_arg_names | join(", ") }} + {%- if not nobag %} + , output_dtype + {%- endif %} + ); + return { + {%- if not dense %} + Tensor(), // placeholder autograd tensor + {%- endif %} + Variable(), // output_dtype + Variable(), // weights_host + grad_weights_dev, // weights_dev + {%- if not dense %} + Variable(), // weights_uvm + Variable(), // lxu_cache_weights + Variable(), // weights_placements + {%- endif %} + Variable(), // weights_offsets + {%- if nobag %} + Variable(), // D + {%- else %} + Variable(), // D_offsets + Variable(), // total_D + Variable(), // max_D + {%- endif %} + Variable(), // hash_size_cumsum + Variable(), //total_hash_size_bits + Variable(), // indices + Variable(), // offsets + {%- if not nobag %} + Variable(), // pooling_mode + grad_indice_weights, // indice_weights + Variable(), // feature_requires_grad + {%- endif %} + {%- if not dense %} + Variable(), // lxu_cache_locations + Variable(), // uvm_cache_stats + {%- endif %} + {%- if optimizer != "none" and not dense %} + Variable(), // gradient_clipping + Variable(), // max_gradient + Variable(), // stochastic_rounding + {%- endif %} + {%- if vbe %} + Variable(), // B_offsets + Variable(), // vbe_output_offsets_feature_rank + Variable(), // vbe_B_offsets_rank_per_feature + Variable(), // max_B + Variable(), // max_B_feature_rank + Variable(), // vbe_output_size + {%- endif %} + {%- if not dense %} + Variable(), // is_experimental + Variable(), // use_uniq_cache_locations_bwd + Variable(), // use_homogeneous_placements + {%- endif %} + {%- if is_gwd %} + {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} + Variable(), // prev_iter_dev + {%- endif %} + {%- if "iter" not in args_pt2.split_function_arg_names %} + Variable(), // iter + {%- endif %} + Variable(), // gwd_lower_bound + {%- endif %} + {%- if ssd %} + {%- for tensor in ssd_tensors %} + Variable(), // {{ tensor }} + {%- endfor %} + {%- endif %} + {{ args_pt2.split_variables | join(", ") }} + }; +{%- endmacro %} + +/* This macro generates a code blob that calls corresponding autograd function + from lookup_function +*/ +{%- macro call_autograd(nobag, vbe, is_gwd) %} + {%- set autograd_func = "{}{}{}{}LookupFunction_{}_Op_pt2".format( + "SSD" if ssd else "Split", + "NoBag" if nobag else "", + "VBE" if vbe else "", + "GWD" if is_gwd else "", + optimizer + ) + %} + return {{ autograd_func }}::apply( + {%- if not dense %} + placeholder_autograd_tensor, + {%- endif %} + output_dtype, + weights, + {%- if not dense %} + lxu_cache_weights, + {%- endif %} + {%- if nobag %} + max_D, + {%- else %} + D_offsets, + total_D, + max_D, + {%- endif %} + hash_size_cumsum, + total_hash_size_bits, + indices, + {%- if not nobag and dense and not vbe %} + offsets, + pooling_mode, + indice_weights, + feature_requires_grad + {%- elif not nobag %} + offsets, + pooling_mode, + indice_weights, + feature_requires_grad, + {%- elif nobag and dense and not vbe %} + offsets + {%- else %} + offsets, + {%- endif %} + {%- if not dense %} + lxu_cache_locations, + uvm_cache_stats, + {%- endif %} + {%- if optimizer != "none" and not dense %} + gradient_clipping, + max_gradient, + stochastic_rounding, + {%- endif %} + {%- if vbe %} + B_offsets, + vbe_output_offsets_feature_rank, + vbe_B_offsets_rank_per_feature, + max_B, + max_B_feature_rank, + vbe_output_size, + {%- endif %} + {%- if not dense %} + is_experimental, + use_uniq_cache_locations_bwd, + use_homogeneous_placements, + {%- if is_gwd %} + {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} + prev_iter_dev, + {%- endif %} + {%- if "iter" not in args_pt2.split_function_arg_names %} + iter, + {%- endif %} + gwd_lower_bound, + {%- endif %} + {%- if ssd %} + ssd_tensors.value(), + {%- endif %} + {{ args_pt2.unified_pt2.split_function_arg_names | join(", ") }} + {%- endif %} + )[0]; +{%- endmacro %} + +/* This macro generates a code blob for unpacking the tensor list +*/ +{%- macro unpack_tensor_list(tensor_list) %} + const Tensor {{ tensor_list }}_host = {{ tensor_list }}[0]; + const Tensor {{ tensor_list }}_dev = {{ tensor_list }}[1]; + const Tensor {{ tensor_list }}_uvm = {{ tensor_list }}[2]; + const Tensor {{ tensor_list }}_placements = {{ tensor_list }}[3]; + const Tensor {{ tensor_list }}_offsets = {{ tensor_list }}[4]; +{%- endmacro %} + + +//////////////////////////////////////////////////////////////////////////////// +// Autograd Function Declarations +//////////////////////////////////////////////////////////////////////////////// + {%- if has_gpu_support or has_cpu_support %} {%- for vbe in ([True, False] if has_vbe_support else [False]) %} {%- set vdesc = "_vbe" if vbe else "" %} -{%- for nobag in [True, False] %} -{%- if not nobag or not vbe %} {#-/* nobag does not support vbe */#} -{%- set autograd_func = "Split{}{}LookupFunction_{}_Op_pt2".format( - "NoBag" if nobag else "", - "VBE" if vbe else "", - optimizer +{%- for nobag in ([False] if (weighted or vbe) else [True, False]) %} +{%- set ndesc = "_nobag" if nobag else "" %} + +{%- for is_gwd in ([True, False] + if is_valid_gwd_config( + dense, + nobag, + vbe, + is_index_select, + has_global_weight_decay_support, + ssd) + else [False]) %} +{%- set gwddesc = "_gwd" if is_gwd else "" %} + +{%- set autograd_func = "{}{}{}{}LookupFunction_{}_Op_pt2".format( + "SSD" if ssd else "Split", + "NoBag" if nobag else "", + "VBE" if vbe else "", + "GWD" if is_gwd else "", + optimizer, ) %} @@ -68,18 +514,14 @@ class {{ autograd_func }} : torch::autograd::AutogradContext* ctx, const Tensor& placeholder_autograd_tensor, const int64_t output_dtype, - const Tensor& host_weights, - const Tensor& dev_weights, - const Tensor& uvm_weights, + const at::TensorList weights, const Tensor& lxu_cache_weights, - const Tensor& weights_placements, - const Tensor& weights_offsets, {%- if not nobag %} const Tensor& D_offsets, - const int64_t total_D, - const int64_t max_D, + const c10::SymInt total_D, + const c10::SymInt max_D, {%- else %} - const int64_t D, + const c10::SymInt D, {%- endif %} const Tensor& hash_size_cumsum, const int64_t total_hash_size_bits, @@ -108,7 +550,25 @@ class {{ autograd_func }} : const bool is_experimental, const bool use_uniq_cache_locations_bwd, const bool use_homogeneous_placements, - {{ args_pt2.split_function_args | join(", ") }}) { + {%- if is_gwd %} + {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} + const std::optional& prev_iter_dev, + {%- endif %} + {%- if "iter" not in args_pt2.split_function_arg_names %} + const int64_t iter, + {%- endif %} + const double gwd_lower_bound, + {%- endif %} + {%- if ssd %} + const at::TensorList& ssd_tensors, + {%- endif %} + {{ args_pt2.unified_pt2.split_function_args | join(", ") }}) { + + // unpack Tensor lists + {{ unpack_tensor_list("weights") }} + {%- for arg_name in args_pt2.unified_pt2.split_saved_tensor_list %} + {{ unpack_tensor_list(arg_name) }} + {%- endfor %} const auto T = weights_offsets.sym_numel(); {%- if vbe %} @@ -124,7 +584,7 @@ class {{ autograd_func }} : // NOTE: The `local_uvm_cache_stats` variable held by the nn.Module has dtype int32_t // TODO: Hook up with frontend code const auto uvm_cache_stats_ = uvm_cache_stats - .value_or(at::empty({0}, uvm_weights.options().dtype(at::kInt))); + .value_or(at::empty({0}, weights_uvm.options().dtype(at::kInt))); // Default values for Dynamo tracing // SymInt does not support bitshifts operator @@ -147,7 +607,16 @@ class {{ autograd_func }} : static auto generate_vbe_metadata_op = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::generate_vbe_metadata", "") - .typed(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const c10::SymInt, const bool, const c10::SymInt, const int64_t, const c10::SymInt)>(); + .typed( + const Tensor&, + const Tensor&, + const Tensor&, + const Tensor&, + const int64_t, + const bool, + const c10::SymInt, + const int64_t, + const c10::SymInt)>(); auto [ vbe_row_output_offsets, @@ -168,18 +637,22 @@ class {{ autograd_func }} : {%- endif %} max_B_feature_rank, info_B_num_bits, - /*total_B=*/offsets.size(0) - 1 + /*total_B=*/offsets.sym_size(0) - 1 ); {%- endif %} // vbe + {%- if is_gwd %} + const auto prev_iter_dev_ = prev_iter_dev.value_or(Tensor()); + {%- endif %} + {%- if not nobag %} const auto indice_weights_value = indice_weights.value_or(Tensor()); {%- endif %} ctx->save_for_backward({ - host_weights, - dev_weights, - uvm_weights, + weights_host, + weights_dev, + weights_uvm, lxu_cache_weights, weights_placements, weights_offsets, @@ -199,6 +672,14 @@ class {{ autograd_func }} : vbe_row_output_offsets, vbe_b_t_map, {%- endif %} + {%- if is_gwd and "prev_iter_dev" not in args_pt2.split_function_arg_names %} + prev_iter_dev_, + {%- endif %} + {%- if ssd %} + {%- for tensor in ssd_tensors %} + ssd_tensors[SSDTensor::{{ tensor | upper }}], + {%- endfor %} + {%- endif %} {{ args_pt2.split_saved_tensors | join(", ") }} }); @@ -220,6 +701,12 @@ class {{ autograd_func }} : ctx->saved_data["info_B_mask"] = info_B_mask_int64; ctx->saved_data["use_uniq_cache_locations_bwd"] = use_uniq_cache_locations_bwd; ctx->saved_data["use_homogeneous_placements"] = use_homogeneous_placements; + {%- if is_gwd %} + {%- if "iter" not in args_pt2.split_function_arg_names %} + ctx->saved_data["iter"] = iter; + {%- endif %} + ctx->saved_data["gwd_lower_bound"] = gwd_lower_bound; + {%- endif %} {%- if not nobag %} ctx->saved_data["output_dtype"] = output_dtype; {%- endif %} @@ -230,136 +717,42 @@ class {{ autograd_func }} : {%- if optimizer == "none" %} // Flatten - const auto& flatten_dev_weights = dev_weights.flatten(); + const auto& flatten_weights_dev = weights_dev.flatten(); {%- else %} - const auto& flatten_dev_weights = dev_weights; + const auto& flatten_weights_dev = weights_dev; {%- endif %} - - {%- if not nobag %} - {%- for weighted in [False, True] %} - {%- set wdesc = "weighted" if weighted else "unweighted" %} - {%- if not weighted %} - if (!indice_weights_value.defined()) { + {%- if nobag %} + // nobag + {{ + call_forward_op_dispatch( + nobag=True, + weighted=False, + vbe=vbe, + is_gwd=is_gwd, + ) + }} {%- else %} - else { - {%- endif %} - {%- set forward_op = "split_embedding_codegen_forward_{}{}_pt2_wrapper".format( - wdesc, vdesc - ) - %} - static auto split_embedding_codegen_forward_op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::{{ forward_op }}", "") - .typed(); - return { - split_embedding_codegen_forward_op.call( - host_weights, - flatten_dev_weights, - uvm_weights, - lxu_cache_weights, - weights_placements, - weights_offsets, - D_offsets, - total_D, - max_D, - hash_size_cumsum, - indices, - offsets, - pooling_mode, - indice_weights_value, - lxu_cache_locations, - uvm_cache_stats_, - {%- if vbe %} - vbe_row_output_offsets, - vbe_b_t_map, - vbe_output_size, - info_B_num_bits, - info_B_mask_int64, - {%- endif %} - is_experimental, - output_dtype - ) - }; + if (indice_weights) { + // weighted + {{ + call_forward_op_dispatch( + nobag=False, + weighted=True, + vbe=vbe, + is_gwd=is_gwd, + ) + }} } - {%- endfor %} - {%- else %} - {%- set forward_nobag_op = "split_embedding_nobag_codegen_forward_unweighted_pt2_wrapper" %} - static auto split_embedding_codegen_forward_op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::{{ forward_nobag_op }}", "") - .typed(); - return { - split_embedding_codegen_forward_op.call( - host_weights, - flatten_dev_weights, - uvm_weights, - lxu_cache_weights, - weights_placements, - weights_offsets, - D, - hash_size_cumsum, - indices, - offsets, - lxu_cache_locations, - uvm_cache_stats_, - /*is_experimental=*/false, - output_dtype - ) - }; - {%- endif %} + // unweighted + {{ + call_forward_op_dispatch( + nobag=False, + weighted=False, + vbe=vbe, + is_gwd=is_gwd, + ) + }} + {%- endif %} {#-/* if not nobag */ #} } static torch::autograd::variable_list backward( @@ -367,9 +760,9 @@ static torch::autograd::variable_list backward( torch::autograd::variable_list grad_outputs) { const auto saved = ctx->get_saved_variables(); auto savedItr = std::begin(saved); - auto host_weights = *savedItr++; - auto dev_weights = *savedItr++; - auto uvm_weights = *savedItr++; + auto weights_host = *savedItr++; + auto weights_dev = *savedItr++; + auto weights_uvm = *savedItr++; auto lxu_cache_weights = *savedItr++; auto weights_placements = *savedItr++; auto weights_offsets = *savedItr++; @@ -389,6 +782,14 @@ static torch::autograd::variable_list backward( auto vbe_row_output_offsets = *savedItr++; auto vbe_b_t_map = *savedItr++; {%- endif %} + {%- if is_gwd and "prev_iter_dev" not in args_pt2.split_function_arg_names %} + auto prev_iter_dev = *savedItr++; + {%- endif %} + {%- if ssd %} + {%- for tensor in ssd_tensors %} + auto ssd_{{ tensor }} = *savedItr++; + {%- endfor %} + {%- endif %} {%- for tensor in args_pt2.split_saved_tensors %} auto {{ tensor }} = *savedItr++; @@ -411,6 +812,13 @@ static torch::autograd::variable_list backward( const int64_t info_B_mask_int64 = ctx->saved_data["info_B_mask"].toInt(); const auto use_uniq_cache_locations_bwd = ctx->saved_data["use_uniq_cache_locations_bwd"].toBool(); const auto use_homogeneous_placements = ctx->saved_data["use_homogeneous_placements"].toBool(); + {%- if is_gwd %} + {%- if "iter" not in args_pt2.split_function_arg_names %} + const auto iter = ctx->saved_data["iter"].toInt(); + {%- endif %} + const auto gwd_lower_bound = ctx->saved_data["gwd_lower_bound"].toDouble(); + {%- endif %} + {%- if not nobag %} auto output_dtype = ctx->saved_data["output_dtype"].toInt(); {%- endif %} @@ -438,22 +846,22 @@ static torch::autograd::variable_list backward( {%- if not nobag %} {%- if optimizer == "none" %} - // Flatten (dev_weights is used in - // split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cuda) - dev_weights = dev_weights.flatten(); + // Flatten (weights_dev is used in + // {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cuda) + weights_dev = weights_dev.flatten(); {%- endif %} {%- set grad_indice_weights_op = - "split_embedding_codegen_grad_indice_weights{}_pt2_wrapper".format(vdesc) + "{}_embedding_codegen_grad_indice_weights{}_pt2_wrapper".format(fwd_mdesc, vdesc) %} - static auto split_embedding_codegen_grad_indice_weights_op = + static auto embedding_codegen_grad_indice_weights_op = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::{{ grad_indice_weights_op }}", "") .typed(); + )>(); const auto grad_indice_weights = !indice_weights.defined() ? Variable() : - split_embedding_codegen_grad_indice_weights_op.call( + embedding_codegen_grad_indice_weights_op.call( grad_output, - host_weights, - dev_weights, - uvm_weights, + weights_host, + weights_dev, + weights_uvm, lxu_cache_weights, weights_placements, weights_offsets, @@ -487,7 +899,11 @@ static torch::autograd::variable_list backward( max_D, indices, offsets, + {%- if ssd %} + ssd_row_addrs, + {%- else %} lxu_cache_locations, + {%- endif %} {%- if vbe %} feature_requires_grad, vbe_row_output_offsets, @@ -498,276 +914,56 @@ static torch::autograd::variable_list backward( feature_requires_grad {%- endif %} ); - - {%- for weighted in [False, True] %} - {%- set wdesc = "weighted" if weighted else "unweighted" %} - {%- set backward_op = "split_embedding_backward_codegen_{}_{}_exact{}_pt2_wrapper".format( - optimizer, wdesc, vdesc + + Tensor grad_weights_dev; + if (indice_weights.defined()) + { + // weighted + {{ + call_backward_op_dispatch( + nobag=False, + weighted=True, + vbe=vbe, + is_gwd=is_gwd, ) - %} - static auto split_embedding_codegen_{{ wdesc }}_backward_op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::{{ backward_op }}", "") - .typed(); - {%- endfor %} {#-/* for weighted */#} - - const auto grad_dev_weights = !indice_weights.defined() ? - {%- for weighted in [False, True] %} - {%- set wdesc = "weighted" if weighted else "unweighted" %} - split_embedding_codegen_{{ wdesc }}_backward_op.call( - grad_output, - host_weights, - dev_weights, - uvm_weights, - lxu_cache_weights, - weights_placements, - weights_offsets, - D_offsets, - max_D, - hash_size_cumsum, - total_hash_size_bits, - indices, - offsets, - pooling_mode, - indice_weights, - lxu_cache_locations, - BT_block_size, - max_segment_length_per_warp, - {%- if optimizer != "none" %} - stochastic_rounding, - {%- endif %} - info_B_num_bits, - info_B_mask_int64, - {%- if vbe %} - B_offsets, - vbe_row_output_offsets, - vbe_b_t_map, - {%- endif %} - use_uniq_cache_locations_bwd, - use_homogeneous_placements, - {{ args_pt2.split_function_arg_names | join(", ") }} - {%- if not nobag %} - , output_dtype - {%- endif %} - ) {{ ":" if not weighted else ";" }} - {%- endfor %} {#-/* for weighted in [False, True] */#} - return { - Tensor(), // placeholder autograd tensor - Variable(), // output_dtype - Variable(), // host_weights - grad_dev_weights, // dev_weights - Variable(), // uvm_weights - Variable(), // lxu_cache_weights - Variable(), // weights_placements - Variable(), // weights_offsets - Variable(), // D_offsets - Variable(), // total_D - Variable(), // max_D - Variable(), // hash_size_cumsum - Variable(), //total_hash_size_bits - Variable(), // indices - Variable(), // offsets - Variable(), // pooling_mode - grad_indice_weights, // indice_weights - Variable(), // feature_requires_grad - Variable(), // lxu_cache_locations - Variable(), // uvm_cache_stats - {%- if optimizer != "none" %} - Variable(), // gradient_clipping - Variable(), // max_gradient - Variable(), // stochastic_rounding - {%- endif %} - {%- if vbe %} - Variable(), // B_offsets - Variable(), // vbe_output_offsets_feature_rank - Variable(), // vbe_B_offsets_rank_per_feature - Variable(), // max_B - Variable(), // max_B_feature_rank - Variable(), // vbe_output_size - {%- endif %} - Variable(), // is_experimental - Variable(), // use_uniq_cache_locations_bwd - Variable(), // use_homogeneous_placements - {{ args_pt2.split_variables | join(", ") }} - }; - {%- else %} - {%- set backward_nobag_op = - "split_embedding_nobag_backward_codegen_{}_unweighted_exact_pt2_wrapper".format( - optimizer + }} + } + // unweighted + {{ + call_backward_op_dispatch( + nobag=False, + weighted=False, + vbe=vbe, + is_gwd=is_gwd, + ) + }} + {%- else %} {#-/* if not nobag */#} + // nobag + Tensor grad_weights_dev; + {{ + call_backward_op_dispatch( + nobag=True, + weighted=False, + vbe=vbe, + is_gwd=is_gwd, ) - %} - - static auto split_embedding_nobag_codegen_backward_op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::{{ backward_nobag_op }}", "") - .typed(); + }} + {%- endif %} {#-/* if not nobag */#} - const auto grad_dev_weights = split_embedding_nobag_codegen_backward_op.call( - grad_output, - host_weights, - dev_weights, - uvm_weights, - lxu_cache_weights, - weights_placements, - weights_offsets, - D, - hash_size_cumsum, - total_hash_size_bits, - indices, - offsets, - lxu_cache_locations, - BT_block_size, - max_segment_length_per_warp, - {%- if optimizer != "none" %} - stochastic_rounding, - {%- endif %} - info_B_num_bits, - info_B_mask_int64, - {%- if vbe %} - B_offsets, - vbe_row_output_offsets, - vbe_b_t_map, - {%- endif %} - use_uniq_cache_locations_bwd, - use_homogeneous_placements, - {{ args_pt2.split_function_arg_names | join(", ") }} - ); - return { - Tensor(), // placeholder autograd tensor - Variable(), // output_dtype - Variable(), // host_weights - grad_dev_weights, // dev_weights - Variable(), // uvm_weights - Variable(), // lxu_cache_weights - Variable(), // weights_placements - Variable(), // weights_offsets - Variable(), // D - Variable(), // hash_size_cumsum - Variable(), // total_hash_size_bits - Variable(), // indices - Variable(), // offsets - Variable(), // lxu_cache_locations - Variable(), // uvm_cache_stats - {%- if optimizer != "none" %} - Variable(), // gradient_clipping - Variable(), // max_gradient - Variable(), // stochastic_rounding - {%- endif %} - {%- if vbe %} - Variable(), // B_offsets - Variable(), // vbe_output_offsets_feature_rank - Variable(), // vbe_B_offsets_rank_per_feature - Variable(), // max_B - Variable(), // max_B_feature_rank - Variable(), // vbe_output_size - {%- endif %} - Variable(), // is_experimental - Variable(), // use_uniq_cache_locations_bwd - Variable(), // use_homogeneous_placements - {{ args_pt2.split_variables | join(", ") }} - }; - {%- endif %} } }; -{%- endif %} {#-/* if not nobag or not vbe */#} +{%- endfor %} {#-/* for is_gwd */#} {%- endfor %} {#-/* for nobag in [True, False] */#} {%- endfor %} {#-/* for vbe in [True, False] */#} ///@ingroup embedding-cuda -Tensor split_embedding_codegen_lookup_{{ optimizer }}_function_pt2( +Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( const Tensor& placeholder_autograd_tensor, - const Tensor& host_weights, - const Tensor& dev_weights, - const Tensor& uvm_weights, + const at::TensorList weights, const Tensor& lxu_cache_weights, - const Tensor& weights_placements, - const Tensor& weights_offsets, const Tensor& D_offsets, - const int64_t total_D, - const int64_t max_D, + const c10::SymInt total_D, + const c10::SymInt max_D, const Tensor& hash_size_cumsum, const int64_t total_hash_size_bits, const Tensor& indices, @@ -781,105 +977,96 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function_pt2( const double max_gradient, const bool stochastic_rounding, {%- endif %} - {{ args_pt2.split_function_args | join(", ") }}, + {{ args_pt2.unified_pt2.split_function_args | join(", ") }}, const int64_t output_dtype = static_cast(SparseType::FP32), - const std::optional& B_offsets = std::optional(), - const std::optional& vbe_output_offsets_feature_rank = std::optional(), - const std::optional& vbe_B_offsets_rank_per_feature = std::optional(), + const std::optional& B_offsets = c10::nullopt, + const std::optional& vbe_output_offsets_feature_rank = c10::nullopt, + const std::optional& vbe_B_offsets_rank_per_feature = c10::nullopt, const c10::SymInt max_B = -1, const c10::SymInt max_B_feature_rank = -1, const c10::SymInt vbe_output_size = -1, - const bool is_experimental = false, + const bool is_experimental_tbe = false, // formerly named is_experimental const bool use_uniq_cache_locations_bwd = false, const bool use_homogeneous_placements = false, - const std::optional& uvm_cache_stats = std::optional()) { - {%- for vbe in ([True, False] if has_vbe_support else [False]) %} - {%- if has_vbe_support %} - {%- if vbe %} - if (B_offsets.has_value()) { - {%- else %} - else { // if (B_offsets.has_value()) - {%- endif %} - {%- endif %} {#-/* if has_vbe_support */#} - {%- for nobag in [True, False] %} - {%- set vbe = False if nobag else vbe %} - {%- set autograd_func = "Split{}{}LookupFunction_{}_Op_pt2".format( - "NoBag" if nobag else "", - "VBE" if vbe else "", - optimizer - ) - %} - {%- if nobag %} - if (static_cast(pooling_mode) == PoolingMode::NONE) { - {%- else %} - else { + const std::optional& uvm_cache_stats = c10::nullopt, + {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} + const std::optional& prev_iter_dev = c10::nullopt, {%- endif %} - return {{ autograd_func }}::apply( - placeholder_autograd_tensor, - output_dtype, - host_weights, - dev_weights, - uvm_weights, - lxu_cache_weights, - weights_placements, - weights_offsets, - {%- if nobag %} - max_D, - {%- else %} - D_offsets, - total_D, - max_D, - {%- endif %} - hash_size_cumsum, - total_hash_size_bits, - indices, - offsets, - {%- if not nobag %} - pooling_mode, - indice_weights, - feature_requires_grad, - {%- endif %} - lxu_cache_locations, - uvm_cache_stats, - {%- if optimizer != "none" %} - gradient_clipping, - max_gradient, - stochastic_rounding, - {%- endif %} - {%- if vbe %} - B_offsets, - vbe_output_offsets_feature_rank, - vbe_B_offsets_rank_per_feature, - max_B, - max_B_feature_rank, - vbe_output_size, - {%- endif %} - is_experimental, - use_uniq_cache_locations_bwd, - use_homogeneous_placements, - {{ args_pt2.split_function_arg_names | join(", ") }} - )[0]; - } - {%- endfor %} {#-/* for nobag */#} + {%- if "iter" not in args_pt2.split_function_arg_names %} + const int64_t iter = 0, + {%- endif %} + const bool apply_global_weight_decay = false, + {%- if ssd %} + const c10::optional& ssd_tensors = c10::nullopt, + {%- endif %} + const double gwd_lower_bound = 0 +) { + {%- if has_gpu_support or has_cpu_support %} + + {%- if not dense %} + // Load the config value from JK once + static auto is_tbev2_enabled = config::is_feature_enabled(config::FeatureGateName::TBE_V2); + + // Set to experimental if either the feature is enabled in JK, or the user specifies to use TBEv2 + const auto is_experimental = is_tbev2_enabled || is_experimental_tbe; + {%- endif %} + + {%- if not ssd %} {%- if has_vbe_support %} + // has vbe support and on gpu + if (B_offsets.has_value() && !(weights[0].numel() > 0)) { + {%- if has_global_weight_decay_support %} + // vbe and has gwd support + if (apply_global_weight_decay && weight_decay > 0) { + {{ call_autograd(nobag=False, vbe=True, is_gwd=True) }} + } + {%- endif %} {#-/* if has_global_weight_decay_support */ #} + // vbe and no gwd support + {{ call_autograd(nobag=False, vbe=True, is_gwd=False) }} + } + {%- endif %} {#-/* if has_vbe_support */ #} + + {%- if has_global_weight_decay_support %} + // has gwd support + if (apply_global_weight_decay && weight_decay > 0) { + // not vbe and gwd + {{ call_autograd(nobag=False, vbe=False, is_gwd=True) }} } + {%- endif %} {#-/* if has_global_weight_decay_support */ #} + {%- endif %} {#-/* if not ssd */#} + + {%- if ssd %} + TORCH_CHECK( + ssd_tensors.value().size() == {{ ssd_tensors | length }}, + "SSD TBE expects {{ ssd_tensors | length }} in ssd_tensors"); {%- endif %} - {%- endfor %} {#-/* vbe */#} + + if (static_cast(pooling_mode) == PoolingMode::NONE) { + // no bag + {{ call_autograd(nobag=True, vbe=False, is_gwd=False) }} + } + else { + {{ call_autograd(nobag=False, vbe=False, is_gwd=False) }} + } + {%- else %} + TORCH_CHECK( + false, + "{{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function is deprecated. Please see https://github.com/pytorch/FBGEMM/discussions/1727 for more detail." + ); + return Tensor(); + {%- endif %} } TORCH_LIBRARY_FRAGMENT(fbgemm, m) { - m.def("split_embedding_codegen_lookup_{{ optimizer }}_function_pt2(" + {%- set op_name = "{}_embedding_codegen_lookup_{}_function_pt2".format(bwd_mdesc, optimizer) %} + m.def("{{ op_name }}(" " Tensor placeholder_autograd_tensor, " - " Tensor host_weights, " - " Tensor dev_weights, " - " Tensor uvm_weights, " + " Tensor[] weights, " " Tensor lxu_cache_weights, " - " Tensor weights_placements, " - " Tensor weights_offsets, " " Tensor D_offsets, " - " int total_D, " - " int max_D, " + " SymInt total_D, " + " SymInt max_D, " " Tensor hash_size_cumsum, " " int total_hash_size_bits, " " Tensor indices, " @@ -893,7 +1080,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " float max_gradient, " " bool stochastic_rounding, " {%- endif %} - " {{ args_pt2.split_function_schemas | join(", ") }}, " + " {{ args_pt2.unified_pt2.split_function_schemas | join(", ") }}, " " int output_dtype=0, " " Tensor? B_offsets=None, " " Tensor? vbe_output_offsets_feature_rank=None, " @@ -901,10 +1088,21 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " SymInt max_B=-1, " " SymInt max_B_feature_rank=-1, " " SymInt vbe_output_size=-1, " - " bool is_experimental=False, " + " bool is_experimental_tbe=False, " " bool use_uniq_cache_locations_bwd=False, " " bool use_homogeneous_placements=False, " - " Tensor? uvm_cache_stats=None" + " Tensor? uvm_cache_stats=None," + {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} + " Tensor? prev_iter_dev=None, " + {%- endif %} + {%- if "iter" not in args_pt2.split_function_arg_names %} + " int iter=0, " + {%- endif %} + " bool apply_global_weight_decay=False, " + {%- if ssd %} + " Tensor[]? ssd_tensors=None," + {%- endif %} + " float gwd_lower_bound=0 " ") -> Tensor", {PT2_COMPLIANT_TAG}); // We're playing a funny trick here: we're using the autograd @@ -913,18 +1111,18 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { // no autograd enabled, and all of the internal implementations redispatch // appropriately m.impl( - "split_embedding_codegen_lookup_{{ optimizer }}_function_pt2", + "{{ op_name }}", torch::dispatch( c10::DispatchKey::Autograd, - TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function_pt2))); + TORCH_FN({{ op_name }}))); m.impl( - "split_embedding_codegen_lookup_{{ optimizer }}_function_pt2", + "{{ op_name }}", torch::dispatch( c10::DispatchKey::Meta, - TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function_pt2))); + TORCH_FN({{ op_name }}))); DISPATCH_TO_CUDA( - "split_embedding_codegen_lookup_{{ optimizer }}_function_pt2", - split_embedding_codegen_lookup_{{ optimizer }}_function_pt2); + " {{ op_name }} ", + {{ op_name }} ); } {%- endif %} {#-/* if has_gpu_support or has_cpu_support */#} diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp index 7cafb32cd..c74355207 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp @@ -107,7 +107,7 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}_pt2_cpu_wrapper( } {% else %} {#-/* PT2 wrapper function for backward CPU */#} -Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_exact_pt2_cpu_wrapper( +Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_pt2_cpu_wrapper( const Tensor& grad_output, const Tensor& host_weights, const Tensor& /*dev_weights*/, @@ -208,7 +208,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { DISPATCH_TO_CPU("{{ embedding_codegen_forward_op }}_wrapper", {{ embedding_codegen_forward_op }}_cpu_wrapper); {%- else %} - {%- set embedding_codegen_backward_op = "split_embedding_backward_codegen_{}_{}_exact_pt2".format( + {%- set embedding_codegen_backward_op = "split_embedding_backward_codegen_{}_{}_pt2".format( optimizer, wdesc ) %} diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp index ded9bf261..623947e2e 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp @@ -30,6 +30,18 @@ using Tensor = at::Tensor; using namespace fbgemm_gpu; +{#/* Module description */#} +{%- set fwd_mdesc = "ssd" if ssd else ("dense" if dense else "split") %} +{%- set bwd_mdesc = "ssd" if ssd else "split" %} + +{%- if ssd %} +enum SSDTensor { + {%- for tensor in ssd_tensors %} + {{ tensor | upper }} = {{ loop.index - 1 }}, + {%- endfor %} +}; +{%- endif %} + {%- for vbe in ([True, False] if has_vbe_support else [False]) %} {%- set vdesc = "_vbe" if vbe else "" %} @@ -37,10 +49,22 @@ using namespace fbgemm_gpu; {%- for nobag in ([False] if (weighted or vbe) else [True, False]) %} {%- set wdesc = "weighted" if weighted else "unweighted" %} {%- set ndesc = "_nobag" if nobag else "" %} +{%- for is_gwd in ([True, False] + if is_valid_gwd_config( + dense, + nobag, + vbe, + is_index_select, + True, + ssd) + else [False]) %} +{%- set gwddesc = "_gwd" if is_gwd else "" %} +{%- set desc_suffix = wdesc + vdesc + gwddesc %} {%- if is_forward %} {#-/* PT2 wrapper function for forward CUDA */#} -Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cuda_wrapper( +{%- for dispatch_type in ["cuda", "meta"] %} +Tensor {{ fwd_mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ desc_suffix }}_pt2_{{ dispatch_type }}_wrapper( const Tensor& /*host_weights*/, const Tensor& dev_weights, const Tensor& uvm_weights, @@ -54,14 +78,14 @@ Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cud const c10::SymInt total_D, const c10::SymInt max_D, {%- endif %} - const Tensor& /*hash_size_cumsum*/, + const Tensor& hash_size_cumsum, const Tensor& indices, const Tensor& offsets, {%- if not nobag %} const int64_t pooling_mode, const Tensor& indice_weights, // CPU always takes indice_weights {%- endif %} - const Tensor& lxu_cache_locations, + const Tensor& {{ "ssd_row_addrs" if ssd else "lxu_cache_locations" }}, const Tensor& uvm_cache_stats, {%- if vbe %} const Tensor& vbe_row_output_offsets, @@ -70,48 +94,61 @@ Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cud const int64_t info_B_num_bits, const int64_t info_B_mask_int64, {%- endif %} + {%- if is_gwd %} + const Tensor& prev_iter_dev, + const double learning_rate, + const double weight_decay, + const int64_t iter, + const double gwd_lower_bound, + {%- endif %} const bool is_experimental, const int64_t output_dtype ){ - {%- set op = "split_embedding{}_codegen_forward_{}{}_cuda".format( - ndesc, wdesc, vdesc + {%- set op = "{}_embedding{}_codegen_forward_{}_cuda".format( + fwd_mdesc, ndesc, desc_suffix ) %} static auto op = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::{{ op }}", "") .typed(); @@ -137,7 +174,11 @@ Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cud {%- if weighted %} indice_weights, {%- endif %} + {%- if ssd %} + ssd_row_addrs, + {%- else %} lxu_cache_locations, + {%- endif %} uvm_cache_stats, output_dtype, {%- if vbe %} @@ -147,131 +188,22 @@ Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cud info_B_num_bits, info_B_mask_int64, {%- endif %} + {%- if is_gwd %} + hash_size_cumsum, + prev_iter_dev, + learning_rate, + weight_decay, + iter, + gwd_lower_bound, + {%- endif %} {# /* if is_gwd */ #} is_experimental ); }; - -{#-/* PT2 wrapper function for forward META */#} -Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_meta_wrapper( - const Tensor& /*host_weights*/, - const Tensor& dev_weights, - const Tensor& uvm_weights, - const Tensor& lxu_cache_weights, - const Tensor& weights_placements, - const Tensor& weights_offsets, - {%- if nobag %} - const c10::SymInt D, - {%- else %} - const Tensor& D_offsets, - const c10::SymInt total_D, - const c10::SymInt max_D, - {%- endif %} - const Tensor& /*hash_size_cumsum*/, - const Tensor& indices, - const Tensor& offsets, - {%- if not nobag %} - const int64_t pooling_mode, - const Tensor& indice_weights, // CPU always takes indice_weights - {%- endif %} - const Tensor& lxu_cache_locations, - const Tensor& uvm_cache_stats, - {%- if vbe %} - const Tensor& vbe_row_output_offsets, - const Tensor& vbe_b_t_map, - const c10::SymInt vbe_output_size, - const int64_t info_B_num_bits, - const int64_t info_B_mask_int64, - {%- endif %} - const bool is_experimental, - const int64_t output_dtype - ){ - {%- set op = "split_embedding{}_codegen_forward_{}{}_cuda".format( - ndesc, wdesc, vdesc - ) - %} - static auto op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::{{ op }}", "") - .typed(); - - return op.call( - dev_weights, - uvm_weights, - lxu_cache_weights, - weights_placements, - weights_offsets, - {%- if not nobag %} - D_offsets, - {%- else %} - D, - {%- endif %} - {%- if not nobag %} - total_D, - {%- endif %} - {%- if not nobag %} - max_D, - {%- endif %} - indices, - offsets, - {%- if not nobag %} - pooling_mode, - {%- endif %} - {%- if weighted %} - indice_weights, - {%- endif %} - lxu_cache_locations, - uvm_cache_stats, - output_dtype, - {%- if vbe %} - vbe_row_output_offsets, - vbe_b_t_map, - vbe_output_size, - info_B_num_bits, // int32_t - info_B_mask_int64, // uint32_t - {%- endif %} - is_experimental - ); - } +{%- endfor %} {#-/*for dispatch_type in ["cuda", "meta"]*/#} {%- else %} {#-/* PT2 wrapper function for backward CUDA */#} -Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_exact{{ vdesc }}_pt2_cuda_wrapper( +Tensor {{ bwd_mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ desc_suffix }}_pt2_cuda_wrapper( const Tensor& grad_output, const Tensor& /*host_weights*/, const Tensor& dev_weights, @@ -293,7 +225,11 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e const int64_t pooling_mode, const Tensor& indice_weights, // currently supports no bag with unweighted {%- endif %} + {%- if ssd %} + const Tensor& ssd_row_addrs, + {%- elif not dense %} const Tensor& lxu_cache_locations, + {%- endif %} const int64_t BT_block_size, const int64_t max_segment_length_per_warp, {%- if optimizer != "none" %} @@ -308,55 +244,73 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e {%- endif %} const bool use_uniq_cache_locations, const bool use_homogeneous_placements, + {%- if is_gwd %} + {%- if "prev_iter_dev" not in args.split_function_arg_names %} + const Tensor& prev_iter_dev, + {%- endif %} + {%- if "iter" not in args.split_function_arg_names %} + const int64_t iter, + {%- endif %} + const double gwd_lower_bound, + {%- endif %} {{ args_pt2.split_function_args | join(", ") }} {%- if not nobag %} , const int64_t output_dtype = static_cast(SparseType::FP32) {%- endif %}){ - {%- set backward_op = "split_embedding{}_backward_codegen_{}_{}_exact{}_cuda".format( - ndesc, optimizer, wdesc, vdesc + {%- set backward_op = "{}_embedding{}_backward_codegen_{}_{}_exact_cuda".format( + bwd_mdesc, ndesc, optimizer, desc_suffix ) %} static auto op = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::{{ backward_op }}", "") .typed Tensor" @@ -547,8 +547,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { ); m.impl("{{ embedding_codegen_forward_op }}_wrapper", torch::dispatch(c10::DispatchKey::Meta, TORCH_FN({{ embedding_codegen_forward_op }}_meta_wrapper))); {%- else %} - {%- set embedding_codegen_backward_op = "split_embedding{}_backward_codegen_{}_{}_exact{}_pt2".format( - ndesc, optimizer, wdesc, vdesc + {%- set embedding_codegen_backward_op = "{}_embedding{}_backward_codegen_{}_{}_pt2".format( + bwd_mdesc, ndesc, optimizer, desc_suffix ) %} m.def("{{ embedding_codegen_backward_op }}_wrapper(" @@ -573,7 +573,11 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " int pooling_mode, " " Tensor indice_weights, " {%- endif %} + {%- if ssd %} + " Tensor ssd_row_addrs, " + {%- else %} " Tensor lxu_cache_locations, " + {%- endif %} " int BT_block_size, " " int max_segment_length_per_warp, " {%- if optimizer != "none" %} @@ -588,6 +592,15 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {%- endif %} " bool use_uniq_cache_locations, " " bool use_homogeneous_placements," + {%- if is_gwd %} + {%- if "prev_iter_dev" not in args.split_function_arg_names %} + " Tensor prev_iter_dev, " + {%- endif %} + {%- if "iter" not in args.split_function_arg_names %} + " int iter, " + {%- endif %} + " float gwd_lower_bound, " + {%- endif %} " {{ args_pt2.split_function_schemas | join(", ") }} " {%- if not nobag %} " , int output_dtype=0 " @@ -598,12 +611,13 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {{ embedding_codegen_backward_op }}_cuda_wrapper ); {%- endif %} + {%- endfor %} {#-/*for is_gwd*/#} {%- endfor %} {#-/*for nobag*/#} {%- endfor %} {#-/*for weighted*/#} {%- if is_forward %} {%- set embedding_codegen_grad_indice_weights_op = - "split_embedding_codegen_grad_indice_weights{}_pt2".format( - vdesc + "{}_embedding_codegen_grad_indice_weights{}_pt2".format( + fwd_mdesc, vdesc ) %} m.def("{{ embedding_codegen_grad_indice_weights_op }}_wrapper(" @@ -618,7 +632,11 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " SymInt max_D, " " Tensor indices, " " Tensor offsets, " + {%- if ssd %} + " Tensor ssd_row_addrs, " + {%- else %} " Tensor lxu_cache_locations, " + {%- endif %} {%- if vbe %} " Tensor feature_requires_grad, " " Tensor vbe_row_output_offsets, " From 78a52d690f755fb888baebd30361cba32ad6ec1f Mon Sep 17 00:00:00 2001 From: root Date: Fri, 13 Sep 2024 11:43:02 -0700 Subject: [PATCH 05/27] fix gen_ai rocm header (#3048) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/226 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3048 Reviewed By: sryap, shintaro-iwasaki Differential Revision: D62329253 Pulled By: q10 fbshipit-source-id: 929af5f52eac72665297b6be218d8fc87265d802 --- .../gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip | 2 +- .../gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip | 2 +- .../gen_ai/src/quantize/ck_extensions/fp8_tensorwise_gemm.hip | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip index 17d460483..ca44bba42 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip @@ -12,7 +12,7 @@ #include #include -#include +#include #include #if defined(USE_ROCM) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip index afb306432..4b9946872 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip @@ -15,7 +15,7 @@ #include #include -#include +#include #include #if defined(USE_ROCM) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_tensorwise_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_tensorwise_gemm.hip index 6170675a2..04b2ba3b0 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_tensorwise_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_tensorwise_gemm.hip @@ -12,7 +12,7 @@ #include #include -#include +#include #include #if defined(USE_ROCM) From b295fc9e13bb7ac47078a20dea983a6680172fc5 Mon Sep 17 00:00:00 2001 From: Joe Wang Date: Fri, 13 Sep 2024 14:18:08 -0700 Subject: [PATCH 06/27] cachelib optimization (#3123) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3123 X-link: https://github.com/facebookresearch/FBGEMM/pull/211 1. formulate bucket power and lock power in setAccessConfig to reduce hashtable contention 2. modify cache put logic to do find first, this will avoid always kicking out an existing item on cach hit cases when cache is full. Reviewed By: q10 Differential Revision: D62041988 fbshipit-source-id: e0aa55a52d83fb9324cb4e2ff79bfdc9ea223c1b --- .../split_embeddings_cache/cachelib_cache.h | 8 +++ .../split_embeddings_cache/cachelib_cache.cpp | 53 ++++++++++++++++--- .../kv_db_table_batched_embeddings.cpp | 18 +++---- .../kv_db_table_batched_embeddings.h | 35 +++++++++++- 4 files changed, 96 insertions(+), 18 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h index 4965c3731..7dc8efe0c 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h @@ -74,6 +74,14 @@ class CacheLibCache { /// deterministic mapping from a embedding index to a specific pool id facebook::cachelib::PoolId get_pool_id(int64_t key); + /// update the LRU queue in cachelib, this is detached from cache->find() + /// so that we could boost up the lookup perf without worrying about LRU queue + /// contention + /// + /// @param read_handles the read handles that record what cache item has been + /// accessed + void batchMarkUseful(const std::vector& read_handles); + /// Add an embedding index and embeddings into cachelib /// /// @param key embedding index to insert diff --git a/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp b/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp index 2510d4e81..4b3bee846 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp @@ -7,6 +7,7 @@ */ #include "fbgemm_gpu/split_embeddings_cache/cachelib_cache.h" +#include #include "fbgemm_gpu/split_embeddings_cache/kv_db_cpp_utils.h" #include "fbgemm_gpu/utils/dispatch_macros.h" @@ -38,7 +39,12 @@ CacheLibCache::CacheLibCache(const CacheConfig& cache_config) for (size_t i = 0; i < cache_config_.num_shards; i++) { pool_ids_.push_back(cache_->addPool( fmt::format("shard_{}", i), - cache_->getCacheMemoryStats().ramCacheSize / cache_config_.num_shards)); + cache_->getCacheMemoryStats().ramCacheSize / cache_config_.num_shards, + std::set{}, + Cache::MMConfig{ + 0, /* promote on every access*/ + true, /*enable promotion on write*/ + true /*enable promotion on read*/})); } } @@ -69,10 +75,24 @@ std::unique_ptr CacheLibCache::initializeCacheLib( }); }; Cache::Config cacheLibConfig; + int64_t rough_num_items = + cache_config_.cache_size_bytes / cache_config_.item_size_bytes; + unsigned int bucket_power = std::log(rough_num_items) / std::log(2) + 1; + // 15 here is a magic number between 10 and 20 + unsigned int lock_power = + std::log(cache_config_.num_shards * 15) / std::log(2) + 1; + XLOG(INFO) << fmt::format( + "Setting up Cachelib for L2 cache, capacity: {}GB, " + "item_size: {}B, max_num_items: {}, bucket_power: {}, lock_power: {}", + config.cache_size_bytes / 1024 / 1024 / 1024, + cache_config_.item_size_bytes, + rough_num_items, + bucket_power, + lock_power); cacheLibConfig.setCacheSize(static_cast(config.cache_size_bytes)) .setRemoveCallback(eviction_cb) .setCacheName("TBEL2Cache") - .setAccessConfig({25 /* bucket power */, 10 /* lock power */}) + .setAccessConfig({bucket_power, lock_power}) .setFullCoredump(false) .validate(); return std::make_unique(cacheLibConfig); @@ -104,16 +124,35 @@ facebook::cachelib::PoolId CacheLibCache::get_pool_id(int64_t key) { return pool_ids_[get_shard_id(key)]; } +void CacheLibCache::batchMarkUseful( + const std::vector& read_handles) { + if (read_handles.empty()) { + return; + } + for (auto& handle : read_handles) { + if (handle) { + cache_->markUseful(handle, facebook::cachelib::AccessMode::kRead); + } + } +} + bool CacheLibCache::put(int64_t key, const at::Tensor& data) { auto key_str = folly::StringPiece(reinterpret_cast(&key), sizeof(int64_t)); - auto item = cache_->allocate(get_pool_id(key), key_str, data.nbytes()); + auto item = cache_->findToWrite(key_str); if (!item) { - XLOG(ERR) << fmt::format("Failed to allocate item {} in cache, skip", key); - return false; + auto alloc_item = + cache_->allocate(get_pool_id(key), key_str, data.nbytes()); + if (!alloc_item) { + XLOG(ERR) << fmt::format( + "Failed to allocate item {} in cache, skip", key); + return false; + } + std::memcpy(alloc_item->getMemory(), data.data_ptr(), data.nbytes()); + cache_->insertOrReplace(std::move(alloc_item)); + } else { + std::memcpy(item->getMemory(), data.data_ptr(), data.nbytes()); } - std::memcpy(item->getMemory(), data.data_ptr(), data.nbytes()); - cache_->insertOrReplace(std::move(item)); return true; } diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp index 6507d2fa6..3aa630744 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp @@ -29,7 +29,7 @@ inline int64_t get_maybe_uvm_scalar(const at::Tensor& tensor) { }; // namespace -std::tuple tensor_copy( +QueueItem tensor_copy( const at::Tensor& indices, const at::Tensor& weights, const at::Tensor& count) { @@ -58,7 +58,7 @@ std::tuple tensor_copy( new_weightss_addr); // dst_start }); *new_count.data_ptr() = num_sets; - return std::make_tuple(new_indices, new_weights, new_count); + return QueueItem{new_indices, new_weights, new_count}; } EmbeddingKVDB::EmbeddingKVDB( @@ -94,9 +94,9 @@ EmbeddingKVDB::EmbeddingKVDB( if (stop_) { return; } - auto& indices = std::get<0>(*filling_item_ptr); - auto& weights = std::get<1>(*filling_item_ptr); - auto& count = std::get<2>(*filling_item_ptr); + auto& indices = filling_item_ptr->indices; + auto& weights = filling_item_ptr->weights; + auto& count = filling_item_ptr->count; if (l2_cache_) { auto evicted_pairs_opt = set_cache(indices, weights, count); @@ -228,8 +228,8 @@ void EmbeddingKVDB::set( // parallelized with other cuda kernels, as long as all updates are finished // before the next L2 cache lookup auto tensor_copy_start_ts = facebook::WallClockUtil::NowInUsecFast(); - auto new_tuple = tensor_copy(indices, weights, count); - weights_to_fill_queue_.enqueue(new_tuple); + auto new_item = tensor_copy(indices, weights, count); + weights_to_fill_queue_.enqueue(new_item); set_tensor_copy_for_cache_update_ += facebook::WallClockUtil::NowInUsecFast() - tensor_copy_start_ts; } @@ -264,8 +264,8 @@ void EmbeddingKVDB::get( // be parallelized with other cuda kernels, as long as all updates are // finished before the next L2 cache lookup auto tensor_copy_start_ts = facebook::WallClockUtil::NowInUsecFast(); - auto new_tuple = tensor_copy(indices, weights, count); - weights_to_fill_queue_.enqueue(new_tuple); + auto new_item = tensor_copy(indices, weights, count); + weights_to_fill_queue_.enqueue(new_item); get_tensor_copy_for_cache_update_ += facebook::WallClockUtil::NowInUsecFast() - tensor_copy_start_ts; } else { diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h index 8970f2f33..a77d1c2b7 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h @@ -62,6 +62,38 @@ class CacheContext { std::vector cached_addr_list; }; +/// @ingroup embedding-ssd +/// +/// @brief queue item for background L2/rocksdb update +/// +/// indices/weights/count are the corresponding set() params +/// +/// read_handles is cachelib abstracted indices/embedding pair metadata, will be +/// later used on updating cachelib LRU queue as we separate it from +/// EmbeddingKVDB::get_cache() +/// +/// mode is used for monitoring rocksdb write as there are 3 writes in each +/// train iteration, +/// - cache lookup will move uncached data from rocksdb into L2 cache on fwd +/// path +/// - L1 cache eviciton will evict data into L2 cache on fwd path +/// - L1 conflict miss will insert into L2 on bwd path +/// those L2 cache filling will potentially trigger rocksdb write once L2 cache +/// is full +struct QueueItem { + at::Tensor indices; + at::Tensor weights; + at::Tensor count; + QueueItem( + at::Tensor src_indices, + at::Tensor src_weights, + at::Tensor src_count) { + indices = src_indices; + weights = src_weights; + count = src_count; + } +}; + /// @ingroup embedding-ssd /// /// @brief A class for interacting with different cache layers and storage @@ -239,8 +271,7 @@ class EmbeddingKVDB : public std::enable_shared_from_this { std::atomic stop_{false}; // buffer queue that stores all the needed indices/weights/action_count to // fill up cache - folly::USPSCQueue, true> - weights_to_fill_queue_; + folly::USPSCQueue weights_to_fill_queue_; // perf stats // -- perf of get() function From 0535bcf6fabf5d1804fecabbc9a7b058bb6459bb Mon Sep 17 00:00:00 2001 From: Joe Wang Date: Fri, 13 Sep 2024 14:18:08 -0700 Subject: [PATCH 07/27] add metrics for ssd logging (#3122) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3122 X-link: https://github.com/facebookresearch/FBGEMM/pull/212 1. number evictions from L2 cache 2. cache memcpy duration(in parallel with rocksdb io read) 3. break down rocksdb write io for 2 calls for fwd path and 1 call for bwd path Reviewed By: q10 Differential Revision: D62208506 fbshipit-source-id: ea312cf07153310d8cd937c3f4cc89530aa2187c --- fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py | 63 ++++++++++++++----- .../split_embeddings_cache/cachelib_cache.h | 5 +- .../ps_table_batched_embeddings.h | 3 +- .../split_embeddings_cache/cachelib_cache.cpp | 6 +- .../kv_db_table_batched_embeddings.cpp | 48 +++++++++----- .../kv_db_table_batched_embeddings.h | 43 ++++++++++--- .../ssd_table_batched_embeddings.h | 51 ++++++++++----- 7 files changed, 158 insertions(+), 61 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 04fac4a62..f0e16b840 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -659,6 +659,9 @@ def __init__( self.l2_num_cache_lookups_stats_name: str = ( f"l2_cache.perf.get.tbe_id{tbe_unique_id}.num_lookups" ) + self.l2_num_cache_evictions_stats_name: str = ( + f"l2_cache.perf.tbe_id{tbe_unique_id}.num_l2_cache_evictions" + ) self.l2_cache_free_mem_stats_name: str = ( f"l2_cache.mem.tbe_id{tbe_unique_id}.free_mem_bytes" ) @@ -686,6 +689,7 @@ def __init__( self.stats_reporter.register_stats(self.l2_num_cache_misses_stats_name) # pyre-ignore self.stats_reporter.register_stats(self.l2_num_cache_lookups_stats_name) + self.stats_reporter.register_stats(self.l2_num_cache_evictions_stats_name) self.stats_reporter.register_stats(self.l2_cache_free_mem_stats_name) self.stats_reporter.register_stats(self.l2_cache_capacity_stats_name) @@ -1744,13 +1748,15 @@ def _report_ssd_io_stats(self) -> None: self.step, self.stats_reporter.report_interval # pyre-ignore ) - if len(ssd_io_duration) != 3: - logging.error("ssd io duration should have 3 elements") + if len(ssd_io_duration) != 5: + logging.error("ssd io duration should have 5 elements") return ssd_read_dur_us = ssd_io_duration[0] - fwd_ssd_write_dur_us = ssd_io_duration[1] - bwd_ssd_write_dur_us = ssd_io_duration[2] + fwd_rocksdb_read_dur = ssd_io_duration[1] + fwd_l1_eviction_dur = ssd_io_duration[2] + bwd_l1_cnflct_miss_write_back_dur = ssd_io_duration[3] + flush_write_dur = ssd_io_duration[4] # pyre-ignore self.stats_reporter.report_duration( @@ -1762,15 +1768,29 @@ def _report_ssd_io_stats(self) -> None: # pyre-ignore self.stats_reporter.report_duration( iteration_step=self.step, - event_name="ssd.io_duration.fwd_write_us", - duration_ms=fwd_ssd_write_dur_us, + event_name="ssd.io_duration.write.fwd_rocksdb_read_us", + duration_ms=fwd_rocksdb_read_dur, + time_unit="us", + ) + # pyre-ignore + self.stats_reporter.report_duration( + iteration_step=self.step, + event_name="ssd.io_duration.write.fwd_l1_eviction_us", + duration_ms=fwd_l1_eviction_dur, + time_unit="us", + ) + # pyre-ignore + self.stats_reporter.report_duration( + iteration_step=self.step, + event_name="ssd.io_duration.write.bwd_l1_cnflct_miss_write_back_us", + duration_ms=bwd_l1_cnflct_miss_write_back_dur, time_unit="us", ) # pyre-ignore self.stats_reporter.report_duration( iteration_step=self.step, - event_name="ssd.io_duration.bwd_write_us", - duration_ms=bwd_ssd_write_dur_us, + event_name="ssd.io_duration.write.flush_write_us", + duration_ms=flush_write_dur, time_unit="us", ) @@ -1830,8 +1850,8 @@ def _report_l2_cache_perf_stats(self) -> None: self.step, stats_reporter.report_interval # pyre-ignore ) - if len(l2_cache_perf_stats) != 11: - logging.error("l2 perf stats should have 11 elements") + if len(l2_cache_perf_stats) != 13: + logging.error("l2 perf stats should have 13 elements") return num_cache_misses = l2_cache_perf_stats[0] @@ -1840,12 +1860,14 @@ def _report_l2_cache_perf_stats(self) -> None: get_cache_lookup_total_duration = l2_cache_perf_stats[3] get_cache_lookup_wait_filling_thread_duration = l2_cache_perf_stats[4] get_weights_fillup_total_duration = l2_cache_perf_stats[5] - total_cache_update_duration = l2_cache_perf_stats[6] - get_tensor_copy_for_cache_update_duration = l2_cache_perf_stats[7] - set_tensor_copy_for_cache_update_duration = l2_cache_perf_stats[8] + get_cache_memcpy_duration = l2_cache_perf_stats[6] + total_cache_update_duration = l2_cache_perf_stats[7] + get_tensor_copy_for_cache_update_duration = l2_cache_perf_stats[8] + set_tensor_copy_for_cache_update_duration = l2_cache_perf_stats[9] + num_l2_evictions = l2_cache_perf_stats[10] - l2_cache_free_bytes = l2_cache_perf_stats[9] - l2_cache_capacity = l2_cache_perf_stats[10] + l2_cache_free_bytes = l2_cache_perf_stats[11] + l2_cache_capacity = l2_cache_perf_stats[12] stats_reporter.report_data_amount( iteration_step=self.step, @@ -1857,6 +1879,11 @@ def _report_l2_cache_perf_stats(self) -> None: event_name=self.l2_num_cache_lookups_stats_name, data_bytes=num_lookups, ) + stats_reporter.report_data_amount( + iteration_step=self.step, + event_name=self.l2_num_cache_evictions_stats_name, + data_bytes=num_l2_evictions, + ) stats_reporter.report_data_amount( iteration_step=self.step, event_name=self.l2_cache_capacity_stats_name, @@ -1892,6 +1919,12 @@ def _report_l2_cache_perf_stats(self) -> None: duration_ms=get_weights_fillup_total_duration, time_unit="us", ) + stats_reporter.report_duration( + iteration_step=self.step, + event_name="l2_cache.perf.get.cache_memcpy_duration_us", + duration_ms=get_cache_memcpy_duration, + time_unit="us", + ) stats_reporter.report_duration( iteration_step=self.step, event_name="l2_cache.perf.total.cache_update_duration_us", diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h index 7dc8efe0c..0f3bacbb8 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h @@ -128,8 +128,9 @@ class CacheLibCache { const at::Tensor& count); /// reset slot pointer that points to the next available slot in the eviction - /// tensors - void reset_eviction_states(); + /// tensors and returns number of slots filled + /// @return number evictions + int64_t reset_eviction_states(); /// get the filled indices and weights tensors from L2 eviction, could be all /// invalid if no eviction happened diff --git a/fbgemm_gpu/src/ps_split_embeddings_cache/ps_table_batched_embeddings.h b/fbgemm_gpu/src/ps_split_embeddings_cache/ps_table_batched_embeddings.h index 897672943..0498bda96 100644 --- a/fbgemm_gpu/src/ps_split_embeddings_cache/ps_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ps_split_embeddings_cache/ps_table_batched_embeddings.h @@ -48,7 +48,8 @@ class EmbeddingParameterServer : public kv_db::EmbeddingKVDB { const at::Tensor& indices, const at::Tensor& weights, const at::Tensor& count, - const bool is_bwd = false) override { + const kv_db::RocksdbWriteMode w_mode = + kv_db::RocksdbWriteMode::FWD_ROCKSDB_READ) override { RECORD_USER_SCOPE("EmbeddingParameterServer::set"); co_await tps_client_->set(indices, weights, count.item().toLong()); } diff --git a/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp b/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp index 4b3bee846..f12b6e947 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp @@ -209,8 +209,10 @@ void CacheLibCache::init_tensor_for_l2_eviction( at::TensorOptions().device(weights.device()).dtype(weights.dtype()))); } -void CacheLibCache::reset_eviction_states() { - eviction_row_id = 0; +int64_t CacheLibCache::reset_eviction_states() { + int64_t reset_val = 0; + auto num_eviction = eviction_row_id.exchange(reset_val); + return num_eviction; } folly::Optional> diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp index 3aa630744..5d12dbbc6 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp @@ -32,7 +32,8 @@ inline int64_t get_maybe_uvm_scalar(const at::Tensor& tensor) { QueueItem tensor_copy( const at::Tensor& indices, const at::Tensor& weights, - const at::Tensor& count) { + const at::Tensor& count, + kv_db::RocksdbWriteMode mode) { auto num_sets = get_maybe_uvm_scalar(count); auto new_indices = at::empty( num_sets, at::TensorOptions().device(at::kCPU).dtype(indices.dtype())); @@ -58,7 +59,7 @@ QueueItem tensor_copy( new_weightss_addr); // dst_start }); *new_count.data_ptr() = num_sets; - return QueueItem{new_indices, new_weights, new_count}; + return QueueItem{new_indices, new_weights, new_count, mode}; } EmbeddingKVDB::EmbeddingKVDB( @@ -97,6 +98,7 @@ EmbeddingKVDB::EmbeddingKVDB( auto& indices = filling_item_ptr->indices; auto& weights = filling_item_ptr->weights; auto& count = filling_item_ptr->count; + auto& rocksdb_wmode = filling_item_ptr->mode; if (l2_cache_) { auto evicted_pairs_opt = set_cache(indices, weights, count); @@ -104,11 +106,12 @@ EmbeddingKVDB::EmbeddingKVDB( auto& evicted_indices = evicted_pairs_opt->first; auto& evicted_weights = evicted_pairs_opt->second; - folly::coro::blockingWait( - set_kv_db_async(evicted_indices, evicted_weights, count)); + folly::coro::blockingWait(set_kv_db_async( + evicted_indices, evicted_weights, count, rocksdb_wmode)); } } else { - folly::coro::blockingWait(set_kv_db_async(indices, weights, count)); + folly::coro::blockingWait( + set_kv_db_async(indices, weights, count, rocksdb_wmode)); } weights_to_fill_queue_.dequeue(); @@ -128,7 +131,8 @@ void EmbeddingKVDB::flush() { auto& indices = std::get<0>(tensor_tuple); auto& weights = std::get<1>(tensor_tuple); auto& count = std::get<2>(tensor_tuple); - folly::coro::blockingWait(set_kv_db_async(indices, weights, count)); + folly::coro::blockingWait(set_kv_db_async( + indices, weights, count, kv_db::RocksdbWriteMode::FLUSH)); } } @@ -175,11 +179,12 @@ void EmbeddingKVDB::set_cuda( std::vector EmbeddingKVDB::get_l2cache_perf( const int64_t step, const int64_t interval) { - std::vector ret(11, 0); // num metrics + std::vector ret(13, 0); // num metrics if (step > 0 && step % interval == 0) { int reset_val = 0; auto num_cache_misses = num_cache_misses_.exchange(reset_val); auto num_lookups = num_lookups_.exchange(reset_val); + auto num_evictions = num_evictions_.exchange(reset_val); auto get_total_duration = get_total_duration_.exchange(reset_val); auto get_cache_lookup_total_duration = get_cache_lookup_total_duration_.exchange(reset_val); @@ -187,6 +192,9 @@ std::vector EmbeddingKVDB::get_l2cache_perf( get_cache_lookup_wait_filling_thread_duration_.exchange(reset_val); auto get_weights_fillup_total_duration = get_weights_fillup_total_duration_.exchange(reset_val); + auto get_cache_memcpy_duration = + get_cache_memcpy_duration_.exchange(reset_val); + auto total_cache_update_duration = total_cache_update_duration_.exchange(reset_val); auto get_tensor_copy_for_cache_update_dur = @@ -199,13 +207,15 @@ std::vector EmbeddingKVDB::get_l2cache_perf( ret[3] = (double(get_cache_lookup_total_duration) / interval); ret[4] = (double(get_cache_lookup_wait_filling_thread_duration) / interval); ret[5] = (double(get_weights_fillup_total_duration) / interval); - ret[6] = (double(total_cache_update_duration) / interval); - ret[7] = (double(get_tensor_copy_for_cache_update_dur) / interval); - ret[8] = (double(set_tensor_copy_for_cache_update_dur) / interval); + ret[6] = (double(get_cache_memcpy_duration) / interval); + ret[7] = (double(total_cache_update_duration) / interval); + ret[8] = (double(get_tensor_copy_for_cache_update_dur) / interval); + ret[9] = (double(set_tensor_copy_for_cache_update_dur) / interval); + ret[10] = (double(num_evictions) / interval); if (l2_cache_) { auto cache_mem_stats = l2_cache_->get_cache_usage(); - ret[9] = (cache_mem_stats[0]); // free cache in bytes - ret[10] = (cache_mem_stats[1]); // total cache capacity in bytes + ret[11] = (cache_mem_stats[0]); // free cache in bytes + ret[12] = (cache_mem_stats[1]); // total cache capacity in bytes } } return ret; @@ -228,7 +238,10 @@ void EmbeddingKVDB::set( // parallelized with other cuda kernels, as long as all updates are finished // before the next L2 cache lookup auto tensor_copy_start_ts = facebook::WallClockUtil::NowInUsecFast(); - auto new_item = tensor_copy(indices, weights, count); + kv_db::RocksdbWriteMode write_mode = is_bwd + ? kv_db::RocksdbWriteMode::BWD_L1_CNFLCT_MISS_WRITE_BACK + : kv_db::RocksdbWriteMode::FWD_L1_EVICTION; + auto new_item = tensor_copy(indices, weights, count, write_mode); weights_to_fill_queue_.enqueue(new_item); set_tensor_copy_for_cache_update_ += facebook::WallClockUtil::NowInUsecFast() - tensor_copy_start_ts; @@ -264,7 +277,8 @@ void EmbeddingKVDB::get( // be parallelized with other cuda kernels, as long as all updates are // finished before the next L2 cache lookup auto tensor_copy_start_ts = facebook::WallClockUtil::NowInUsecFast(); - auto new_item = tensor_copy(indices, weights, count); + auto new_item = tensor_copy( + indices, weights, count, kv_db::RocksdbWriteMode::FWD_ROCKSDB_READ); weights_to_fill_queue_.enqueue(new_item); get_tensor_copy_for_cache_update_ += facebook::WallClockUtil::NowInUsecFast() - tensor_copy_start_ts; @@ -420,7 +434,8 @@ folly::Optional> EmbeddingKVDB::set_cache( .scheduleOn(executor_tp_.get())); } folly::coro::blockingWait(folly::coro::collectAllRange(std::move(tasks))); - l2_cache_->reset_eviction_states(); + auto num_evictions = l2_cache_->reset_eviction_states(); + num_evictions_ += num_evictions; total_cache_update_duration_ += facebook::WallClockUtil::NowInUsecFast() - cache_update_start_ts; return l2_cache_->get_evicted_indices_and_weights(); @@ -429,6 +444,7 @@ folly::Optional> EmbeddingKVDB::set_cache( folly::coro::Task EmbeddingKVDB::cache_memcpy( const at::Tensor& weights, const std::vector& cached_addr_list) { + auto cache_memcpy_start_ts = facebook::WallClockUtil::NowInUsecFast(); FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( weights.scalar_type(), "cache_memcpy", [&] { auto weights_data_ptr = weights.data_ptr(); @@ -443,6 +459,8 @@ folly::coro::Task EmbeddingKVDB::cache_memcpy( &weights_data_ptr[row_id * max_D_]); // dst_start } }); + get_cache_memcpy_duration_ += + facebook::WallClockUtil::NowInUsecFast() - cache_memcpy_start_ts; co_return; } diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h index a77d1c2b7..dfbd6c230 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h @@ -62,6 +62,30 @@ class CacheContext { std::vector cached_addr_list; }; +/// @ingroup embedding-ssd +/// +/// @brief rocksdb write mode +/// +/// In SSD offloading there are 3 writes in each train iteration +/// FWD_ROCKSDB_READ: cache lookup will move uncached data from rocksdb into L2 +/// cache on fwd path +/// +/// FWD_L1_EVICTION: L1 cache eviciton will evict data into L2 cache on fwd path +/// +/// BWD_L1_CNFLCT_MISS_WRITE_BACK: L1 conflict miss will insert into L2 for +/// embedding update on bwd path +/// +/// All the L2 cache filling above will potentially trigger rocksdb write once +/// L2 cache is full +/// +/// Additionally we will do ssd io on L2 flush +enum RocksdbWriteMode { + FWD_ROCKSDB_READ = 0, + FWD_L1_EVICTION = 1, + BWD_L1_CNFLCT_MISS_WRITE_BACK = 2, + FLUSH = 3, +}; + /// @ingroup embedding-ssd /// /// @brief queue item for background L2/rocksdb update @@ -72,25 +96,22 @@ class CacheContext { /// later used on updating cachelib LRU queue as we separate it from /// EmbeddingKVDB::get_cache() /// -/// mode is used for monitoring rocksdb write as there are 3 writes in each -/// train iteration, -/// - cache lookup will move uncached data from rocksdb into L2 cache on fwd -/// path -/// - L1 cache eviciton will evict data into L2 cache on fwd path -/// - L1 conflict miss will insert into L2 on bwd path -/// those L2 cache filling will potentially trigger rocksdb write once L2 cache -/// is full +/// mode is used for monitoring rocksdb write, checkout RocksdbWriteMode for +/// detailed explanation struct QueueItem { at::Tensor indices; at::Tensor weights; at::Tensor count; + RocksdbWriteMode mode; QueueItem( at::Tensor src_indices, at::Tensor src_weights, - at::Tensor src_count) { + at::Tensor src_count, + RocksdbWriteMode src_mode) { indices = src_indices; weights = src_weights; count = src_count; + mode = src_mode; } }; @@ -169,7 +190,7 @@ class EmbeddingKVDB : public std::enable_shared_from_this { const at::Tensor& indices, const at::Tensor& weights, const at::Tensor& count, - const bool is_bwd = false) = 0; + const RocksdbWriteMode w_mode = RocksdbWriteMode::FWD_ROCKSDB_READ) = 0; virtual void compact() = 0; @@ -279,10 +300,12 @@ class EmbeddingKVDB : public std::enable_shared_from_this { // instead of SUM(cache miss per interval) / SUM(lookups per interval) std::atomic num_cache_misses_{0}; std::atomic num_lookups_{0}; + std::atomic num_evictions_{0}; std::atomic get_total_duration_{0}; std::atomic get_cache_lookup_total_duration_{0}; std::atomic get_cache_lookup_wait_filling_thread_duration_{0}; std::atomic get_weights_fillup_total_duration_{0}; + std::atomic get_cache_memcpy_duration_{0}; std::atomic get_tensor_copy_for_cache_update_{0}; // -- perf of set() function diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 096745314..77588b5b9 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -345,7 +345,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { const at::Tensor& indices, const at::Tensor& weights, const at::Tensor& count, - const bool is_bwd = false) override { + const kv_db::RocksdbWriteMode w_mode = + kv_db::RocksdbWriteMode::FWD_ROCKSDB_READ) override { RECORD_USER_SCOPE("EmbeddingRocksDB::set"); #ifdef FBGEMM_FBCODE auto start_ts = facebook::WallClockUtil::NowInUsecFast(); @@ -398,10 +399,19 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { co_await folly::coro::collectAllRange(std::move(tasks)); #ifdef FBGEMM_FBCODE auto duration = facebook::WallClockUtil::NowInUsecFast() - start_ts; - if (is_bwd) { - bwd_write_total_duration_ += duration; - } else { - fwd_write_total_duration_ += duration; + switch (w_mode) { + case kv_db::RocksdbWriteMode::BWD_L1_CNFLCT_MISS_WRITE_BACK: + bwd_l1_cnflct_miss_write_back_dur_ += duration; + break; + case kv_db::RocksdbWriteMode::FWD_L1_EVICTION: + fwd_l1_eviction_dur_ += duration; + break; + case kv_db::RocksdbWriteMode::FWD_ROCKSDB_READ: + fwd_rocksdb_read_dur_ += duration; + break; + case kv_db::RocksdbWriteMode::FLUSH: + flush_write_dur_ += duration; + break; } #endif } @@ -560,17 +570,22 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { const int64_t step, const int64_t interval) { std::vector ret; - ret.reserve(3); + ret.reserve(5); if (step > 0 && step % interval == 0) { - auto read_dur = read_total_duration_.load(); - auto fwd_write_dur = fwd_write_total_duration_.load(); - auto bwd_write_dur = bwd_write_total_duration_.load(); + int64_t reset_val = 0; + auto read_dur = read_total_duration_.exchange(reset_val); + + auto fwd_rocksdb_read_dur = fwd_rocksdb_read_dur_.exchange(reset_val); + auto fwd_l1_eviction_dur = fwd_l1_eviction_dur_.exchange(reset_val); + auto bwd_l1_cnflct_miss_write_back_dur = + bwd_l1_cnflct_miss_write_back_dur_.exchange(reset_val); + auto flush_write_dur = flush_write_dur_.exchange(reset_val); + ret.push_back(double(read_dur) / interval); - ret.push_back(double(fwd_write_dur) / interval); - ret.push_back(double(bwd_write_dur) / interval); - read_total_duration_ = 0; - fwd_write_total_duration_ = 0; - bwd_write_total_duration_ = 0; + ret.push_back(double(fwd_rocksdb_read_dur) / interval); + ret.push_back(double(fwd_l1_eviction_dur) / interval); + ret.push_back(double(bwd_l1_cnflct_miss_write_back_dur) / interval); + ret.push_back(double(flush_write_dur) / interval); } return ret; } @@ -650,9 +665,13 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { int64_t memtable_flush_period_; int64_t compaction_period_; int64_t l0_files_per_compact_; + + // break down on rocksdb write duration for details checkout RocksdbWriteMode std::atomic read_total_duration_{0}; - std::atomic fwd_write_total_duration_{0}; - std::atomic bwd_write_total_duration_{0}; + std::atomic fwd_rocksdb_read_dur_{0}; + std::atomic fwd_l1_eviction_dur_{0}; + std::atomic bwd_l1_cnflct_miss_write_back_dur_{0}; + std::atomic flush_write_dur_{0}; }; // class EmbeddingKVDB } // namespace ssd From 48dad549a998c0deb8fe6970d9f214c8e1c29c5b Mon Sep 17 00:00:00 2001 From: Joe Wang Date: Fri, 13 Sep 2024 14:18:08 -0700 Subject: [PATCH 08/27] parallelize cache memcpy (#3121) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3121 X-link: https://github.com/facebookresearch/FBGEMM/pull/209 as title Reviewed By: q10 Differential Revision: D62223021 fbshipit-source-id: 7f642f8e244204754fdc213706ed21ad454bb3c1 --- .../kv_db_table_batched_embeddings.cpp | 42 ++++++++++++++----- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp index 5d12dbbc6..020fb6672 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp @@ -43,7 +43,7 @@ QueueItem tensor_copy( auto new_count = at::empty({1}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( - weights.scalar_type(), "cache_memcpy", [&] { + weights.scalar_type(), "tensor_copy", [&] { auto indices_addr = indices.data_ptr(); auto new_indices_addr = new_indices.data_ptr(); std::copy( @@ -447,17 +447,39 @@ folly::coro::Task EmbeddingKVDB::cache_memcpy( auto cache_memcpy_start_ts = facebook::WallClockUtil::NowInUsecFast(); FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( weights.scalar_type(), "cache_memcpy", [&] { + std::vector> tasks; auto weights_data_ptr = weights.data_ptr(); - for (int row_id = 0; row_id < cached_addr_list.size(); row_id++) { - if (cached_addr_list[row_id] == nullptr) { - continue; - } - std::copy( - reinterpret_cast(cached_addr_list[row_id]), - reinterpret_cast(cached_addr_list[row_id]) + - max_D_, - &weights_data_ptr[row_id * max_D_]); // dst_start + auto num_shards = executor_tp_->numThreads(); + for (uint32_t shard_id = 0; shard_id < num_shards; ++shard_id) { + tasks.emplace_back( + folly::coro::co_invoke( + [this, + &cached_addr_list, + &weights_data_ptr, + num_shards, + shard_id]() mutable -> folly::coro::Task { + for (int row_id = 0; row_id < cached_addr_list.size(); + row_id++) { + if (row_id % num_shards != shard_id) { + continue; + } + if (cached_addr_list[row_id] == nullptr) { + continue; + } + std::copy( + reinterpret_cast( + cached_addr_list[row_id]), + reinterpret_cast( + cached_addr_list[row_id]) + + max_D_, + &weights_data_ptr[row_id * max_D_]); // dst_start + } + co_return; + }) + .scheduleOn(executor_tp_.get())); } + folly::coro::blockingWait( + folly::coro::collectAllRange(std::move(tasks))); }); get_cache_memcpy_duration_ += facebook::WallClockUtil::NowInUsecFast() - cache_memcpy_start_ts; From d699a454452bff63eaf97f8fc108d0e9bb0bbf62 Mon Sep 17 00:00:00 2001 From: Joe Wang Date: Fri, 13 Sep 2024 14:18:08 -0700 Subject: [PATCH 09/27] don't update rocksdb on empty eviction (#3120) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3120 X-link: https://github.com/facebookresearch/FBGEMM/pull/210 before this diff, we always allocate a new space in cachelib even if the key already exists after this diff, we will not allocate a new space but using the existing key space to overwrite weights Reviewed By: q10 Differential Revision: D62584593 fbshipit-source-id: 6e7f3c180aa22e0185a4b8b2cfeea4476a1ec2a2 --- .../split_embeddings_cache/cachelib_cache.h | 11 ++--- .../split_embeddings_cache/cachelib_cache.cpp | 46 +++++++++++-------- .../kv_db_table_batched_embeddings.cpp | 24 ++++++---- .../kv_db_table_batched_embeddings.h | 8 ++-- 4 files changed, 51 insertions(+), 38 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h index 0f3bacbb8..bb6edff00 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h @@ -129,13 +129,12 @@ class CacheLibCache { /// reset slot pointer that points to the next available slot in the eviction /// tensors and returns number of slots filled - /// @return number evictions - int64_t reset_eviction_states(); + void reset_eviction_states(); /// get the filled indices and weights tensors from L2 eviction, could be all /// invalid if no eviction happened - folly::Optional> - get_evicted_indices_and_weights(); + folly::Optional> + get_tensors_and_reset(); /// get L2 cache utilization stats std::vector get_cache_usage(); @@ -146,8 +145,8 @@ class CacheLibCache { std::vector pool_ids_; std::unique_ptr admin_; - std::shared_ptr evicted_indices_ptr_{nullptr}; - std::shared_ptr evicted_weights_ptr_{nullptr}; + folly::Optional evicted_indices_opt_{folly::none}; + folly::Optional evicted_weights_opt_{folly::none}; std::atomic eviction_row_id{0}; }; diff --git a/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp b/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp index f12b6e947..d7eb220d8 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp @@ -53,15 +53,15 @@ std::unique_ptr CacheLibCache::initializeCacheLib( auto eviction_cb = [this](const facebook::cachelib::LruAllocator::RemoveCbData& data) { FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( - evicted_weights_ptr_->scalar_type(), "l2_eviction_handling", [&] { + evicted_weights_opt_->scalar_type(), "l2_eviction_handling", [&] { if (data.context == facebook::cachelib::RemoveContext::kEviction) { auto indices_data_ptr = - evicted_indices_ptr_->data_ptr(); + evicted_indices_opt_->data_ptr(); auto weights_data_ptr = - evicted_weights_ptr_->data_ptr(); + evicted_weights_opt_->data_ptr(); auto row_id = eviction_row_id++; - auto weight_dim = evicted_weights_ptr_->size(1); + auto weight_dim = evicted_weights_opt_->size(1); const auto key_ptr = reinterpret_cast(data.item.getKey().data()); indices_data_ptr[row_id] = *key_ptr; @@ -199,30 +199,38 @@ void CacheLibCache::init_tensor_for_l2_eviction( const at::Tensor& weights, const at::Tensor& count) { auto num_lookups = count.item(); - evicted_indices_ptr_ = std::make_shared( + evicted_indices_opt_ = at::ones( num_lookups, at::TensorOptions().device(indices.device()).dtype(indices.dtype())) * - -1); - evicted_weights_ptr_ = std::make_shared(at::empty( + -1; + evicted_weights_opt_ = at::empty( {num_lookups, weights.size(1)}, - at::TensorOptions().device(weights.device()).dtype(weights.dtype()))); + at::TensorOptions().device(weights.device()).dtype(weights.dtype())); } -int64_t CacheLibCache::reset_eviction_states() { - int64_t reset_val = 0; - auto num_eviction = eviction_row_id.exchange(reset_val); - return num_eviction; +void CacheLibCache::reset_eviction_states() { + evicted_indices_opt_.reset(); + evicted_weights_opt_.reset(); + eviction_row_id = 0; + return; } -folly::Optional> -CacheLibCache::get_evicted_indices_and_weights() { - if (evicted_indices_ptr_) { - assert(evicted_weights_ptr_ != nullptr); - return std::make_pair(*evicted_indices_ptr_, *evicted_weights_ptr_); - } else { - return folly::none; +folly::Optional> +CacheLibCache::get_tensors_and_reset() { + folly::Optional> ret = + folly::none; + if (evicted_indices_opt_.has_value()) { + assert(evicted_weights_opt_.has_value()); + if (eviction_row_id > 0) { + ret = std::make_tuple( + evicted_indices_opt_.value(), + evicted_weights_opt_.value(), + at::tensor(eviction_row_id, evicted_indices_opt_->options())); + } } + reset_eviction_states(); + return ret; } std::vector CacheLibCache::get_cache_usage() { diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp index 020fb6672..fcf04fff0 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp @@ -101,13 +101,14 @@ EmbeddingKVDB::EmbeddingKVDB( auto& rocksdb_wmode = filling_item_ptr->mode; if (l2_cache_) { - auto evicted_pairs_opt = set_cache(indices, weights, count); - if (evicted_pairs_opt.has_value()) { - auto& evicted_indices = evicted_pairs_opt->first; - auto& evicted_weights = evicted_pairs_opt->second; + auto evicted_tuples_opt = set_cache(indices, weights, count); + if (evicted_tuples_opt.has_value()) { + auto& evicted_indices = std::get<0>(evicted_tuples_opt.value()); + auto& evicted_weights = std::get<1>(evicted_tuples_opt.value()); + auto& evicted_count = std::get<2>(evicted_tuples_opt.value()); folly::coro::blockingWait(set_kv_db_async( - evicted_indices, evicted_weights, count, rocksdb_wmode)); + evicted_indices, evicted_weights, evicted_count, rocksdb_wmode)); } } else { folly::coro::blockingWait( @@ -380,7 +381,8 @@ void EmbeddingKVDB::wait_util_filling_work_done() { facebook::WallClockUtil::NowInUsecFast() - start_ts; } -folly::Optional> EmbeddingKVDB::set_cache( +folly::Optional> +EmbeddingKVDB::set_cache( const at::Tensor& indices, const at::Tensor& weights, const at::Tensor& count) { @@ -434,11 +436,15 @@ folly::Optional> EmbeddingKVDB::set_cache( .scheduleOn(executor_tp_.get())); } folly::coro::blockingWait(folly::coro::collectAllRange(std::move(tasks))); - auto num_evictions = l2_cache_->reset_eviction_states(); - num_evictions_ += num_evictions; total_cache_update_duration_ += facebook::WallClockUtil::NowInUsecFast() - cache_update_start_ts; - return l2_cache_->get_evicted_indices_and_weights(); + auto tensor_tuple_opt = l2_cache_->get_tensors_and_reset(); + if (tensor_tuple_opt.has_value()) { + auto& num_evictions_tensor = std::get<2>(tensor_tuple_opt.value()); + auto num_evictions = get_maybe_uvm_scalar(num_evictions_tensor); + num_evictions_ += num_evictions; + } + return tensor_tuple_opt; } folly::coro::Task EmbeddingKVDB::cache_memcpy( diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h index dfbd6c230..d63ff5418 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h @@ -254,10 +254,10 @@ class EmbeddingKVDB : public std::enable_shared_from_this { /// @param count A single element tensor that contains the number of indices /// to be processed /// - /// @return None if L2 is missing, other wise return pair of tensors with - /// length of containing L2 evicted embedding indices and embeddings, - /// invalid pairs will have sentinel value(-1) on - folly::Optional> set_cache( + /// @return None if L2 is missing or no eviction, other wise return tuple of + /// tensors with length of containing L2 evicted embedding indices and + /// embeddings, invalid pairs will have sentinel value(-1) on + folly::Optional> set_cache( const at::Tensor& indices, const at::Tensor& weights, const at::Tensor& count); From a90aac18feb73baf9835d8919d64762e3f418250 Mon Sep 17 00:00:00 2001 From: Supadchaya Puangpontip Date: Fri, 13 Sep 2024 15:26:57 -0700 Subject: [PATCH 10/27] Increase test time out for test workflow OSS CI (#3134) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3134 Failures on OSS CI test workflow are due to time-out during unit testing. https://github.com/pytorch/FBGEMM/actions/runs/10851972782 This diff increases time-out so all tests can be run. Reviewed By: q10, sryap Differential Revision: D62649448 fbshipit-source-id: 5911b2dca3351eab60cf32eb008b433d2b47236f --- .github/workflows/fbgemm_gpu_pip.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/fbgemm_gpu_pip.yml b/.github/workflows/fbgemm_gpu_pip.yml index 20cac9dc9..33125145f 100644 --- a/.github/workflows/fbgemm_gpu_pip.yml +++ b/.github/workflows/fbgemm_gpu_pip.yml @@ -160,7 +160,7 @@ jobs: run: . $PRELUDE; install_fbgemm_gpu_pip $BUILD_ENV ${{ github.event.inputs.fbgemm_gpu_channel_version || 'nightly' }} cuda/${{ matrix.cuda-version }} - name: Test with PyTest - timeout-minutes: 20 + timeout-minutes: 40 run: . $PRELUDE; test_all_fbgemm_gpu_modules $BUILD_ENV From 49fa9a55ba97a0d655cf2c51bf595caab37638e1 Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Mon, 16 Sep 2024 02:12:40 -0700 Subject: [PATCH 11/27] Call cudaGetDeviceProperties once in masked_index_impl (#3136) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3136 X-link: https://github.com/facebookresearch/FBGEMM/pull/229 This diff caches the default number of SMs for pipeline prefetching to avoid calling `cudaGetDeviceProperties` in every `masked_index_*` call Reviewed By: chrisxcai Differential Revision: D62672190 fbshipit-source-id: a5865f602f394ba4909ab1334661df4689b8cfd5 --- .../ssd_split_embeddings_cache_cuda.cu | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu index d6d926a96..d371c2845 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu @@ -22,12 +22,27 @@ #include "fbgemm_gpu/utils/tensor_utils.h" #include "fbgemm_gpu/utils/vec4.cuh" -constexpr int ALL_TO_PREFETCH_SM_RATIO = 8; - using Tensor = at::Tensor; using namespace fbgemm_gpu; +int get_masked_index_default_pipeline_sms(int device) { + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, device); + + // The default number of SMs for use_pipeline=true is set based on an + // empirical study + if (prop.major == 8) { + // Assume A100 + return 4; + } else if (prop.major == 9) { + // Assume H100 + return 16; + } + constexpr int ALL_TO_PREFETCH_SM_RATIO = 8; + return div_round_up(get_device_sm_cnt_(), ALL_TO_PREFETCH_SM_RATIO); +} + template DEVICE_INLINE void vec4_copy(scalar_t* dst, const scalar_t* src, const int32_t D) { @@ -103,26 +118,11 @@ Tensor masked_index_impl( const auto full_grid_size = div_round_up(N, kMaxThreads / tx); - // The default number of SMs for use_pipeline=true is set based on an - // empirical study - - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, at::cuda::current_device()); - - int DEFAULT_PIPELINE_SMS; - if (prop.major == 8) { - // Assume A100 - DEFAULT_PIPELINE_SMS = 4; - } else if (prop.major == 9) { - // Assume H100 - DEFAULT_PIPELINE_SMS = 16; - } else { - DEFAULT_PIPELINE_SMS = - div_round_up(get_device_sm_cnt_(), ALL_TO_PREFETCH_SM_RATIO); - } + static int masked_index_default_pipeline_sms = + get_masked_index_default_pipeline_sms(at::cuda::current_device()); const int pipeline_grid_size = - preferred_sms == -1 ? DEFAULT_PIPELINE_SMS : preferred_sms; + preferred_sms == -1 ? masked_index_default_pipeline_sms : preferred_sms; TORCH_CHECK( !use_pipeline || pipeline_grid_size >= 1, "preferred_sms must >= 1"); From 9634774cfe3526e785f050976fcfd6f2b05d0632 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 16 Sep 2024 13:17:48 -0700 Subject: [PATCH 12/27] EMU1.6 CK FP8 Tuning (#3138) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3138 X-link: https://github.com/facebookresearch/FBGEMM/pull/232 This diff adds a few more tuned shapes to our CK FP8 Kernel dispatch for better EMU1.6 performance. The before / after performance can be seen [here](https://docs.google.com/spreadsheets/d/1SAymyghA8V0ZXD1G7ButMy7GzohYogPISlQ6rHVtVdE/edit?usp=sharing). The quick summary is that we see small to medium improvements across all EMU shapes. E2E performance should improve a bit but not massively as it seems the heuristics did acceptably in this case. Reviewed By: mxz297 Differential Revision: D62761000 fbshipit-source-id: 378fabb7cb2e37ffc415b7e464044a266a24ec9c --- .../ck_extensions/fp8_rowwise_gemm.hip | 26 ++++++- ...8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip | 69 +++++++++++++++++++ ...4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip | 69 +++++++++++++++++++ .../kernels/fp8_rowwise_kernel_manifest.h | 18 +++++ 4 files changed, 179 insertions(+), 3 deletions(-) create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip index 4b9946872..3450d2cad 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip @@ -126,13 +126,33 @@ static const std::unordered_map< fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, // EMU 1.6 Shapes. {{1536, 3584, 3584}, - fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1}, + fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, {{8192, 9728, 3584}, - fp8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, {{8192, 3584, 9728}, fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5}, {{8192, 3584, 3584}, - fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1}, + fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + {{4096, 3584, 3584}, + fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + {{768, 3584, 3584}, + fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + {{4096, 9728, 3584}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + {{4096, 3584, 9728}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + {{7200, 3584, 3584}, + fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + {{7200, 9728, 3584}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + {{7200, 3584, 9728}, + fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + {{3600, 3584, 3584}, + fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + {{3600, 9728, 3584}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + {{3600, 3584, 9728}, + fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, // Pro Shapes. {{32768, 128, 8192}, fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..acf927c05 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,69 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_common.h" + +at::Tensor +fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (M % 128 != 0) || (N % 32 != 0) || (K % 128 != 0); + + // This kernel seems optimal in the most purely compute bound tasks. + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2>; + // Run kernel instance. + return f8f8bf16_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip new file mode 100644 index 000000000..b47c082fd --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip @@ -0,0 +1,69 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_common.h" + +at::Tensor +fp8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (M % 256 != 0) || (N % 256 != 0) || (K % 64 != 0); + + // This kernel seems optimal in the most purely compute bound tasks. + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 256, + 64, + 16, + 16, + 8, + 8, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3>; + // Run kernel instance. + return f8f8bf16_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 256, + 64, + 16, + 16, + 8, + 8, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_kernel_manifest.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_kernel_manifest.h index db38343dc..146450ec1 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_kernel_manifest.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_kernel_manifest.h @@ -231,3 +231,21 @@ fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave at::Tensor x_scale, at::Tensor w_scale, at::Tensor Y); + +// Decent mid-size kernel. +at::Tensor +fp8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + +// Kernel for small but not too small batch sizes. +at::Tensor +fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); From 98bc1320f2df84fbc39203663ee8337ae7bec0db Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Tue, 17 Sep 2024 13:56:48 -0700 Subject: [PATCH 13/27] Better type checks for `pta::PackedTensorAccessor` (#3141) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3141 X-link: https://github.com/facebookresearch/FBGEMM/pull/234 - Add better type checks and diagnostic messages for `pta::PackedTensorAccessor` (extracted from D62794566) Reviewed By: basilwong, spcyppt Differential Revision: D62814065 fbshipit-source-id: b77b314a925dee58f23a699449199916985da05d --- ...ward_quantized_split_nbit_host_template.cu | 2 +- .../fbgemm_gpu/utils/tensor_accessor.h | 74 ++++++++++++++++++- .../transpose_embedding_input.cu | 4 +- 3 files changed, 76 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu index e7b908cdd..bc4e7ba74 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu @@ -62,7 +62,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no {%- set func_name = "nbit::" + emb_weight_type + "_split_embedding" + ("_nobag" if nobag else "") + "_codegen_forward_" + wdesc + "_kernel_small_L" %} #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name_{{ emb_weight_type }} = func_name_{{ emb_weight_type }}; + const auto func_name_{{ emb_weight_type }} = "{{ func_name }}_{{ emb_weight_type }}"; #endif #ifdef X diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_accessor.h b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_accessor.h index 3f5ed08f2..24ad14125 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_accessor.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_accessor.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include #include @@ -472,6 +473,53 @@ template < using PackedTensorAccessor64 = GenericPackedTensorAccessor; +template +inline at::ScalarType scalar_type_for() { +#define TYPE_CASE(U, name) \ + if constexpr (std::is_same_v) { \ + return at::ScalarType::name; \ + } + + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(TYPE_CASE) + +#undef TYPE_CASE + + return at::ScalarType::Undefined; +} + +template +inline void check_scalar_type( + const at::TensorBase& tensor +#ifdef FBGEMM_GPU_MEMCHECK + , + const char* const func_name, + const char* const tensor_name +#endif +) { + const auto expected_type = scalar_type_for(); + + TORCH_CHECK( + tensor.scalar_type() == expected_type || + (isQIntType(tensor.scalar_type()) && + toUnderlying(tensor.scalar_type()) == expected_type), +#ifdef FBGEMM_GPU_MEMCHECK + "[ ", + func_name, + " ]: ", +#endif + "Expected tensor ", +#ifdef FBGEMM_GPU_MEMCHECK + "'", + tensor_name, + "' ", +#endif + "to have scalar type ", + expected_type, + ", but found ", + tensor.scalar_type(), + " instead!"); +} + } // namespace fbgemm_gpu #ifdef FBGEMM_GPU_MEMCHECK @@ -521,10 +569,21 @@ pta::PackedTensorAccessor32 make_packed_tensor_accessor32( #else const at::Tensor& tensor) { #endif + TORCH_CHECK( tensor.numel() <= static_cast(std::numeric_limits::max()), "numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64"); + + fbgemm_gpu::check_scalar_type( + tensor +#ifdef FBGEMM_GPU_MEMCHECK + , + func_name, + ptr_name +#endif + ); + #ifdef FBGEMM_GPU_MEMCHECK return make_generic_packed_tensor_accessor( tensor, ptr_name, func_name); @@ -542,10 +601,23 @@ pta::PackedTensorAccessor64 make_packed_tensor_accessor64( const at::Tensor& tensor, const char* const ptr_name, const char* const func_name) { +#else + const at::Tensor& tensor) { +#endif + + fbgemm_gpu::check_scalar_type( + tensor +#ifdef FBGEMM_GPU_MEMCHECK + , + func_name, + ptr_name +#endif + ); + +#ifdef FBGEMM_GPU_MEMCHECK return make_generic_packed_tensor_accessor( tensor, ptr_name, func_name); #else - const at::Tensor& tensor) { return tensor.packed_accessor64(); #endif } diff --git a/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu b/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu index 16d46e2e3..83a06d78a 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu +++ b/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu @@ -274,10 +274,10 @@ transpose_embedding_input( } AT_DISPATCH_INDEX_TYPES( - infos.scalar_type(), "transpose_embedding_input1", [&] { + infos.scalar_type(), "transpose_embedding_input_1", [&] { using info_t = index_t; AT_DISPATCH_INDEX_TYPES( - indices.scalar_type(), "transpose_embedding_input2", [&] { + indices.scalar_type(), "transpose_embedding_input_2", [&] { if (!is_index_select) { if (!nobag) { INVOKE_LINEARIZE_INDEX_KERNEL(int32_t, false); From fabf5e96fc7542a4e045a4cd05db72c517940389 Mon Sep 17 00:00:00 2001 From: Quinn Zhu Date: Tue, 17 Sep 2024 14:43:09 -0700 Subject: [PATCH 14/27] Enforce same dtype for indices and offsets in cpu ops (#3132) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/224 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3132 Convert offsets.dtype to indices.dtype in cpu ops Reviewed By: houseroad Differential Revision: D62622955 fbshipit-source-id: f32edd4a25eb414651c1a1e1570fac4d7f69a4ce --- .../codegen/inference/embedding_forward_quantized_host_cpu.cpp | 3 +++ fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp | 3 +++ 2 files changed, 6 insertions(+) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp index e799120a6..41fd137dd 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp @@ -103,6 +103,9 @@ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu( std::optional max_float8_D, std::optional fp8_exponent_bits, std::optional fp8_exponent_bias) { + if (offsets.scalar_type() != indices.scalar_type()) { + offsets = offsets.toType(indices.scalar_type()); + } if (static_cast(pooling_mode) == PoolingMode::NONE) { std::vector max_D_list{ max_int2_D, diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp index 85d23cc94..1098378d0 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp @@ -49,6 +49,9 @@ void bounds_check_indices_cpu( const std::optional& weights, const std::optional& B_offsets, const int64_t max_B) { + if (offsets.scalar_type() != indices.scalar_type()) { + offsets = offsets.toType(indices.scalar_type()); + } const auto vbe = B_offsets.has_value(); if (vbe) { TENSOR_NDIM_EQUALS(B_offsets.value(), 1); From e27c5e166f69ac138e839f7f62b189a731923da1 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Tue, 17 Sep 2024 18:41:08 -0700 Subject: [PATCH 15/27] Improve one/two_shot_all_reduce (#3139) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3139 X-link: https://github.com/facebookresearch/FBGEMM/pull/231 - Adjust one/two_shot_all_reduce for better performance on AMD/NVIDIA GPUs - Improve all_reduce time for 70B decoding, BS96 - AMD MI300X: 33.14us to 27.00us - NVIDIA H100: 22.32us to 20.83us Reviewed By: jianyuh, xw285cornell Differential Revision: D62753553 fbshipit-source-id: d3556bdad3f129936a843bc6ebb11810ab8fc7bc --- .../experimental/gen_ai/src/comm/car.cu | 161 ++++++++++++------ 1 file changed, 113 insertions(+), 48 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu index f0c689570..4712f620f 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu @@ -67,17 +67,26 @@ DEVICE_INLINE bf16x8 add_bf16x8(bf16x8 a, bf16x8 b) { return c; } -template -__global__ void one_shot_all_reduce( - int32_t rank, - int32_t world_size, - int32_t flag, - std::array barriers, - std::array inputs, +template +#if defined(USE_ROCM) +__launch_bounds__(512) +#endif + __global__ void one_shot_all_reduce( + int32_t rank, + int32_t world_size, + int32_t flag, + std::array barriers, + std::array inputs, +#if defined(USE_ROCM) + at::BFloat16* __restrict__ ar_input, + at::BFloat16* __restrict__ acc, + at::BFloat16* __restrict__ output, +#else at::BFloat16* ar_input, at::BFloat16* acc, at::BFloat16* output, - int32_t N) { +#endif + int32_t N) { // It is expensive to launch hipMemcpyAsync on ROCm // Move data copy here. Each block copies part of input data at::BFloat16* input = inputs[rank]; @@ -143,11 +152,11 @@ __global__ void one_shot_all_reduce( // Sum the values from the different ranks. bf16x8 sums; - if (acc) { + if constexpr (has_acc) { *reinterpret_cast(&sums) = *reinterpret_cast(&acc[i]); } else { - memset(reinterpret_cast(&sums), 0, sizeof(sums)); + *reinterpret_cast(&sums) = uint4{0}; } #pragma unroll kWorldSize @@ -336,15 +345,24 @@ static DEVICE_INLINE void ld_flag_acquire(int32_t& flag, int32_t* flag_addr) { #endif } -template +template +#if defined(USE_ROCM) +__launch_bounds__(512) __global__ void two_shot_all_reduce( +#else __launch_bounds__(1024) __global__ void two_shot_all_reduce( +#endif int32_t rank, int32_t world_size, int32_t flag, std::array barriers, std::array inputs, +#if defined(USE_ROCM) + at::BFloat16* __restrict__ acc, + at::BFloat16* __restrict__ output, +#else at::BFloat16* acc, at::BFloat16* output, +#endif int32_t N) { int32_t N_per_rank = N / kWorldSize; int32_t N_start = N_per_rank * rank; @@ -374,13 +392,11 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce( __syncthreads(); at::BFloat16* src_d[kWorldSize]; - int dst_rank[kWorldSize]; #pragma unroll kWorldSize for (int ii = 0; ii < kWorldSize; ++ii) { int d_rank = (rank + ii) % kWorldSize; src_d[ii] = inputs[d_rank]; - dst_rank[ii] = d_rank; } // Each block accumulates the values from the different GPUs on the same @@ -395,11 +411,12 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce( } bf16x8 sums; - if (acc) { + + if constexpr (has_acc) { *reinterpret_cast(&sums) = *reinterpret_cast(&acc[i + N_start]); } else { - memset(reinterpret_cast(&sums), 0, sizeof(sums)); + *reinterpret_cast(&sums) = uint4{0}; } #pragma unroll kWorldSize @@ -433,11 +450,19 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce( // Gather all needed elts from other intra-node ranks for (size_t i = threadIdx.x * 8 + blockIdx.x * blockDim.x * 8; i < N_per_rank; i += gridDim.x * blockDim.x * 8) { + uint4 temp[kWorldSize]; +#pragma unroll kWorldSize + for (int ii = 0; ii < kWorldSize; ++ii) { + int d_rank = (rank + ii) % kWorldSize; + int i_r = N_start + i + (d_rank - rank) * N_per_rank; + temp[ii] = reinterpret_cast(&src_d[ii][i_r])[0]; + } + #pragma unroll kWorldSize for (int ii = 0; ii < kWorldSize; ++ii) { - int i_r = N_start + i + (dst_rank[ii] - rank) * N_per_rank; - *reinterpret_cast(&output[i_r]) = - reinterpret_cast(&src_d[ii][i_r])[0]; + int d_rank = (rank + ii) % kWorldSize; + int i_r = N_start + i + (d_rank - rank) * N_per_rank; + *reinterpret_cast(&output[i_r]) = temp[ii]; } } } @@ -474,7 +499,11 @@ void one_shot_car_allreduce( constexpr int32_t N_per_thread = 8; constexpr int32_t N_per_warp = N_per_thread * kThreadsPerWarp; TORCH_CHECK(N % N_per_warp == 0); +#if defined(USE_ROCM) + constexpr int32_t kThreadsPerBlock = 512; +#else constexpr int32_t kThreadsPerBlock = 1024; +#endif constexpr int32_t kMaxBlocks = 24; dim3 threads(0, 1, 1); @@ -494,21 +523,37 @@ void one_shot_car_allreduce( threads.x = threads_per_block; } -#define X(kWorldSize) \ - if (state->world_size_ == kWorldSize) { \ - one_shot_all_reduce \ - <<>>( \ - state->rank_, \ - state->world_size_, \ - state->flag_ * state->world_size_, \ - barriers, \ - inputs, \ - y.data_ptr(), \ - z ? z->data_ptr() : nullptr, \ - y_allreduce.data_ptr(), \ - N); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - return; \ +#define X(kWorldSize) \ + if (state->world_size_ == kWorldSize) { \ + if (z) { \ + one_shot_all_reduce \ + <<>>( \ + state->rank_, \ + state->world_size_, \ + state->flag_ * state->world_size_, \ + barriers, \ + inputs, \ + y.data_ptr(), \ + z->data_ptr(), \ + y_allreduce.data_ptr(), \ + N); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + return; \ + } else { \ + one_shot_all_reduce \ + <<>>( \ + state->rank_, \ + state->world_size_, \ + state->flag_ * state->world_size_, \ + barriers, \ + inputs, \ + y.data_ptr(), \ + nullptr, \ + y_allreduce.data_ptr(), \ + N); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + return; \ + } \ } TORCH_CHECK( @@ -520,7 +565,7 @@ void one_shot_car_allreduce( #undef X return; -} +} // namespace fbgemm_gpu void two_shot_car_allreduce( at::Tensor y_allreduce, @@ -565,26 +610,46 @@ void two_shot_car_allreduce( TORCH_CHECK(N_per_rank % N_per_thread == 0); auto threads_per_rank = N_per_rank / N_per_thread; +#if defined(USE_ROCM) + constexpr int32_t kThreadsPerBlock = 512; +#else constexpr int32_t kThreadsPerBlock = 1024; +#endif + constexpr int32_t kMaxBlocks = 24; auto blocks = std::min( cuda_calc_block_count(threads_per_rank, kThreadsPerBlock), kMaxBlocks); -#define X(kWorldSize) \ - if (state->world_size_ == kWorldSize) { \ - two_shot_all_reduce \ - <<>>( \ - state->rank_, \ - state->world_size_, \ - state->flag_ * state->world_size_, \ - barriers, \ - inputs, \ - z ? z->data_ptr() : nullptr, \ - y_allreduce.data_ptr(), \ - N); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - return; \ +#define X(kWorldSize) \ + if (state->world_size_ == kWorldSize) { \ + if (z) { \ + two_shot_all_reduce \ + <<>>( \ + state->rank_, \ + state->world_size_, \ + state->flag_ * state->world_size_, \ + barriers, \ + inputs, \ + z->data_ptr(), \ + y_allreduce.data_ptr(), \ + N); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + return; \ + } else { \ + two_shot_all_reduce \ + <<>>( \ + state->rank_, \ + state->world_size_, \ + state->flag_ * state->world_size_, \ + barriers, \ + inputs, \ + nullptr, \ + y_allreduce.data_ptr(), \ + N); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + return; \ + } \ } TORCH_CHECK( From d0e5fb3315ba88716f4005c07ede416e780d7f68 Mon Sep 17 00:00:00 2001 From: Quinn Zhu Date: Wed, 18 Sep 2024 10:03:14 -0700 Subject: [PATCH 16/27] Add unit test for indices and offsets with mixed precision (#3144) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/237 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3144 Add unit test for the changed in https://github.com/pytorch/FBGEMM/pull/3132 Reviewed By: q10 Differential Revision: D62900038 fbshipit-source-id: 49a025f3a79ac86b33e6fda0049020619322dd44 --- .../inference/nbit_split_embeddings_test.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py index ed4245dbd..439797688 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py @@ -347,6 +347,62 @@ def test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function( ) torch.testing.assert_close(output_uvm, output_ref, equal_nan=True) + @given( + weights_ty=st.sampled_from( + [ + SparseType.FP32, + SparseType.FP16, + SparseType.INT8, + SparseType.INT4, + SparseType.INT2, + ] + ), + ) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) + def test_int_nbit_split_embedding_cpu_mixed_indices_offsets_dtypes( + self, + weights_ty: SparseType, + ) -> None: + T = random.randint(1, 5) + B = random.randint(1, 128) + L = random.randint(1, 20) + D = random.randint(2, 256) + log_E = random.randint(3, 5) + + iters = 4 + E = int(10**log_E) + + D_alignment = ( + 1 if weights_ty.bit_rate() % 8 == 0 else int(8 / weights_ty.bit_rate()) + ) + D = round_up(D, D_alignment) + + Ds = [D] * T + Es = [E] * T + cpu_locations = [EmbeddingLocation.HOST] * T + + cc = IntNBitTableBatchedEmbeddingBagsCodegen( + [("", E, D, weights_ty, M) for (E, D, M) in zip(Es, Ds, cpu_locations)], + device=torch.device("cpu"), + ) + cc.fill_random_weights() + + requests = generate_requests( + iters, B, T, L, min(Es), reuse=0.1, emulate_pruning=False, use_cpu=True + ) + dtypes_combo = [ + (torch.int64, torch.int64), + (torch.int32, torch.int32), + (torch.int32, torch.int64), + (torch.int64, torch.int32), + ] + for i, req in enumerate(requests): + indices, offsets = req.unpack_2() + indices_dtype, offsets_dtype = dtypes_combo[i] + indices = indices.to(indices_dtype) + offsets = offsets.to(offsets_dtype) + _ = cc(indices, offsets) + if __name__ == "__main__": unittest.main() From fd21e0df9a390e658dbc8cfd7743c24553dc912d Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Wed, 18 Sep 2024 10:59:46 -0700 Subject: [PATCH 17/27] Fix docutils version (#3145) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/238 - Fix docutils version so that netlify builds can pass Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3145 Reviewed By: duduyi2013 Differential Revision: D62906803 Pulled By: q10 fbshipit-source-id: 44467a3dc7065fc6690072434acfb2e9459fa49a --- fbgemm_gpu/docs/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/docs/requirements.txt b/fbgemm_gpu/docs/requirements.txt index 9f3bca439..533232736 100644 --- a/fbgemm_gpu/docs/requirements.txt +++ b/fbgemm_gpu/docs/requirements.txt @@ -14,7 +14,7 @@ sphinx<7 breathe bs4 -docutils +docutils<0.20,>=0.18.1 lxml myst-parser sphinx-lint From d64d1f650e7f9b95d67a4f70fe3aa6355c237b72 Mon Sep 17 00:00:00 2001 From: Sungmin Cho Date: Wed, 18 Sep 2024 15:30:09 -0700 Subject: [PATCH 18/27] support rope with block tables (#3146) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3146 X-link: https://github.com/facebookresearch/FBGEMM/pull/227 Modify `rope_xpos_qkv_varseq_prefill_kernel_` so that it uses page indirection for qparam tensors as well. Reviewed By: sgrigory Differential Revision: D61898380 fbshipit-source-id: fa8b9a3dc42333358ee056bfb793b72bf7d7e450 --- .../gen_ai/src/kv_cache/kv_cache.cu | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu index 0728e6b9a..51583495a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu @@ -795,11 +795,27 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_( } else { __half2* qparam_row = nullptr; auto T = cache_K.size(1); - auto idx = b * (T * N_KVH) + (size_t)cache_loc_t * N_KVH + h; - if (qkv == QKV::K) { - qparam_row = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]); + if (block_tables == nullptr) { + auto idx = b * (T * N_KVH) + (size_t)cache_loc_t * N_KVH + h; + if (qkv == QKV::K) { + qparam_row = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]); + } else { + qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]); + } } else { - qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]); + // This is duplicate computation with get_dst_row above. + // TODO: Maybe clean up and merge later. + int page_logical_idx = cache_loc_t / page_size; + int page_offset = cache_loc_t % page_size; + int page_physical_idx = + block_tables[b * block_tables_b_stride + page_logical_idx]; + int physical_t = page_physical_idx * page_size + page_offset; + auto idx = physical_t * N_KVH + h; + if (qkv == QKV::K) { + qparam_row = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]); + } else { + qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]); + } } quantize_fp8_kv(dst, dst_row_q, qparam_row); } From 396b3111e7fcadbfb413d94f40b0b30140239882 Mon Sep 17 00:00:00 2001 From: Sungmin Cho Date: Wed, 18 Sep 2024 15:30:09 -0700 Subject: [PATCH 19/27] support dequantize with block tables (#3135) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3135 X-link: https://github.com/facebookresearch/FBGEMM/pull/228 Add support for paged KV cache in `dequantize_fp8_cache`. Reviewed By: sgrigory Differential Revision: D61904297 fbshipit-source-id: 1384e4214aea803e0e070898596100c17c8dd583 --- .../gen_ai/src/kv_cache/kv_cache.cpp | 7 +- .../gen_ai/src/kv_cache/kv_cache.cu | 180 ++++++++++++++++-- 2 files changed, 165 insertions(+), 22 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp index 3d5d337ec..38ed3ec6b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp @@ -133,7 +133,9 @@ std::tuple dequantize_fp8_cache( at::Tensor cache_V, at::Tensor kv_seqlen, std::optional qparam_k, - std::optional qparam_v); + std::optional qparam_v, + std::optional block_tables, + int64_t page_size); at::Tensor mqa_attn( at::Tensor XQ, @@ -162,7 +164,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "dequantize_int4_cache(Tensor cache_K, Tensor cache_V, Tensor kv_seqlen, int? num_groups=1) -> (Tensor, Tensor)"); m.impl("dequantize_int4_cache", dequantize_int4_cache); m.def( - "dequantize_fp8_cache(Tensor cache_K, Tensor cache_V, Tensor kv_seqlen, Tensor? qparam_k=None, Tensor? qparam_v=None) -> (Tensor, Tensor)"); + "dequantize_fp8_cache(Tensor cache_K, Tensor cache_V, Tensor kv_seqlen, Tensor? qparam_k=None, Tensor? qparam_v=None, Tensor? block_tables=None, int page_size=" STRING( + DEFAULT_PAGE_SIZE) ") -> (Tensor, Tensor)"); m.impl("dequantize_fp8_cache", dequantize_fp8_cache); } diff --git a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu index 51583495a..787c0547c 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu @@ -1493,16 +1493,113 @@ __global__ void dequantize_fp8_cache_kernel( *reinterpret_cast(&kv_dq.vals[2]); } } + +// Cloned from dequantize_fp8_cache_kernel because +// branching inside the original kernel runs into +// "too many resources requested for launch" which +// necessitates decreasing the number of warps per block, +// which might have performance implications. Also we +// might have more diverging behaviors for paged kernel +// as noted in the comment below so we will keep a separate +// kernel for now. +__global__ void dequantize_fp8_cache_kernel_paged( + // This code currently represents FP8 version not int4 + at::PackedTensorAccessor64 + cache_K, // [1][MAX_PAGE * PAGE_SIZE][N_KVH][D_H] + at::PackedTensorAccessor64 + cache_V, // [1][MAX_PAGE * PAGE_SIZE][N_KVH][D_H // G] + at::PackedTensorAccessor32 kv_seqlen, + at::PackedTensorAccessor64 + cache_K_dq, // [1][MAX_T][N_KVH][D_H] + at::PackedTensorAccessor64 + cache_V_dq, // [1][MAX_T][N_KVH][D_H] + int32_t* qparam_k_ptr, + int32_t* qparam_v_ptr, + int32_t* block_tables, + int32_t block_tables_b_stride, + int32_t page_size) { + auto N_KVH = cache_K.size(2); + auto MAX_T = cache_K.size(1); + auto D_H = cache_K_dq.size(3); + auto D_H_q = cache_K.size(3); + CUDA_KERNEL_ASSERT(D_H == 128); + + auto b = blockIdx.x; + // only need to dequantize this far. + auto max_t = kv_seqlen[b]; + + // one warp per T/H + for (auto t_h = threadIdx.y + blockIdx.y * blockDim.y; t_h < max_t * N_KVH; + t_h += blockDim.y * gridDim.y) { + auto h = t_h % N_KVH; + auto t = t_h / N_KVH; + + int page_logical_idx = t / page_size; + int page_offset = t % page_size; + int page_physical_idx = + block_tables[b * block_tables_b_stride + page_logical_idx]; + int physical_t = page_physical_idx * page_size + page_offset; + + uint8_t* row_k = &cache_K[0][physical_t][h][0]; + uint8_t* row_v = &cache_V[0][physical_t][h][0]; + + bfx8 kv_dq; + uint8_t qparam_offset_bytes; + __half2* qparam_k_src; + __half2* qparam_v_src; + if (qparam_k_ptr) { + // read from standalone qparam tensor + qparam_offset_bytes = 0; + auto idx = physical_t * N_KVH + h; + qparam_k_src = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]); + qparam_v_src = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]); + } else { + // read from first row + qparam_offset_bytes = 4; + qparam_k_src = reinterpret_cast<__half2*>(&row_k[0]); + qparam_v_src = reinterpret_cast<__half2*>(&row_v[0]); + } + // Assert the quantized row dim is as expected + CUDA_KERNEL_ASSERT(D_H_q - D_H == qparam_offset_bytes); + if (4 * threadIdx.x >= D_H) { + continue; + } + // each thread reads 4 x 8 bits + + uint64_t kq = *reinterpret_cast( + &row_k[threadIdx.x * 4 + qparam_offset_bytes]); + uint64_t vq = *reinterpret_cast( + &row_v[threadIdx.x * 4 + qparam_offset_bytes]); + + uint64_t packed = kq | (vq << 32); + + kv_dq = dequantize_packed_fp8(packed, *qparam_k_src, *qparam_v_src); + + // now, write our outputs + auto* row_k_dq = &cache_K_dq[0][physical_t][h][0]; + auto* row_v_dq = &cache_V_dq[0][physical_t][h][0]; + // each thread writes 4 elements of type bf16 + *reinterpret_cast(&row_k_dq[4 * threadIdx.x]) = + *reinterpret_cast(&kv_dq.vals[0]); + *reinterpret_cast(&row_v_dq[4 * threadIdx.x]) = + *reinterpret_cast(&kv_dq.vals[2]); + } +} std::tuple dequantize_fp8_cache( at::Tensor cache_K, at::Tensor cache_V, at::Tensor kv_seqlen, std::optional qparam_k, - std::optional qparam_v) { + std::optional qparam_v, + std::optional block_tables, + int64_t page_size) { TORCH_CHECK(cache_K.is_cuda()); TORCH_CHECK(cache_V.is_cuda()); TORCH_CHECK(kv_seqlen.is_cuda()); - auto B = cache_K.size(0); + auto B = kv_seqlen.size(0); + // vanilla: B_KV = B, paged: B_KV = 1 + auto B_KV = cache_K.size(0); + // vanilla: MAX_T = MAX_T, paged: MAX_T = MAX_PAGE * PAGE_SIZE auto MAX_T = cache_K.size(1); auto N_KVH = cache_K.size(2); auto D_HQ = cache_K.size(3); @@ -1516,31 +1613,72 @@ std::tuple dequantize_fp8_cache( } auto D_H = (D_HQ - fp8_qparam_offset); - auto cache_K_dq = - at::empty({B, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16)); - auto cache_V_dq = - at::empty({B, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16)); + // TODO: + // The below allocates Tensors that have the same shape as cache_K and cache_V + // to store their dequantize results. For paged KV cache, this can be a bit + // inefficient because it has the shape of [1 x (MAX_PAGES * PAGE_SIZE) x + // N_KVH x D_H] to accommodate pages globally across batch instances, and + // if we have very large MAX_PAGES then we are essentially allocating a very + // huge Tensor here. The benefit is that the following users of this + // dequantized results can reuse the existing block_tables to access their + // elements. If we want to be more efficient, there are two possible + // approaches: (1) Allocate a shorter Tensor here and store the dequantize + // results in a more compact manner, but that requires creating a new + // block_tables here and making sure the following users all use the + // correct block_tables. (2) From outside, keep a persistent buffer that has a + // matching shape with the original paged KV and feed the same buffer + // into this function at every layer to reuse it and prevent allocation. + auto cache_K_dq = at::empty( + {B_KV, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16)); + auto cache_V_dq = at::empty( + {B_KV, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16)); if (B == 0) { return {cache_K_dq, cache_V_dq}; } + int32_t* block_tables_ptr = nullptr; + int32_t block_tables_b_stride = 0; + if (block_tables.has_value()) { + block_tables_ptr = static_cast(block_tables.value().data_ptr()); + block_tables_b_stride = block_tables.value().stride(0); + } + constexpr int32_t kMaxBlocks = 256; dim3 blocks(B, std::max(1, kMaxBlocks / B)); dim3 threads(kThreadsPerWarp, kWarpsPerBlock); - dequantize_fp8_cache_kernel<<< - blocks, - threads, - 0, - at::cuda::getCurrentCUDAStream()>>>( - cache_K.packed_accessor64(), - cache_V.packed_accessor64(), - kv_seqlen.packed_accessor32(), - cache_K_dq.packed_accessor64(), - cache_V_dq.packed_accessor64(), - qparam_k_ptr, - qparam_v_ptr); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + if (block_tables_ptr == nullptr) { + dequantize_fp8_cache_kernel<<< + blocks, + threads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + cache_K.packed_accessor64(), + cache_V.packed_accessor64(), + kv_seqlen.packed_accessor32(), + cache_K_dq.packed_accessor64(), + cache_V_dq.packed_accessor64(), + qparam_k_ptr, + qparam_v_ptr); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + dequantize_fp8_cache_kernel_paged<<< + blocks, + threads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + cache_K.packed_accessor64(), + cache_V.packed_accessor64(), + kv_seqlen.packed_accessor32(), + cache_K_dq.packed_accessor64(), + cache_V_dq.packed_accessor64(), + qparam_k_ptr, + qparam_v_ptr, + block_tables_ptr, + block_tables_b_stride, + page_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } return {cache_K_dq, cache_V_dq}; } @@ -1622,7 +1760,9 @@ std::tuple dequantize_fp8_cache( at::Tensor cache_V, at::Tensor kv_seqlen, std::optional qparam_k, - std::optional qparam_v) { + std::optional qparam_v, + std::optional block_tables, + int64_t page_size) { throw std::runtime_error( "CUDA version is older than 12.0"); // requires CUDA>=12 } From 6e323f60fe239c792f9cab48d5f44c3bfaf43335 Mon Sep 17 00:00:00 2001 From: Shawn Xu Date: Wed, 18 Sep 2024 16:23:46 -0700 Subject: [PATCH 20/27] add the capability to read snapshot from rocksdb (#3148) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/242 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3148 * This is mostly cherry picking some related stuff from sryap's D62364335 to unblock. * The main difference is here we introduce a RAII class `SnapshotHandle` to make sure we release the snapshots when the lifetime of the object is over. Furthermore, since one snapshot actually contains N shard level snapshot, we bundle them together in the handle. * The handle will be wrapped in the torch class for python interoperability (see next diff) instead of raw pointers. * This diff only focuses the write path. Read path will be in a separate diff. Reviewed By: duduyi2013 Differential Revision: D62902451 fbshipit-source-id: cb30d805ac3be425fd760784e0180b208790214a --- .../ssd_table_batched_embeddings.h | 355 ++++++++++++------ 1 file changed, 240 insertions(+), 115 deletions(-) diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 77588b5b9..3d1f2c977 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -8,11 +8,14 @@ #pragma once +#include +#include + #include +#include #include #include #include -#include #ifdef FBGEMM_FBCODE #include "common/strings/UUID.h" #include "common/time/Time.h" @@ -126,7 +129,50 @@ class Initializer { /// @brief An implementation of EmbeddingKVDB for RocksDB /// class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { + using snapshot_ptr_t = const rocksdb::Snapshot*; + public: + class SnapshotHandle { + public: + explicit SnapshotHandle(EmbeddingRocksDB* db) : db_(db) { + auto num_shards = db->num_shards(); + CHECK_GT(num_shards, 0); + shard_snapshots_.reserve(num_shards); + for (auto shard = 0; shard < num_shards; ++shard) { + const auto* snapshot = db->dbs_[shard]->GetSnapshot(); + CHECK(snapshot != nullptr) + << "ERROR: create_snapshot fails to create a snapshot " + << "for db shard " << shard << ". Please make sure that " + << "inplace_update_support is set to false" << std::endl; + shard_snapshots_.push_back(snapshot); + } + } + + ~SnapshotHandle() { + for (auto shard = 0; shard < db_->dbs_.size(); ++shard) { + snapshot_ptr_t snapshot = shard_snapshots_[shard]; + CHECK(snapshot != nullptr) + << "Unexpected nullptr for snapshot " << shard; + db_->dbs_[shard]->ReleaseSnapshot(snapshot); + } + } + + void release() { + db_->release_snapshot(this); + } + + snapshot_ptr_t get_snapshot_for_shard(size_t shard) const { + CHECK_LE(shard, shard_snapshots_.size()); + return shard_snapshots_[shard]; + } + + private: + friend class EmbeddingRocksDB; + + EmbeddingRocksDB* db_; + std::vector shard_snapshots_; + }; + explicit EmbeddingRocksDB( std::string path, int64_t num_shards, @@ -152,7 +198,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { max_D, l2_cache_size_gb, tbe_unqiue_id, - row_storage_bitwidth / 8) { + row_storage_bitwidth / 8), + max_D_(max_D) { // TODO: lots of tunables. NNI or something for this? rocksdb::Options options; options.create_if_missing = true; @@ -193,7 +240,7 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { // causing flush set this to true to make update on the existing key // allow_concurrent_memtable_write is toggled in pair with // inplace_update_support - options.inplace_update_support = true; + options.inplace_update_support = false; options.avoid_unnecessary_blocking_io = true; options.use_direct_reads = true; @@ -341,6 +388,30 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { } } + ~EmbeddingRocksDB() override { + // clear all the snapshots if not released + if (snapshots_.size() > 0) { + LOG(WARNING) + << snapshots_.size() + << " snapshots have not been released when db is closing. Releasing them now."; + } + snapshots_.clear(); + for (auto shard = 0; shard < dbs_.size(); ++shard) { + dbs_[shard]->Close(); + } + } + + folly::coro::Task get_kv_db_async( + const at::Tensor& indices, + const at::Tensor& weights, + const at::Tensor& count) override { + co_await get_kv_db_async_impl( + indices, + weights, + count, + /*snapshot_handle=*/nullptr); + } + folly::coro::Task set_kv_db_async( const at::Tensor& indices, const at::Tensor& weights, @@ -416,10 +487,162 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { #endif } - folly::coro::Task get_kv_db_async( + bool is_valid_snapshot(const SnapshotHandle* snapshot_handle) const { + return snapshots_.find(snapshot_handle) != snapshots_.end(); + } + + const SnapshotHandle* create_snapshot() { + const auto num_snapshots = snapshots_.size(); + if (num_snapshots > 0) { + std::cerr << "WARNING: create_snapshot found " << num_snapshots + << " other snapshots" << std::endl; + } + + auto handle = std::make_unique(this); + auto handlePtr = handle.get(); + snapshots_[handlePtr] = std::move(handle); + return handlePtr; + } + + void release_snapshot(const SnapshotHandle* snapshot_handle) { + snapshots_.erase(snapshot_handle); + } + + void get_range_from_snapshot( + const at::Tensor& weights, + const int64_t start, + const int64_t length, + const SnapshotHandle* snapshot_handle) { + const auto seq_indices = + at::arange(start, start + length, at::TensorOptions().dtype(at::kLong)); + int64_t* count_ = new int64_t[1]; + count_[0] = length; + const auto count = at::from_blob(count_, {1}, at::kLong); + folly::coro::blockingWait( + get_kv_db_async_impl(seq_indices, weights, count, snapshot_handle)); + } + + int64_t get_max_D() { + return max_D_; + } + + // collect mem usage on all db shards, checkout rocks_db_mem_properties + std::vector get_mem_usage() { + int num_mem_component = rocks_db_mem_properties.size(); + std::vector mem_usages(num_mem_component); + for (auto& db : dbs_) { + for (int i = 0; i < num_mem_component; i++) { + std::string property = rocks_db_mem_properties[i]; + std::string val; + db->GetProperty(property, &val); + if (val != "") { + if (i != 0) { + mem_usages[i] += folly::to(val); + } else { + mem_usages[i] = folly::to(val); + } + } + } + } + return mem_usages; + } + + std::vector get_rocksdb_io_duration( + const int64_t step, + const int64_t interval) { + std::vector ret; + ret.reserve(5); + if (step > 0 && step % interval == 0) { + int64_t reset_val = 0; + auto read_dur = read_total_duration_.exchange(reset_val); + + auto fwd_rocksdb_read_dur = fwd_rocksdb_read_dur_.exchange(reset_val); + auto fwd_l1_eviction_dur = fwd_l1_eviction_dur_.exchange(reset_val); + auto bwd_l1_cnflct_miss_write_back_dur = + bwd_l1_cnflct_miss_write_back_dur_.exchange(reset_val); + auto flush_write_dur = flush_write_dur_.exchange(reset_val); + + ret.push_back(double(read_dur) / interval); + ret.push_back(double(fwd_rocksdb_read_dur) / interval); + ret.push_back(double(fwd_l1_eviction_dur) / interval); + ret.push_back(double(bwd_l1_cnflct_miss_write_back_dur) / interval); + ret.push_back(double(flush_write_dur) / interval); + } + return ret; + } + + void compact() override { + for (auto& db : dbs_) { + db->CompactRange(rocksdb::CompactRangeOptions(), nullptr, nullptr); + } + } + + void flush() { + kv_db::EmbeddingKVDB::flush(); + for (auto& db : dbs_) { + db->Flush(rocksdb::FlushOptions()); + } + } + + int64_t num_shards() const { + return dbs_.size(); + } + + private: + void flush_or_compact(const int64_t timestep) override { + // Only do manual Flush/Compactions if enabled + if (memtable_flush_period_ > 0) { + { + RECORD_USER_SCOPE("FlushCompactIfNecessary"); + if (!done_staggered_flushes_) { + flush_if_necessary(timestep); + } else { + compact_if_necessary(timestep); + } + } + } + } + + void flush_if_necessary(const int64_t timestep) { + for (int64_t i = 0; i < dbs_.size(); i++) { + if (shard_flush_compaction_deadlines_[i] == timestep) { + rocksdb::FlushOptions fo; + fo.wait = false; + fo.allow_write_stall = false; + dbs_[i]->Flush(fo); + if (i == dbs_.size() - 1) { + done_staggered_flushes_ = true; + int64_t period_per_shard = compaction_period_ / dbs_.size(); + int64_t offset = memtable_flush_offset_ + compaction_period_; + for (int64_t j = 0; j < dbs_.size(); j++) { + shard_flush_compaction_deadlines_[j] = + offset + (j * period_per_shard); + } + } + } + } + } + + void compact_if_necessary(const int64_t timestep) { + for (int64_t i = 0; i < dbs_.size(); i++) { + if (shard_flush_compaction_deadlines_[i] == timestep) { + rocksdb::ColumnFamilyMetaData meta; + dbs_[i]->GetColumnFamilyMetaData(&meta); + int32_t num_level0 = meta.levels[0].files.size(); + if (num_level0 >= l0_files_per_compact_) { + dbs_[i]->CompactRange( + rocksdb::CompactRangeOptions(), nullptr, nullptr); + } + shard_flush_compaction_deadlines_[i] += compaction_period_; + } + } + } + + folly::coro::Task get_kv_db_async_impl( const at::Tensor& indices, const at::Tensor& weights, - const at::Tensor& count) override { + const at::Tensor& count, + const SnapshotHandle* snapshot_handle) { RECORD_USER_SCOPE("EmbeddingRocksDB::get"); #ifdef FBGEMM_FBCODE auto start_ts = facebook::WallClockUtil::NowInUsecFast(); @@ -428,9 +651,13 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { auto count_ = count.item().toLong(); for (auto shard = 0; shard < dbs_.size(); ++shard) { + // Get a snapshot for the shard + snapshot_ptr_t snapshot = snapshot_handle == nullptr + ? nullptr + : snapshot_handle->get_snapshot_for_shard(shard); tasks.emplace_back( folly::coro::co_invoke( - [this, &indices, &weights, count_, shard]() mutable + [this, &indices, &weights, count_, shard, snapshot]() mutable -> folly::coro::Task { FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( weights.scalar_type(), "ssd_get", [&] { @@ -487,6 +714,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { values.resize(keys.size()); statuses.resize(keys.size()); + // Set a snapshot if it is available + ro_.snapshot = snapshot; dbs_[shard]->MultiGet( ro_, keys.size(), @@ -545,114 +774,6 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { #endif } - // collect mem usage on all db shards, checkout rocks_db_mem_properties - std::vector get_mem_usage() { - int num_mem_component = rocks_db_mem_properties.size(); - std::vector mem_usages(num_mem_component); - for (auto& db : dbs_) { - for (int i = 0; i < num_mem_component; i++) { - std::string property = rocks_db_mem_properties[i]; - std::string val; - db->GetProperty(property, &val); - if (val != "") { - if (i != 0) { - mem_usages[i] += folly::to(val); - } else { - mem_usages[i] = folly::to(val); - } - } - } - } - return mem_usages; - } - - std::vector get_rocksdb_io_duration( - const int64_t step, - const int64_t interval) { - std::vector ret; - ret.reserve(5); - if (step > 0 && step % interval == 0) { - int64_t reset_val = 0; - auto read_dur = read_total_duration_.exchange(reset_val); - - auto fwd_rocksdb_read_dur = fwd_rocksdb_read_dur_.exchange(reset_val); - auto fwd_l1_eviction_dur = fwd_l1_eviction_dur_.exchange(reset_val); - auto bwd_l1_cnflct_miss_write_back_dur = - bwd_l1_cnflct_miss_write_back_dur_.exchange(reset_val); - auto flush_write_dur = flush_write_dur_.exchange(reset_val); - - ret.push_back(double(read_dur) / interval); - ret.push_back(double(fwd_rocksdb_read_dur) / interval); - ret.push_back(double(fwd_l1_eviction_dur) / interval); - ret.push_back(double(bwd_l1_cnflct_miss_write_back_dur) / interval); - ret.push_back(double(flush_write_dur) / interval); - } - return ret; - } - - void compact() override { - for (auto& db : dbs_) { - db->CompactRange(rocksdb::CompactRangeOptions(), nullptr, nullptr); - } - } - - void flush() { - kv_db::EmbeddingKVDB::flush(); - for (auto& db : dbs_) { - db->Flush(rocksdb::FlushOptions()); - } - } - - private: - void flush_or_compact(const int64_t timestep) override { - // Only do manual Flush/Compactions if enabled - if (memtable_flush_period_ > 0) { - { - RECORD_USER_SCOPE("FlushCompactIfNecessary"); - if (!done_staggered_flushes_) { - flush_if_necessary(timestep); - } else { - compact_if_necessary(timestep); - } - } - } - } - - void flush_if_necessary(const int64_t timestep) { - for (int64_t i = 0; i < dbs_.size(); i++) { - if (shard_flush_compaction_deadlines_[i] == timestep) { - rocksdb::FlushOptions fo; - fo.wait = false; - fo.allow_write_stall = false; - dbs_[i]->Flush(fo); - if (i == dbs_.size() - 1) { - done_staggered_flushes_ = true; - int64_t period_per_shard = compaction_period_ / dbs_.size(); - int64_t offset = memtable_flush_offset_ + compaction_period_; - for (int64_t j = 0; j < dbs_.size(); j++) { - shard_flush_compaction_deadlines_[j] = - offset + (j * period_per_shard); - } - } - } - } - } - - void compact_if_necessary(const int64_t timestep) { - for (int64_t i = 0; i < dbs_.size(); i++) { - if (shard_flush_compaction_deadlines_[i] == timestep) { - rocksdb::ColumnFamilyMetaData meta; - dbs_[i]->GetColumnFamilyMetaData(&meta); - int32_t num_level0 = meta.levels[0].files.size(); - if (num_level0 >= l0_files_per_compact_) { - dbs_[i]->CompactRange( - rocksdb::CompactRangeOptions(), nullptr, nullptr); - } - shard_flush_compaction_deadlines_[i] += compaction_period_; - } - } - } - std::vector> dbs_; std::vector> initializers_; std::unique_ptr executor_; @@ -672,6 +793,10 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { std::atomic fwd_l1_eviction_dur_{0}; std::atomic bwd_l1_cnflct_miss_write_back_dur_{0}; std::atomic flush_write_dur_{0}; -}; // class EmbeddingKVDB + + std::unordered_map> + snapshots_; + int64_t max_D_; +}; // class EmbeddingRocksDB } // namespace ssd From c01bbb88a915e63f3b1763d4065382d09aa2e708 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Wed, 18 Sep 2024 16:50:38 -0700 Subject: [PATCH 21/27] Fix logging in benchmarks Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/244 - Fix logging in benchmarks (redo D62973297 bc diff cannot be unlinked) Reviewed By: basilwong, spcyppt Differential Revision: D62983422 fbshipit-source-id: 9551cf15df7c1c0281193758b25aaae177e34b67 --- fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py | 3 ++- fbgemm_gpu/bench/jagged_tensor_benchmark.py | 3 ++- fbgemm_gpu/bench/merge_embeddings_benchmark.py | 3 +++ fbgemm_gpu/bench/quantize_ops_benchmark.py | 4 ++-- fbgemm_gpu/bench/sparse_ops_benchmark.py | 3 ++- fbgemm_gpu/bench/split_embeddings_cache_benchmark.py | 3 ++- fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py | 3 +++ fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py | 5 ++--- fbgemm_gpu/bench/stride_gemm_benchmark.py | 3 ++- 9 files changed, 20 insertions(+), 10 deletions(-) diff --git a/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py b/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py index 592f11496..c919199ee 100644 --- a/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py +++ b/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py @@ -14,7 +14,8 @@ import torch from torch import Tensor -logging.basicConfig(level=logging.DEBUG) +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) try: # pyre-ignore[21] diff --git a/fbgemm_gpu/bench/jagged_tensor_benchmark.py b/fbgemm_gpu/bench/jagged_tensor_benchmark.py index 51c231ad0..acbe22fb2 100644 --- a/fbgemm_gpu/bench/jagged_tensor_benchmark.py +++ b/fbgemm_gpu/bench/jagged_tensor_benchmark.py @@ -16,7 +16,8 @@ import torch from torch.profiler import profile -logging.basicConfig(level=logging.DEBUG) +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. open_source: bool = getattr(fbgemm_gpu, "open_source", False) diff --git a/fbgemm_gpu/bench/merge_embeddings_benchmark.py b/fbgemm_gpu/bench/merge_embeddings_benchmark.py index 2c0b62664..95ce71d27 100644 --- a/fbgemm_gpu/bench/merge_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/merge_embeddings_benchmark.py @@ -32,6 +32,9 @@ # pyre-fixme[21]: Could not find name `ProfilerActivity` in `torch.profiler`. from torch.profiler import profile, ProfilerActivity +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. open_source: bool = getattr(fbgemm_gpu, "open_source", False) diff --git a/fbgemm_gpu/bench/quantize_ops_benchmark.py b/fbgemm_gpu/bench/quantize_ops_benchmark.py index b4e596f4d..9ffbd9911 100644 --- a/fbgemm_gpu/bench/quantize_ops_benchmark.py +++ b/fbgemm_gpu/bench/quantize_ops_benchmark.py @@ -22,8 +22,8 @@ # pyre-ignore[21] from torch.profiler import profile, ProfilerActivity - -logging.basicConfig(level=logging.DEBUG) +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. open_source: bool = getattr(fbgemm_gpu, "open_source", False) diff --git a/fbgemm_gpu/bench/sparse_ops_benchmark.py b/fbgemm_gpu/bench/sparse_ops_benchmark.py index fdd051909..2ef9abe8f 100644 --- a/fbgemm_gpu/bench/sparse_ops_benchmark.py +++ b/fbgemm_gpu/bench/sparse_ops_benchmark.py @@ -20,7 +20,8 @@ from torch.profiler import profile -logging.basicConfig(level=logging.DEBUG) +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. open_source: bool = getattr(fbgemm_gpu, "open_source", False) diff --git a/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py b/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py index 432ef3f4d..d3169ca81 100644 --- a/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py +++ b/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py @@ -26,7 +26,8 @@ from torch import nn, Tensor -logging.basicConfig(level=logging.DEBUG) +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) try: # pyre-ignore[21] diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index a809b5e9b..177c79508 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -48,6 +48,9 @@ from torch import Tensor from torch.profiler import profile +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + haveAIBench = False try: from aibench_observer.utils.observer import emitMetric diff --git a/fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py index 430087c4a..25540c190 100644 --- a/fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py @@ -40,14 +40,13 @@ from torch.autograd.profiler import record_function from torch.profiler import profile -logging.basicConfig(level=logging.DEBUG) +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) load_torch_module( "//deeplearning/fbgemm/fbgemm_gpu:ssd_split_table_batched_embeddings", ) -logging.basicConfig(level=logging.DEBUG) - @click.group() def cli() -> None: diff --git a/fbgemm_gpu/bench/stride_gemm_benchmark.py b/fbgemm_gpu/bench/stride_gemm_benchmark.py index 3c70d734f..2609f7fbf 100644 --- a/fbgemm_gpu/bench/stride_gemm_benchmark.py +++ b/fbgemm_gpu/bench/stride_gemm_benchmark.py @@ -13,7 +13,8 @@ import torch from fbgemm_gpu.bench.bench_utils import benchmark_torch_function -logging.basicConfig(level=logging.DEBUG) +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) try: # pyre-ignore[21] From bc7c433d9d4c4c5c1173e082cb254af3643621ae Mon Sep 17 00:00:00 2001 From: Ismail Pazarbasi Date: Wed, 18 Sep 2024 21:11:13 -0700 Subject: [PATCH 22/27] Added fake implementation for lengths_range (#3093) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/246 Added fake implementation for `fbgemm::lengths_range` Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3093 Reviewed By: ipazarbasi Differential Revision: D62993442 Pulled By: q10 fbshipit-source-id: 7ea2951c86fdadae61a98f8b2ced4990038a3abb --- fbgemm_gpu/fbgemm_gpu/sparse_ops.py | 17 +++++++++++++++- fbgemm_gpu/test/sparse/misc_ops_test.py | 26 +++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 024170a68..21e858e1a 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -7,7 +7,7 @@ # pyre-strict import math -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional, Sequence, Tuple import torch @@ -1113,3 +1113,18 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None _setup() + + +@torch.library.register_fake("fbgemm::lengths_range") +def lengths_range_abstract( + lengths: Tensor, + output_shape: Optional[Sequence[int]] = None, +) -> Tensor: + torch._check(lengths.dim() == 1, lambda: "lengths must be a 1D tensor") + output_size = 0 + if output_shape is not None: + output_size = math.prod(output_shape) + else: + ctx = torch.library.get_ctx() + output_size = ctx.new_dynamic_size() + return lengths.new_empty([output_size], dtype=lengths.dtype) diff --git a/fbgemm_gpu/test/sparse/misc_ops_test.py b/fbgemm_gpu/test/sparse/misc_ops_test.py index fb9e29c81..41187b502 100644 --- a/fbgemm_gpu/test/sparse/misc_ops_test.py +++ b/fbgemm_gpu/test/sparse/misc_ops_test.py @@ -18,6 +18,7 @@ import numpy as np import torch from hypothesis import given, settings, Verbosity +from torch.fx.experimental.symbolic_shapes import ShapeEnv from .common import extend_test_class, open_source @@ -261,6 +262,31 @@ def test_bottom_unique_k_per_row( all_indices_deduped_ref = torch.as_tensor(all_indices[:, :, :L]) torch.testing.assert_close(all_indices_deduped, all_indices_deduped_ref) + def test_lengths_range(self) -> None: + # When 'output_shape' is None, the function will return a tensor with dynamic shape. + with self.assertRaisesRegex( + torch._subclasses.fake_tensor.DynamicOutputShapeException, + "fbgemm.lengths_range.default", + ): + with torch._subclasses.fake_tensor.FakeTensorMode( + shape_env=ShapeEnv( + allow_dynamic_output_shape_ops=False, + ), + ): + lengths = torch.tensor([3, 2, 4, 10], dtype=torch.int32) + _ = torch.ops.fbgemm.lengths_range(lengths, None) + + with torch._subclasses.fake_tensor.FakeTensorMode( + shape_env=ShapeEnv( + allow_dynamic_output_shape_ops=False, + ), + ): + lengths = torch.tensor([3, 2, 4, 10], dtype=torch.int32) + output_shape = [1, 2, 4, 4] + actual_result = torch.ops.fbgemm.lengths_range(lengths, output_shape) + + self.assertEqual(actual_result.shape, (1 * 2 * 4 * 4,)) + extend_test_class(MiscOpsTest) From 46e309d58b0e2026311480e2763c91591196694e Mon Sep 17 00:00:00 2001 From: Igor Sugak Date: Wed, 18 Sep 2024 21:12:18 -0700 Subject: [PATCH 23/27] use npt.NDArray instead of np.ndarray in type annotations [2/N] (#3150) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3150 X-link: https://github.com/facebookresearch/FBGEMM/pull/245 To facilitate PSS-2 upgrade, this uses `ndt.NDArray` instead of `nd.ndarray` in type annotations. In Numpy-1.19 (PSS-1) it's an alias to `nd.ndarray` -- a noop. In Numpy-1.24, `ndt.NDArray` a proper generic type, and without this change uses of `nd.ndarray` generate this Pyre type error: ```counterexample Invalid type parameters [24]: Generic type `np.ndarray` expects 2 type parameters. ``` Reviewed By: florazzz Differential Revision: D62986280 fbshipit-source-id: a08c3b5fbdc04f5a0100359e6bc18a596ed0c307 --- fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py | 3 +- fbgemm_gpu/test/jagged/common.py | 5 ++-- fbgemm_gpu/test/quantize/common.py | 31 +++++++++++--------- fbgemm_gpu/test/sparse/pack_segments_test.py | 5 ++-- 4 files changed, 25 insertions(+), 19 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py b/fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py index fbccebd9a..19927a5ea 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py @@ -11,6 +11,7 @@ from typing import List, Optional, Tuple import numpy as np +import numpy.typing as npt import torch # pyre-fixme[21]: Could not find name `default_rng` in `numpy.random` (stubbed). @@ -135,7 +136,7 @@ def generate_int_data_from_stats( sigma: int, size: int, distribution: str, -) -> np.ndarray: +) -> npt.NDArray: """ Generate integer data based on stats """ diff --git a/fbgemm_gpu/test/jagged/common.py b/fbgemm_gpu/test/jagged/common.py index 6cd60d96c..3bdcacf98 100644 --- a/fbgemm_gpu/test/jagged/common.py +++ b/fbgemm_gpu/test/jagged/common.py @@ -16,6 +16,7 @@ import fbgemm_gpu import fbgemm_gpu.sparse_ops import numpy as np +import numpy.typing as npt import torch from hypothesis import HealthCheck, settings @@ -122,7 +123,7 @@ def generate_jagged_tensor( # dynamo to mark the input as dynamic shape to make sure symbolic # shape is generated mark_dynamic: bool = False, -) -> Tuple[torch.Tensor, List[torch.LongTensor], np.ndarray]: +) -> Tuple[torch.Tensor, List[torch.LongTensor], npt.NDArray]: max_lengths = np.random.randint(low=1, high=10, size=(num_jagged_dim,)) x_offsets: List[torch.LongTensor] = [] num_lengths = outer_dense_size @@ -167,7 +168,7 @@ def generate_jagged_tensor( def to_padded_dense( values: torch.Tensor, offsets: List[torch.LongTensor], - max_lengths: np.ndarray, + max_lengths: npt.NDArray, padding_value: float = 0, ) -> torch.Tensor: outer_dense_size = len(offsets[0]) - 1 diff --git a/fbgemm_gpu/test/quantize/common.py b/fbgemm_gpu/test/quantize/common.py index 392bbfcac..5333cc893 100644 --- a/fbgemm_gpu/test/quantize/common.py +++ b/fbgemm_gpu/test/quantize/common.py @@ -12,6 +12,7 @@ import fbgemm_gpu import numpy as np +import numpy.typing as npt import torch # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. @@ -30,17 +31,17 @@ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") # Eigen/Python round 0.5 away from 0, Numpy rounds to even -round_to_nearest: Callable[[np.ndarray], np.ndarray] = np.vectorize(round) +round_to_nearest: Callable[[npt.NDArray], npt.NDArray] = np.vectorize(round) -def bytes_to_floats(byte_matrix: np.ndarray) -> np.ndarray: +def bytes_to_floats(byte_matrix: npt.NDArray) -> npt.NDArray: floats = np.empty([np.shape(byte_matrix)[0], 1], dtype=np.float32) for i, byte_values in enumerate(byte_matrix): (floats[i],) = struct.unpack("f", bytearray(byte_values)) return floats -def floats_to_bytes(floats: np.ndarray) -> np.ndarray: +def floats_to_bytes(floats: npt.NDArray) -> npt.NDArray: byte_matrix = np.empty([np.shape(floats)[0], 4], dtype=np.uint8) for i, value in enumerate(floats): assert isinstance(value, np.float32), (value, floats) @@ -53,7 +54,7 @@ def floats_to_bytes(floats: np.ndarray) -> np.ndarray: return byte_matrix -def bytes_to_half_floats(byte_matrix: np.ndarray) -> np.ndarray: +def bytes_to_half_floats(byte_matrix: npt.NDArray) -> npt.NDArray: floats = np.empty([np.shape(byte_matrix)[0], 1], dtype=np.float16) for i, byte_values in enumerate(byte_matrix): (floats[i],) = np.frombuffer( @@ -62,7 +63,7 @@ def bytes_to_half_floats(byte_matrix: np.ndarray) -> np.ndarray: return floats -def half_floats_to_bytes(floats: np.ndarray) -> np.ndarray: +def half_floats_to_bytes(floats: npt.NDArray) -> npt.NDArray: byte_matrix = np.empty([np.shape(floats)[0], 2], dtype=np.uint8) for i, value in enumerate(floats): assert isinstance(value, np.float16), (value, floats) @@ -72,7 +73,7 @@ def half_floats_to_bytes(floats: np.ndarray) -> np.ndarray: return byte_matrix -def fused_rowwise_8bit_quantize_reference(data: np.ndarray) -> np.ndarray: +def fused_rowwise_8bit_quantize_reference(data: npt.NDArray) -> npt.NDArray: minimum = np.min(data, axis=-1, keepdims=True) maximum = np.max(data, axis=-1, keepdims=True) span = maximum - minimum @@ -87,7 +88,9 @@ def fused_rowwise_8bit_quantize_reference(data: np.ndarray) -> np.ndarray: return np.concatenate([quantized_data, scale_bytes, bias_bytes], axis=-1) -def fused_rowwise_8bit_dequantize_reference(fused_quantized: np.ndarray) -> np.ndarray: +def fused_rowwise_8bit_dequantize_reference( + fused_quantized: npt.NDArray, +) -> npt.NDArray: scale = bytes_to_floats(fused_quantized[..., -8:-4].astype(np.uint8).reshape(-1, 4)) scale = scale.reshape(fused_quantized.shape[:-1] + (scale.shape[-1],)) bias = bytes_to_floats(fused_quantized[..., -4:].astype(np.uint8).reshape(-1, 4)) @@ -97,8 +100,8 @@ def fused_rowwise_8bit_dequantize_reference(fused_quantized: np.ndarray) -> np.n def fused_rowwise_8bit_dequantize_2bytes_padding_scale_bias_first_reference( - fused_quantized: np.ndarray, -) -> np.ndarray: + fused_quantized: npt.NDArray, +) -> npt.NDArray: scale = bytes_to_half_floats( fused_quantized[..., 0:2].astype(np.uint8).reshape(-1, 2) ) @@ -112,8 +115,8 @@ def fused_rowwise_8bit_dequantize_2bytes_padding_scale_bias_first_reference( def fused_rowwise_8bit_dequantize_reference_half( - fused_quantized: np.ndarray, -) -> np.ndarray: + fused_quantized: npt.NDArray, +) -> npt.NDArray: scale = bytes_to_half_floats( fused_quantized[..., -8:-4].astype(np.uint8).reshape(-1, 4) ) @@ -126,7 +129,7 @@ def fused_rowwise_8bit_dequantize_reference_half( return quantized_data * scale + bias -def fused_rowwise_nbit_quantize_reference(data: np.ndarray, bit: int) -> np.ndarray: +def fused_rowwise_nbit_quantize_reference(data: npt.NDArray, bit: int) -> npt.NDArray: minimum = np.min(data, axis=1).astype(np.float16).astype(np.float32) maximum = np.max(data, axis=1) span = maximum - minimum @@ -165,8 +168,8 @@ def fused_rowwise_nbit_quantize_reference(data: np.ndarray, bit: int) -> np.ndar def fused_rowwise_nbit_quantize_dequantize_reference( - data: np.ndarray, bit: int -) -> np.ndarray: + data: npt.NDArray, bit: int +) -> npt.NDArray: fused_quantized = fused_rowwise_nbit_quantize_reference(data, bit) scale = bytes_to_half_floats(fused_quantized[:, -4:-2].astype(np.uint8)).astype( np.float32 diff --git a/fbgemm_gpu/test/sparse/pack_segments_test.py b/fbgemm_gpu/test/sparse/pack_segments_test.py index c6383017a..d6b40328e 100644 --- a/fbgemm_gpu/test/sparse/pack_segments_test.py +++ b/fbgemm_gpu/test/sparse/pack_segments_test.py @@ -15,6 +15,7 @@ import hypothesis.strategies as st import numpy as np +import numpy.typing as npt import torch from hypothesis import given, settings @@ -27,7 +28,7 @@ from fbgemm_gpu.test.test_utils import gpu_available -def get_n_rand_num_summing_to_k(n: int, k: int) -> np.ndarray: +def get_n_rand_num_summing_to_k(n: int, k: int) -> npt.NDArray: """Get a list of `n` integers which collectively sum to `k`, drawn uniformly from the set of all such lists. @@ -58,7 +59,7 @@ def _pack_segments_ref( lengths: torch.Tensor, tensor: torch.Tensor, max_length: Optional[int] = None, - ) -> np.ndarray: + ) -> npt.NDArray: lengths = lengths.numpy() sections = np.split(tensor, np.cumsum(lengths)) max_length = np.max(lengths, initial=0) if max_length is None else max_length From 12710ef6135ec940b88942da1d87cdff8549e621 Mon Sep 17 00:00:00 2001 From: Joe Wang Date: Thu, 19 Sep 2024 09:56:36 -0700 Subject: [PATCH 24/27] support different index dtype (#3140) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3140 X-link: https://github.com/facebookresearch/FBGEMM/pull/233 before this diff we hack the indices to be int64 after this diff, SSD tbe support int32 indices Reviewed By: q10 Differential Revision: D62761615 fbshipit-source-id: 5a08b022c0ceaa30bc0b7c2ec66e8b8444e15657 --- .../split_embeddings_cache/cachelib_cache.h | 20 +- .../split_embeddings_cache/cachelib_cache.cpp | 201 +++++++------ .../kv_db_table_batched_embeddings.cpp | 265 ++++++++++------- .../kv_db_table_batched_embeddings.h | 6 + .../ssd_table_batched_embeddings.h | 277 ++++++++++-------- .../tbe/ssd/ssd_split_tbe_training_test.py | 12 +- 6 files changed, 446 insertions(+), 335 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h index bb6edff00..80955120f 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h @@ -39,7 +39,9 @@ class CacheLibCache { int64_t max_D_; }; - explicit CacheLibCache(const CacheConfig& cache_config); + explicit CacheLibCache( + const CacheConfig& cache_config, + int64_t unique_tbe_id); std::unique_ptr initializeCacheLib(const CacheConfig& config); @@ -48,7 +50,7 @@ class CacheLibCache { /// Find the stored embeddings from a given embedding indices, aka key /// - /// @param key embedding index to look up + /// @param key_tensor embedding index(tensor with only one element) to look up /// /// @return an optional value, return none on cache misses, if cache hit /// return a pointer to the cachelib underlying storage of associated @@ -57,7 +59,7 @@ class CacheLibCache { /// @note that this is not thread safe, caller needs to make sure the data is /// fully processed before doing cache insertion, otherwise the returned space /// might be overwritten if cache is full - std::optional get(int64_t key); + folly::Optional get(const at::Tensor& key_tensor); /// Cachelib wrapper specific hash function /// @@ -84,7 +86,8 @@ class CacheLibCache { /// Add an embedding index and embeddings into cachelib /// - /// @param key embedding index to insert + /// @param key_tensor embedding index(tensor with only one element) to insert + /// @param data embedding weights to insert /// /// @return true on success insertion, false on failure insertion, a failure /// insertion could happen if the refcount of bottom K items in LRU queue @@ -94,11 +97,12 @@ class CacheLibCache { /// bulk read and bluk write sequentially /// /// @note cache_->allocation will trigger eviction callback func - bool put(int64_t key, const at::Tensor& data); + bool put(const at::Tensor& key_tensor, const at::Tensor& data); /// iterate through all items in L2 cache, fill them in indices and weights /// respectively and return indices, weights and count /// + /// @return optional value, if cache is empty return none /// @return indices The 1D embedding index tensor, should skip on negative /// value /// @return weights The 2D tensor that each row(embeddings) is paired up with @@ -108,7 +112,8 @@ class CacheLibCache { /// /// @note this isn't thread safe, caller needs to make sure put isn't called /// while this is executed. - std::tuple get_all_items(); + folly::Optional> + get_all_items(); /// instantiate eviction related indices and weights tensors(size of ) /// for L2 eviction using the same dtype and device from and @@ -141,12 +146,15 @@ class CacheLibCache { private: const CacheConfig cache_config_; + const int64_t unique_tbe_id_; std::unique_ptr cache_; std::vector pool_ids_; std::unique_ptr admin_; folly::Optional evicted_indices_opt_{folly::none}; folly::Optional evicted_weights_opt_{folly::none}; + folly::Optional index_dtype_{folly::none}; + folly::Optional weights_dtype_{folly::none}; std::atomic eviction_row_id{0}; }; diff --git a/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp b/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp index d7eb220d8..f5002375c 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp @@ -15,25 +15,11 @@ namespace l2_cache { using Cache = facebook::cachelib::LruAllocator; -// this is a general predictor for weights data type, might not be general -// enough for all the cases -at::ScalarType bytes_to_dtype(int num_bytes) { - switch (num_bytes) { - case 1: - return at::kByte; - case 2: - return at::kHalf; - case 4: - return at::kFloat; - case 8: - return at::kDouble; - default: - throw std::runtime_error("Unsupported dtype"); - } -} - -CacheLibCache::CacheLibCache(const CacheConfig& cache_config) +CacheLibCache::CacheLibCache( + const CacheConfig& cache_config, + int64_t unique_tbe_id) : cache_config_(cache_config), + unique_tbe_id_(unique_tbe_id), cache_(initializeCacheLib(cache_config_)), admin_(createCacheAdmin(*cache_)) { for (size_t i = 0; i < cache_config_.num_shards; i++) { @@ -50,30 +36,41 @@ CacheLibCache::CacheLibCache(const CacheConfig& cache_config) std::unique_ptr CacheLibCache::initializeCacheLib( const CacheConfig& config) { - auto eviction_cb = - [this](const facebook::cachelib::LruAllocator::RemoveCbData& data) { - FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( - evicted_weights_opt_->scalar_type(), "l2_eviction_handling", [&] { - if (data.context == - facebook::cachelib::RemoveContext::kEviction) { - auto indices_data_ptr = - evicted_indices_opt_->data_ptr(); - auto weights_data_ptr = - evicted_weights_opt_->data_ptr(); - auto row_id = eviction_row_id++; - auto weight_dim = evicted_weights_opt_->size(1); - const auto key_ptr = - reinterpret_cast(data.item.getKey().data()); - indices_data_ptr[row_id] = *key_ptr; + auto eviction_cb = [this]( + const facebook::cachelib::LruAllocator::RemoveCbData& + data) { + if (evicted_weights_opt_.has_value()) { + FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( + evicted_weights_opt_->scalar_type(), "l2_eviction_handling", [&] { + using value_t = scalar_t; + FBGEMM_DISPATCH_INTEGRAL_TYPES( + evicted_indices_opt_->scalar_type(), + "l2_eviction_handling", + [&] { + using index_t = scalar_t; + if (data.context == + facebook::cachelib::RemoveContext::kEviction) { + auto indices_data_ptr = + evicted_indices_opt_->data_ptr(); + auto weights_data_ptr = + evicted_weights_opt_->data_ptr(); + auto row_id = eviction_row_id++; + auto weight_dim = evicted_weights_opt_->size(1); + const auto key_ptr = reinterpret_cast( + data.item.getKey().data()); + indices_data_ptr[row_id] = *key_ptr; - std::copy( - reinterpret_cast(data.item.getMemory()), - reinterpret_cast(data.item.getMemory()) + - weight_dim, - &weights_data_ptr[row_id * weight_dim]); // dst_start - } - }); - }; + std::copy( + reinterpret_cast(data.item.getMemory()), + reinterpret_cast( + data.item.getMemory()) + + weight_dim, + &weights_data_ptr[row_id * weight_dim]); // dst_start + } + }); + }); + } + }; Cache::Config cacheLibConfig; int64_t rough_num_items = cache_config_.cache_size_bytes / cache_config_.item_size_bytes; @@ -82,8 +79,9 @@ std::unique_ptr CacheLibCache::initializeCacheLib( unsigned int lock_power = std::log(cache_config_.num_shards * 15) / std::log(2) + 1; XLOG(INFO) << fmt::format( - "Setting up Cachelib for L2 cache, capacity: {}GB, " + "[TBE_ID{}] Setting up Cachelib for L2 cache, capacity: {}GB, " "item_size: {}B, max_num_items: {}, bucket_power: {}, lock_power: {}", + unique_tbe_id_, config.cache_size_bytes / 1024 / 1024 / 1024, cache_config_.item_size_bytes, rough_num_items, @@ -106,14 +104,21 @@ std::unique_ptr CacheLibCache::createCacheAdmin( cache, std::move(adminConfig)); } -std::optional CacheLibCache::get(int64_t key) { - auto key_str = - folly::StringPiece(reinterpret_cast(&key), sizeof(int64_t)); - auto item = cache_->find(key_str); - if (!item) { - return std::nullopt; - } - return const_cast(item->getMemory()); +folly::Optional CacheLibCache::get(const at::Tensor& key_tensor) { + folly::Optional res; + FBGEMM_DISPATCH_INTEGRAL_TYPES(key_tensor.scalar_type(), "get", [&] { + using index_t = scalar_t; + auto key = *(key_tensor.data_ptr()); + auto key_str = folly::StringPiece( + reinterpret_cast(&key), sizeof(index_t)); + auto item = cache_->find(key_str); + if (!item) { + res = folly::none; + return; + } + res = const_cast(item->getMemory()); + }); + return res; } size_t CacheLibCache::get_shard_id(int64_t key) { @@ -136,55 +141,79 @@ void CacheLibCache::batchMarkUseful( } } -bool CacheLibCache::put(int64_t key, const at::Tensor& data) { - auto key_str = - folly::StringPiece(reinterpret_cast(&key), sizeof(int64_t)); - auto item = cache_->findToWrite(key_str); - if (!item) { - auto alloc_item = - cache_->allocate(get_pool_id(key), key_str, data.nbytes()); - if (!alloc_item) { - XLOG(ERR) << fmt::format( - "Failed to allocate item {} in cache, skip", key); - return false; - } - std::memcpy(alloc_item->getMemory(), data.data_ptr(), data.nbytes()); - cache_->insertOrReplace(std::move(alloc_item)); - } else { - std::memcpy(item->getMemory(), data.data_ptr(), data.nbytes()); +bool CacheLibCache::put(const at::Tensor& key_tensor, const at::Tensor& data) { + if (!index_dtype_.has_value()) { + index_dtype_ = key_tensor.scalar_type(); + } + if (!weights_dtype_.has_value()) { + weights_dtype_ = data.scalar_type(); } - return true; + bool res; + FBGEMM_DISPATCH_INTEGRAL_TYPES(key_tensor.scalar_type(), "put", [&] { + using index_t = scalar_t; + auto key = *(key_tensor.data_ptr()); + auto key_str = folly::StringPiece( + reinterpret_cast(&key), sizeof(index_t)); + auto item = cache_->findToWrite(key_str); + if (!item) { + auto alloc_item = + cache_->allocate(get_pool_id(key), key_str, data.nbytes()); + if (!alloc_item) { + XLOG(ERR) << fmt::format( + "[TBE_ID{}]Failed to allocate item {} in cache, skip", + unique_tbe_id_, + key); + res = false; + return; + } + std::memcpy(alloc_item->getMemory(), data.data_ptr(), data.nbytes()); + cache_->insertOrReplace(std::move(alloc_item)); + } else { + std::memcpy(item->getMemory(), data.data_ptr(), data.nbytes()); + } + res = true; + }); + return res; } -std::tuple CacheLibCache::get_all_items() { +folly::Optional> +CacheLibCache::get_all_items() { + if (!index_dtype_.has_value() || !weights_dtype_.has_value()) { + return folly::none; + } int total_num_items = 0; for (auto& pool_id : pool_ids_) { total_num_items += cache_->getPoolStats(pool_id).numItems(); } auto weight_dim = cache_config_.max_D_; - auto weights_dtype = - bytes_to_dtype(cache_config_.item_size_bytes / weight_dim); auto indices = at::empty( - total_num_items, at::TensorOptions().dtype(at::kLong).device(at::kCPU)); + total_num_items, + at::TensorOptions().dtype(index_dtype_.value()).device(at::kCPU)); auto weights = at::empty( {total_num_items, weight_dim}, - at::TensorOptions().dtype(weights_dtype).device(at::kCPU)); + at::TensorOptions().dtype(weights_dtype_.value()).device(at::kCPU)); FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( weights.scalar_type(), "get_all_items", [&] { - auto indices_data_ptr = indices.data_ptr(); - auto weights_data_ptr = weights.data_ptr(); - int64_t item_idx = 0; - for (auto itr = cache_->begin(); itr != cache_->end(); ++itr) { - const auto key_ptr = - reinterpret_cast(itr->getKey().data()); - indices_data_ptr[item_idx] = *key_ptr; - std::copy( - reinterpret_cast(itr->getMemory()), - reinterpret_cast(itr->getMemory()) + weight_dim, - &weights_data_ptr[item_idx * weight_dim]); // dst_start - item_idx++; - } - CHECK_EQ(total_num_items, item_idx); + using value_t = scalar_t; + FBGEMM_DISPATCH_INTEGRAL_TYPES( + indices.scalar_type(), "get_all_items", [&] { + using index_t = scalar_t; + auto indices_data_ptr = indices.data_ptr(); + auto weights_data_ptr = weights.data_ptr(); + int64_t item_idx = 0; + for (auto itr = cache_->begin(); itr != cache_->end(); ++itr) { + const auto key_ptr = + reinterpret_cast(itr->getKey().data()); + indices_data_ptr[item_idx] = *key_ptr; + std::copy( + reinterpret_cast(itr->getMemory()), + reinterpret_cast(itr->getMemory()) + + weight_dim, + &weights_data_ptr[item_idx * weight_dim]); // dst_start + item_idx++; + } + CHECK_EQ(total_num_items, item_idx); + }); }); return std::make_tuple( indices, diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp index fcf04fff0..21aa00290 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp @@ -44,19 +44,24 @@ QueueItem tensor_copy( at::empty({1}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( weights.scalar_type(), "tensor_copy", [&] { - auto indices_addr = indices.data_ptr(); - auto new_indices_addr = new_indices.data_ptr(); - std::copy( - indices_addr, - indices_addr + num_sets, - new_indices_addr); // dst_start - - auto weights_addr = weights.data_ptr(); - auto new_weightss_addr = new_weights.data_ptr(); - std::copy( - weights_addr, - weights_addr + num_sets * weights.size(1), - new_weightss_addr); // dst_start + using value_t = scalar_t; + FBGEMM_DISPATCH_INTEGRAL_TYPES( + indices.scalar_type(), "tensor_copy", [&] { + using index_t = scalar_t; + auto indices_addr = indices.data_ptr(); + auto new_indices_addr = new_indices.data_ptr(); + std::copy( + indices_addr, + indices_addr + num_sets, + new_indices_addr); // dst_start + + auto weights_addr = weights.data_ptr(); + auto new_weightss_addr = new_weights.data_ptr(); + std::copy( + weights_addr, + weights_addr + num_sets * weights.size(1), + new_weightss_addr); // dst_start + }); }); *new_count.data_ptr() = num_sets; return QueueItem{new_indices, new_weights, new_count, mode}; @@ -79,7 +84,8 @@ EmbeddingKVDB::EmbeddingKVDB( cache_config.num_shards = num_shards_; cache_config.item_size_bytes = max_D_ * ele_size_bytes; cache_config.max_D_ = max_D_; - l2_cache_ = std::make_unique(cache_config); + l2_cache_ = + std::make_unique(cache_config, unique_id); } else { l2_cache_ = nullptr; } @@ -128,10 +134,15 @@ EmbeddingKVDB::~EmbeddingKVDB() { void EmbeddingKVDB::flush() { wait_util_filling_work_done(); if (l2_cache_) { - auto tensor_tuple = l2_cache_->get_all_items(); - auto& indices = std::get<0>(tensor_tuple); - auto& weights = std::get<1>(tensor_tuple); - auto& count = std::get<2>(tensor_tuple); + auto tensor_tuple_opt = l2_cache_->get_all_items(); + if (!tensor_tuple_opt.has_value()) { + XLOG(INFO) << "[TBE_ID" << unique_id_ + << "]no items exist in L2 cache, flush nothing"; + return; + } + auto& indices = std::get<0>(tensor_tuple_opt.value()); + auto& weights = std::get<1>(tensor_tuple_opt.value()); + auto& count = std::get<2>(tensor_tuple_opt.value()); folly::coro::blockingWait(set_kv_db_async( indices, weights, count, kv_db::RocksdbWriteMode::FLUSH)); } @@ -143,6 +154,7 @@ void EmbeddingKVDB::get_cuda( const at::Tensor& count) { auto rec = torch::autograd::profiler::record_function_enter_new( "## EmbeddingKVDB::get_cuda ##"); + check_tensor_type_consistency(indices, weights); // take reference to self to avoid lifetime issues. auto self = shared_from_this(); std::function* functor = @@ -163,6 +175,7 @@ void EmbeddingKVDB::set_cuda( const bool is_bwd) { auto rec = torch::autograd::profiler::record_function_enter_new( "## EmbeddingKVDB::set_cuda ##"); + check_tensor_type_consistency(indices, weights); // take reference to self to avoid lifetime issues. auto self = shared_from_this(); std::function* functor = new std::function([=]() { @@ -229,8 +242,8 @@ void EmbeddingKVDB::set( const bool is_bwd) { if (auto num_evictions = get_maybe_uvm_scalar(count); num_evictions <= 0) { XLOG_EVERY_MS(INFO, 60000) - << "[" << unique_id_ << "]skip set_cuda since number evictions is " - << num_evictions; + << "[TBE_ID" << unique_id_ + << "]skip set_cuda since number evictions is " << num_evictions; return; } auto start_ts = facebook::WallClockUtil::NowInUsecFast(); @@ -254,7 +267,7 @@ void EmbeddingKVDB::get( const at::Tensor& count) { if (auto num_lookups = get_maybe_uvm_scalar(count); num_lookups <= 0) { XLOG_EVERY_MS(INFO, 60000) - << "[" << unique_id_ << "]skip get_cuda since number lookups is " + << "[TBE_ID" << unique_id_ << "]skip get_cuda since number lookups is " << num_lookups; return; } @@ -303,63 +316,67 @@ std::shared_ptr EmbeddingKVDB::get_cache( return nullptr; } auto start_ts = facebook::WallClockUtil::NowInUsecFast(); - auto indices_addr = indices.data_ptr(); + auto num_lookups = get_maybe_uvm_scalar(count); auto cache_context = std::make_shared(num_lookups); - - auto num_shards = executor_tp_->numThreads(); - - std::vector> tasks; - std::vector> row_ids_per_shard(num_shards); - for (int i = 0; i < num_shards; i++) { - row_ids_per_shard[i].reserve(num_lookups / num_shards * 2); - } - for (uint32_t row_id = 0; row_id < num_lookups; ++row_id) { - row_ids_per_shard[l2_cache_->get_shard_id(indices_addr[row_id])] - .emplace_back(row_id); - } - for (uint32_t shard_id = 0; shard_id < num_shards; ++shard_id) { - tasks.emplace_back( - folly::coro::co_invoke( - [this, - &indices_addr, - cache_context, - shard_id, - &row_ids_per_shard]() mutable -> folly::coro::Task { - for (const auto& row_id : row_ids_per_shard[shard_id]) { - auto emb_idx = indices_addr[row_id]; - if (emb_idx < 0) { - continue; - } - auto cached_addr_opt = l2_cache_->get(emb_idx); - if (cached_addr_opt.has_value()) { // cache hit - cache_context->cached_addr_list[row_id] = - cached_addr_opt.value(); - indices_addr[row_id] = -1; // mark to sentinel value - } else { // cache miss - cache_context->num_misses += 1; + FBGEMM_DISPATCH_INTEGRAL_TYPES(indices.scalar_type(), "get_cache", [&] { + using index_t = scalar_t; + auto indices_addr = indices.data_ptr(); + auto num_shards = executor_tp_->numThreads(); + + std::vector> tasks; + std::vector> row_ids_per_shard(num_shards); + for (int i = 0; i < num_shards; i++) { + row_ids_per_shard[i].reserve(num_lookups / num_shards * 2); + } + for (uint32_t row_id = 0; row_id < num_lookups; ++row_id) { + row_ids_per_shard[l2_cache_->get_shard_id(indices_addr[row_id])] + .emplace_back(row_id); + } + for (uint32_t shard_id = 0; shard_id < num_shards; ++shard_id) { + tasks.emplace_back( + folly::coro::co_invoke( + [this, + &indices_addr, + &indices, + cache_context, + shard_id, + &row_ids_per_shard]() mutable -> folly::coro::Task { + for (const auto& row_id : row_ids_per_shard[shard_id]) { + auto emb_idx = indices_addr[row_id]; + if (emb_idx < 0) { + continue; + } + auto cached_addr_opt = l2_cache_->get(indices[row_id]); + if (cached_addr_opt.has_value()) { // cache hit + cache_context->cached_addr_list[row_id] = + cached_addr_opt.value(); + indices_addr[row_id] = -1; // mark to sentinel value + } else { // cache miss + cache_context->num_misses += 1; + } } - } - co_return; - }) - .scheduleOn(executor_tp_.get())); - } - folly::coro::blockingWait(folly::coro::collectAllRange(std::move(tasks))); - - // the following metrics added here as the current assumption is - // get_cache will only be called in get_cuda path, if assumption no longer - // true, we should wrap this up on the caller side - auto dur = facebook::WallClockUtil::NowInUsecFast() - start_ts; - get_cache_lookup_total_duration_ += dur; - auto cache_misses = cache_context->num_misses.load(); - if (num_lookups > 0) { - num_cache_misses_ += cache_misses; - num_lookups_ += num_lookups; - } else { - XLOG_EVERY_MS(INFO, 60000) - << "[" << unique_id_ - << "]num_lookups is 0, skip collecting the L2 cache miss stats"; - } + co_return; + }) + .scheduleOn(executor_tp_.get())); + } + folly::coro::blockingWait(folly::coro::collectAllRange(std::move(tasks))); + + // the following metrics added here as the current assumption is + // get_cache will only be called in get_cuda path, if assumption no longer + // true, we should wrap this up on the caller side + auto dur = facebook::WallClockUtil::NowInUsecFast() - start_ts; + get_cache_lookup_total_duration_ += dur; + auto cache_misses = cache_context->num_misses.load(); + if (num_lookups > 0) { + num_cache_misses_ += cache_misses; + num_lookups_ += num_lookups; + } else { + XLOG_EVERY_MS(INFO, 60000) + << "[TBE_ID" << unique_id_ + << "]num_lookups is 0, skip collecting the L2 cache miss stats"; + } + }); return cache_context; } @@ -373,7 +390,8 @@ void EmbeddingKVDB::wait_util_filling_work_done() { total_wait_time_ms += 1; if (total_wait_time_ms > 100) { XLOG_EVERY_MS(ERR, 1000) - << "get_cache: waiting for L2 caching filling embeddings for " + << "[TBE_ID" << unique_id_ + << "]get_cache: waiting for L2 caching filling embeddings for " << total_wait_time_ms << " ms, somethings is likely wrong"; } } @@ -397,45 +415,50 @@ EmbeddingKVDB::set_cache( auto cache_update_start_ts = facebook::WallClockUtil::NowInUsecFast(); l2_cache_->init_tensor_for_l2_eviction(indices, weights, count); - auto indices_addr = indices.data_ptr(); - const int64_t num_lookups = get_maybe_uvm_scalar(count); - auto num_shards = executor_tp_->numThreads(); - std::vector> tasks; - std::vector> row_ids_per_shard(num_shards); + FBGEMM_DISPATCH_INTEGRAL_TYPES(indices.scalar_type(), "set_cache", [&] { + using index_t = scalar_t; + auto indices_addr = indices.data_ptr(); + const int64_t num_lookups = get_maybe_uvm_scalar(count); + auto num_shards = executor_tp_->numThreads(); - for (int i = 0; i < num_shards; i++) { - row_ids_per_shard[i].reserve(num_lookups / num_shards * 2); - } - for (uint32_t row_id = 0; row_id < num_lookups; ++row_id) { - row_ids_per_shard[l2_cache_->get_shard_id(indices_addr[row_id])] - .emplace_back(row_id); - } + std::vector> tasks; + std::vector> row_ids_per_shard(num_shards); - for (uint32_t shard_id = 0; shard_id < num_shards; ++shard_id) { - tasks.emplace_back( - folly::coro::co_invoke( - [this, - &indices_addr, - &weights, - shard_id, - &row_ids_per_shard]() mutable -> folly::coro::Task { - for (const auto& row_id : row_ids_per_shard[shard_id]) { - auto emb_idx = indices_addr[row_id]; - if (emb_idx < 0) { - continue; - } - if (!l2_cache_->put(emb_idx, weights[row_id])) { - XLOG_EVERY_MS(ERR, 1000) - << "[" << unique_id_ - << "]Failed to insert into cache, this shouldn't happen"; + for (int i = 0; i < num_shards; i++) { + row_ids_per_shard[i].reserve(num_lookups / num_shards * 2); + } + for (uint32_t row_id = 0; row_id < num_lookups; ++row_id) { + row_ids_per_shard[l2_cache_->get_shard_id(indices_addr[row_id])] + .emplace_back(row_id); + } + + for (uint32_t shard_id = 0; shard_id < num_shards; ++shard_id) { + tasks.emplace_back( + folly::coro::co_invoke( + [this, + &indices_addr, + &indices, + &weights, + shard_id, + &row_ids_per_shard]() mutable -> folly::coro::Task { + for (const auto& row_id : row_ids_per_shard[shard_id]) { + auto emb_idx = indices_addr[row_id]; + if (emb_idx < 0) { + continue; + } + if (!l2_cache_->put(indices[row_id], weights[row_id])) { + XLOG_EVERY_MS(ERR, 1000) + << "[TBE_ID" << unique_id_ + << "]Failed to insert into cache, this shouldn't happen"; + } } - } - co_return; - }) - .scheduleOn(executor_tp_.get())); - } - folly::coro::blockingWait(folly::coro::collectAllRange(std::move(tasks))); + co_return; + }) + .scheduleOn(executor_tp_.get())); + } + folly::coro::blockingWait(folly::coro::collectAllRange(std::move(tasks))); + }); total_cache_update_duration_ += facebook::WallClockUtil::NowInUsecFast() - cache_update_start_ts; auto tensor_tuple_opt = l2_cache_->get_tensors_and_reset(); @@ -492,4 +515,24 @@ folly::coro::Task EmbeddingKVDB::cache_memcpy( co_return; } +void EmbeddingKVDB::check_tensor_type_consistency( + const at::Tensor& indices, + const at::Tensor& weights) { + if (index_dtype_.has_value()) { + assert(index_dtype_.value() == indices.scalar_type()); + } else { + index_dtype_ = indices.scalar_type(); + XLOG(INFO) << "[TBE_ID" << unique_id_ << "]L2 cache index dtype is " + << index_dtype_.value(); + } + + if (weights_dtype_.has_value()) { + assert(weights_dtype_.value() == weights.scalar_type()); + } else { + weights_dtype_ = weights.scalar_type(); + XLOG(INFO) << "[TBE_ID" << unique_id_ << "]L2 cache weights dtype is " + << weights_dtype_.value(); + } +} + } // namespace kv_db diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h index d63ff5418..da98ba9c3 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h @@ -279,6 +279,10 @@ class EmbeddingKVDB : public std::enable_shared_from_this { virtual void flush_or_compact(const int64_t timestep) = 0; + void check_tensor_type_consistency( + const at::Tensor& indices, + const at::Tensor& weights); + // waiting for working item queue to be empty, this is called by get_cache() // as embedding read should wait until previous write to be finished void wait_util_filling_work_done(); @@ -287,6 +291,8 @@ class EmbeddingKVDB : public std::enable_shared_from_this { const int64_t unique_id_; const int64_t num_shards_; const int64_t max_D_; + folly::Optional index_dtype_{folly::none}; + folly::Optional weights_dtype_{folly::none}; std::unique_ptr executor_tp_; std::unique_ptr cache_filling_thread_; std::atomic stop_{false}; diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 3d1f2c977..e7d32682e 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -432,36 +432,43 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { -> folly::coro::Task { FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( weights.scalar_type(), "ssd_set", [&] { - CHECK(indices.is_contiguous()); - CHECK(weights.is_contiguous()); - auto indices_acc = indices.accessor(); - auto D = weights.size(1); - CHECK_EQ(indices.size(0), weights.size(0)); - { - rocksdb::WriteBatch batch( - (2 * (count_ + dbs_.size() - 1) / dbs_.size()) * - (sizeof(int64_t) + sizeof(scalar_t) * D)); - for (auto i = 0; i < count_; ++i) { - if (indices_acc[i] < 0) { - continue; - } - if (kv_db_utils::hash_shard( - indices_acc[i], dbs_.size()) != shard) { - continue; - } - batch.Put( - rocksdb::Slice( - reinterpret_cast( - &(indices.data_ptr()[i])), - sizeof(int64_t)), - rocksdb::Slice( - reinterpret_cast( - &(weights.data_ptr()[i * D])), - D * sizeof(scalar_t))); - } - auto s = dbs_[shard]->Write(wo_, &batch); - CHECK(s.ok()); - } + using value_t = scalar_t; + FBGEMM_DISPATCH_INTEGRAL_TYPES( + indices.scalar_type(), "ssd_set", [&] { + using index_t = scalar_t; + CHECK(indices.is_contiguous()); + CHECK(weights.is_contiguous()); + auto indices_acc = indices.accessor(); + auto D = weights.size(1); + CHECK_EQ(indices.size(0), weights.size(0)); + { + rocksdb::WriteBatch batch( + (2 * (count_ + dbs_.size() - 1) / + dbs_.size()) * + (sizeof(index_t) + sizeof(value_t) * D)); + for (auto i = 0; i < count_; ++i) { + if (indices_acc[i] < 0) { + continue; + } + if (kv_db_utils::hash_shard( + indices_acc[i], dbs_.size()) != shard) { + continue; + } + batch.Put( + rocksdb::Slice( + reinterpret_cast( + &(indices.data_ptr()[i])), + sizeof(index_t)), + rocksdb::Slice( + reinterpret_cast( + &(weights + .data_ptr()[i * D])), + D * sizeof(value_t))); + } + auto s = dbs_[shard]->Write(wo_, &batch); + CHECK(s.ok()); + } + }); }); co_return; }) @@ -661,107 +668,117 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { -> folly::coro::Task { FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( weights.scalar_type(), "ssd_get", [&] { - CHECK(indices.is_contiguous()); - CHECK(weights.is_contiguous()); - auto indices_data_ptr = indices.data_ptr(); - auto D = weights.size(1); - CHECK_EQ(indices.size(0), weights.size(0)); - auto weights_data_ptr = weights.data_ptr(); - FOLLY_DECLARE_REUSED(keys, std::vector); - FOLLY_DECLARE_REUSED(shard_ids, std::vector); - FOLLY_DECLARE_REUSED( - cfs, std::vector); - FOLLY_DECLARE_REUSED( - values, std::vector); - FOLLY_DECLARE_REUSED( - statuses, std::vector); - auto* dcf = dbs_[shard]->DefaultColumnFamily(); - for (auto i = 0; i < count_; ++i) { - // "no-op"/empty evicted tensor - if (indices_data_ptr[i] == -1) { - continue; - } - if (kv_db_utils::hash_shard( - indices_data_ptr[i], dbs_.size()) != shard) { - continue; - } - shard_ids.push_back(i); - } - std::sort( - shard_ids.begin(), - shard_ids.end(), - [&](int32_t lhs, int32_t rhs) { - const auto lhs_key = rocksdb::Slice( - reinterpret_cast( - &(indices_data_ptr[lhs])), - sizeof(int64_t)); - const auto rhs_key = rocksdb::Slice( - reinterpret_cast( - &(indices_data_ptr[rhs])), - sizeof(int64_t)); - return lhs_key.compare(rhs_key) < 0; + using value_t = scalar_t; + FBGEMM_DISPATCH_INTEGRAL_TYPES( + indices.scalar_type(), "ssd_get", [&] { + using index_t = scalar_t; + CHECK(indices.is_contiguous()); + CHECK(weights.is_contiguous()); + auto indices_data_ptr = indices.data_ptr(); + auto D = weights.size(1); + CHECK_EQ(indices.size(0), weights.size(0)); + auto weights_data_ptr = weights.data_ptr(); + FOLLY_DECLARE_REUSED( + keys, std::vector); + FOLLY_DECLARE_REUSED( + shard_ids, std::vector); + FOLLY_DECLARE_REUSED( + cfs, std::vector); + FOLLY_DECLARE_REUSED( + values, std::vector); + FOLLY_DECLARE_REUSED( + statuses, std::vector); + auto* dcf = dbs_[shard]->DefaultColumnFamily(); + for (auto i = 0; i < count_; ++i) { + // "no-op"/empty evicted tensor + if (indices_data_ptr[i] == -1) { + continue; + } + if (kv_db_utils::hash_shard( + indices_data_ptr[i], dbs_.size()) != + shard) { + continue; + } + shard_ids.push_back(i); + } + std::sort( + shard_ids.begin(), + shard_ids.end(), + [&](int32_t lhs, int32_t rhs) { + const auto lhs_key = rocksdb::Slice( + reinterpret_cast( + &(indices_data_ptr[lhs])), + sizeof(index_t)); + const auto rhs_key = rocksdb::Slice( + reinterpret_cast( + &(indices_data_ptr[rhs])), + sizeof(index_t)); + return lhs_key.compare(rhs_key) < 0; + }); + for (const auto& i : shard_ids) { + const auto key = rocksdb::Slice( + reinterpret_cast( + &(indices_data_ptr[i])), + sizeof(index_t)); + keys.push_back(key); + cfs.push_back(dcf); + } + CHECK_EQ(shard_ids.size(), keys.size()); + CHECK_EQ(shard_ids.size(), cfs.size()); + + values.resize(keys.size()); + statuses.resize(keys.size()); + // Set a snapshot if it is available + ro_.snapshot = snapshot; + dbs_[shard]->MultiGet( + ro_, + keys.size(), + cfs.data(), + keys.data(), + values.data(), + statuses.data(), + /*sorted_input=*/true); + const auto& init_storage = + initializers_[shard]->row_storage_; + // Sanity check + TORCH_CHECK( + init_storage.scalar_type() == + weights.scalar_type(), + "init_storage (", + toString(init_storage.scalar_type()), + ") and weights scalar (", + toString(weights.scalar_type()), + ") types mismatch"); + auto row_storage_data_ptr = + init_storage.data_ptr(); + for (auto j = 0; j < keys.size(); ++j) { + const auto& s = statuses[j]; + int64_t i = shard_ids[j]; + const auto& value = values[j]; + if (s.ok()) { + if (!std::is_same::value) { + CHECK_EQ(value.size(), D * sizeof(value_t)); + } + std::copy( + reinterpret_cast( + value.data()), + reinterpret_cast( + value.data() + value.size()), + &(weights_data_ptr[i * D])); + } else { + CHECK(s.IsNotFound()); + int64_t row_index; + initializers_[shard]->producer_queue_.dequeue( + row_index); + std::copy( + &(row_storage_data_ptr[row_index * D]), + &(row_storage_data_ptr[row_index * D + D]), + &(weights_data_ptr[i * D])); + initializers_[shard]->consumer_queue_.enqueue( + row_index); + } + } }); - for (const auto& i : shard_ids) { - const auto key = rocksdb::Slice( - reinterpret_cast( - &(indices_data_ptr[i])), - sizeof(int64_t)); - keys.push_back(key); - cfs.push_back(dcf); - } - CHECK_EQ(shard_ids.size(), keys.size()); - CHECK_EQ(shard_ids.size(), cfs.size()); - - values.resize(keys.size()); - statuses.resize(keys.size()); - // Set a snapshot if it is available - ro_.snapshot = snapshot; - dbs_[shard]->MultiGet( - ro_, - keys.size(), - cfs.data(), - keys.data(), - values.data(), - statuses.data(), - /*sorted_input=*/true); - const auto& init_storage = - initializers_[shard]->row_storage_; - // Sanity check - TORCH_CHECK( - init_storage.scalar_type() == weights.scalar_type(), - "init_storage (", - toString(init_storage.scalar_type()), - ") and weights scalar (", - toString(weights.scalar_type()), - ") types mismatch"); - auto row_storage_data_ptr = - init_storage.data_ptr(); - for (auto j = 0; j < keys.size(); ++j) { - const auto& s = statuses[j]; - int64_t i = shard_ids[j]; - const auto& value = values[j]; - if (s.ok()) { - if (!std::is_same::value) { - CHECK_EQ(value.size(), D * sizeof(scalar_t)); - } - std::copy( - reinterpret_cast(value.data()), - reinterpret_cast( - value.data() + value.size()), - &(weights_data_ptr[i * D])); - } else { - CHECK(s.IsNotFound()); - int64_t row_index; - initializers_[shard]->producer_queue_.dequeue( - row_index); - std::copy( - &(row_storage_data_ptr[row_index * D]), - &(row_storage_data_ptr[row_index * D + D]), - &(weights_data_ptr[i * D])); - initializers_[shard]->consumer_queue_.enqueue( - row_index); - } - } }); co_return; }) diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py index 7cc75d9e9..61ac98d6e 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py @@ -94,15 +94,23 @@ def get_physical_table_arg_indices_(self, feature_table_map: List[int]): @given( weights_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]), + indice_int64_t=st.sampled_from([True, False]), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) - def test_ssd(self, weights_precision: SparseType) -> None: + def test_ssd(self, indice_int64_t: bool, weights_precision: SparseType) -> None: import tempfile E = int(1e4) D = 128 N = 100 - indices = torch.as_tensor(np.random.choice(E, replace=False, size=(N,))) + if indice_int64_t: + indices = torch.as_tensor( + np.random.choice(E, replace=False, size=(N,)), dtype=torch.int64 + ) + else: + indices = torch.as_tensor( + np.random.choice(E, replace=False, size=(N,)), dtype=torch.int32 + ) weights = torch.randn(N, D, dtype=weights_precision.as_dtype()) output_weights = torch.empty_like(weights) count = torch.tensor([N]) From 904a1c6de3f85c9e5693bb12fdb5190905136013 Mon Sep 17 00:00:00 2001 From: Joe Wang Date: Thu, 19 Sep 2024 09:56:36 -0700 Subject: [PATCH 25/27] add lock for l2 cache set/get (#3153) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/248 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3153 In non pipelining mode, the sequence is - get_cuda(): L2 read and insert L2 cache misses into queue for bg L2 write - L1 cache eviction: insert into bg queue for L2 write - ScratchPad update: insert into bg queue for L2 write in non-prefetch pipeline, cuda synchronization guarantee get_cuda() happen after SP update in prefetch pipeline, cuda sync only guarantee get_cuda() happen after L1 cache eviction pipeline case, SP bwd update could happen in parallel with L2 read lock is used for l2 cache to do read / write exclusively add unittest to capture L2 cache functionality and the cases discussed above Reviewed By: q10 Differential Revision: D63010906 fbshipit-source-id: 3951ce138acb53da4f7aba01a03c46409a6fc630 --- fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py | 20 +- .../kv_db_table_batched_embeddings.cpp | 32 ++- .../kv_db_table_batched_embeddings.h | 31 ++- .../ssd_split_table_batched_embeddings.cpp | 27 +- fbgemm_gpu/test/tbe/common.py | 10 +- fbgemm_gpu/test/tbe/ssd/ssd_l2_cache_test.py | 231 ++++++++++++++++++ 6 files changed, 337 insertions(+), 14 deletions(-) create mode 100644 fbgemm_gpu/test/tbe/ssd/ssd_l2_cache_test.py diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index f0e16b840..e2d35072f 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -1850,8 +1850,8 @@ def _report_l2_cache_perf_stats(self) -> None: self.step, stats_reporter.report_interval # pyre-ignore ) - if len(l2_cache_perf_stats) != 13: - logging.error("l2 perf stats should have 13 elements") + if len(l2_cache_perf_stats) != 15: + logging.error("l2 perf stats should have 15 elements") return num_cache_misses = l2_cache_perf_stats[0] @@ -1869,6 +1869,9 @@ def _report_l2_cache_perf_stats(self) -> None: l2_cache_free_bytes = l2_cache_perf_stats[11] l2_cache_capacity = l2_cache_perf_stats[12] + set_cache_lock_wait_duration = l2_cache_perf_stats[13] + get_cache_lock_wait_duration = l2_cache_perf_stats[14] + stats_reporter.report_data_amount( iteration_step=self.step, event_name=self.l2_num_cache_misses_stats_name, @@ -1944,6 +1947,19 @@ def _report_l2_cache_perf_stats(self) -> None: time_unit="us", ) + stats_reporter.report_duration( + iteration_step=self.step, + event_name="l2_cache.perf.get.cache_lock_wait_duration_us", + duration_ms=get_cache_lock_wait_duration, + time_unit="us", + ) + stats_reporter.report_duration( + iteration_step=self.step, + event_name="l2_cache.perf.set.cache_lock_wait_duration_us", + duration_ms=set_cache_lock_wait_duration, + time_unit="us", + ) + # pyre-ignore def _recording_to_timer( self, timer: Optional[AsyncSeriesTimer], **kwargs: Any diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp index 21aa00290..fdc91b0e9 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp @@ -193,7 +193,7 @@ void EmbeddingKVDB::set_cuda( std::vector EmbeddingKVDB::get_l2cache_perf( const int64_t step, const int64_t interval) { - std::vector ret(13, 0); // num metrics + std::vector ret(15, 0); // num metrics if (step > 0 && step % interval == 0) { int reset_val = 0; auto num_cache_misses = num_cache_misses_.exchange(reset_val); @@ -215,6 +215,12 @@ std::vector EmbeddingKVDB::get_l2cache_perf( get_tensor_copy_for_cache_update_.exchange(reset_val); auto set_tensor_copy_for_cache_update_dur = set_tensor_copy_for_cache_update_.exchange(reset_val); + + auto set_cache_lock_wait_duration = + set_cache_lock_wait_duration_.exchange(reset_val); + auto get_cache_lock_wait_duration = + get_cache_lock_wait_duration_.exchange(reset_val); + ret[0] = (double(num_cache_misses) / interval); ret[1] = (double(num_lookups) / interval); ret[2] = (double(get_total_duration) / interval); @@ -231,10 +237,16 @@ std::vector EmbeddingKVDB::get_l2cache_perf( ret[11] = (cache_mem_stats[0]); // free cache in bytes ret[12] = (cache_mem_stats[1]); // total cache capacity in bytes } + ret[13] = (double(set_cache_lock_wait_duration) / interval); + ret[14] = (double(get_cache_lock_wait_duration) / interval); } return ret; } +void EmbeddingKVDB::reset_l2_cache() { + l2_cache_ = nullptr; +} + void EmbeddingKVDB::set( const at::Tensor& indices, const at::Tensor& weights, @@ -264,7 +276,8 @@ void EmbeddingKVDB::set( void EmbeddingKVDB::get( const at::Tensor& indices, const at::Tensor& weights, - const at::Tensor& count) { + const at::Tensor& count, + int64_t sleep_ms) { if (auto num_lookups = get_maybe_uvm_scalar(count); num_lookups <= 0) { XLOG_EVERY_MS(INFO, 60000) << "[TBE_ID" << unique_id_ << "]skip get_cuda since number lookups is " @@ -274,6 +287,17 @@ void EmbeddingKVDB::get( ASSERT_EQ(max_D_, weights.size(1)); auto start_ts = facebook::WallClockUtil::NowInUsecFast(); wait_util_filling_work_done(); + + std::unique_lock lock(l2_cache_mtx_); + get_cache_lock_wait_duration_ += + facebook::WallClockUtil::NowInUsecFast() - start_ts; + + // this is for unittest to repro synchronization situation deterministically + if (sleep_ms > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(sleep_ms)); + XLOG(INFO) << "get sleep end"; + } + auto cache_context = get_cache(indices, count); if (cache_context != nullptr) { if (cache_context->num_misses > 0) { @@ -407,12 +431,14 @@ EmbeddingKVDB::set_cache( if (l2_cache_ == nullptr) { return folly::none; } - // TODO: consider whether need to reconstruct indices/weights/count and free // the original tensor since most of the tensor elem will be invalid, // this will trade some perf for peak DRAM util saving auto cache_update_start_ts = facebook::WallClockUtil::NowInUsecFast(); + std::unique_lock lock(l2_cache_mtx_); + set_cache_lock_wait_duration_ += + facebook::WallClockUtil::NowInUsecFast() - cache_update_start_ts; l2_cache_->init_tensor_for_l2_eviction(indices, weights, count); diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h index da98ba9c3..93495f2da 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h @@ -171,13 +171,17 @@ class EmbeddingKVDB : public std::enable_shared_from_this { /// relative element in /// @param count A single element tensor that contains the number of indices /// to be processed + /// @param sleep_ms this is used to specifically sleep in get function, this + /// is needed to reproduce synchronization situation deterministicly, in prod + /// case this will be 0 for sure /// /// @return None /// @note weights will be updated from either L2 cache or storage tier void get( const at::Tensor& indices, const at::Tensor& weights, - const at::Tensor& count); + const at::Tensor& count, + int64_t sleep_ms = 0); /// storage tier counterpart of function get() virtual folly::coro::Task get_kv_db_async( @@ -227,6 +231,14 @@ class EmbeddingKVDB : public std::enable_shared_from_this { const int64_t step, const int64_t interval); + // reset L2 cache, this is used for unittesting to bypass l2 cache + void reset_l2_cache(); + + // block waiting for working items in queue to be finished, this is called by + // get_cache() as embedding read should wait until previous write to be + // finished, it could also be called in unitest to sync + void wait_util_filling_work_done(); + private: /// Find non-negative embedding indices in and shard them into /// #cachelib_pools pieces to be lookedup in parallel @@ -283,10 +295,6 @@ class EmbeddingKVDB : public std::enable_shared_from_this { const at::Tensor& indices, const at::Tensor& weights); - // waiting for working item queue to be empty, this is called by get_cache() - // as embedding read should wait until previous write to be finished - void wait_util_filling_work_done(); - std::unique_ptr l2_cache_; const int64_t unique_id_; const int64_t num_shards_; @@ -299,6 +307,17 @@ class EmbeddingKVDB : public std::enable_shared_from_this { // buffer queue that stores all the needed indices/weights/action_count to // fill up cache folly::USPSCQueue weights_to_fill_queue_; + // In non pipelining mode, the sequence is + // - get_cuda(): L2 read and insert L2 cache misses into queue for + // bg L2 write + // - L1 cache eviction: insert into bg queue for L2 write + // - ScratchPad update: insert into bg queue for L2 write + // in non-prefetch pipeline, cuda synchronization guarantee get_cuda() happen + // after SP update + // in prefetch pipeline, cuda sync only guarantee get_cuda() happen after L1 + // cache eviction pipeline case, SP bwd update could happen in parallel with + // L2 read mutex is used for l2 cache to do read / write exclusively + std::mutex l2_cache_mtx_; // perf stats // -- perf of get() function @@ -313,9 +332,11 @@ class EmbeddingKVDB : public std::enable_shared_from_this { std::atomic get_weights_fillup_total_duration_{0}; std::atomic get_cache_memcpy_duration_{0}; std::atomic get_tensor_copy_for_cache_update_{0}; + std::atomic get_cache_lock_wait_duration_{0}; // -- perf of set() function std::atomic set_tensor_copy_for_cache_update_{0}; + std::atomic set_cache_lock_wait_duration_{0}; // -- commone path std::atomic total_cache_update_duration_{0}; diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index 1e91cefb5..6238fe1be 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -315,8 +315,8 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { return impl_->set(indices, weights, count); } - void get(Tensor indices, Tensor weights, Tensor count) { - return impl_->get(indices, weights, count); + void get(Tensor indices, Tensor weights, Tensor count, int64_t sleep_ms) { + return impl_->get(indices, weights, count, sleep_ms); } std::vector get_mem_usage() { @@ -343,6 +343,14 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { return impl_->flush(); } + void reset_l2_cache() { + return impl_->reset_l2_cache(); + } + + void wait_util_filling_work_done() { + return impl_->wait_util_filling_work_done(); + } + private: // shared pointer since we use shared_from_this() in callbacks. std::shared_ptr impl_; @@ -413,7 +421,20 @@ static auto embedding_rocks_db_wrapper = &EmbeddingRocksDBWrapper::get_rocksdb_io_duration) .def("get_l2cache_perf", &EmbeddingRocksDBWrapper::get_l2cache_perf) .def("set", &EmbeddingRocksDBWrapper::set) - .def("get", &EmbeddingRocksDBWrapper::get); + .def( + "get", + &EmbeddingRocksDBWrapper::get, + "", + { + torch::arg("indices"), + torch::arg("weights"), + torch::arg("count"), + torch::arg("sleep_ms") = 0, + }) + .def("reset_l2_cache", &EmbeddingRocksDBWrapper::reset_l2_cache) + .def( + "wait_util_filling_work_done", + &EmbeddingRocksDBWrapper::wait_util_filling_work_done); TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( diff --git a/fbgemm_gpu/test/tbe/common.py b/fbgemm_gpu/test/tbe/common.py index df38e15de..40f1b49e7 100644 --- a/fbgemm_gpu/test/tbe/common.py +++ b/fbgemm_gpu/test/tbe/common.py @@ -17,8 +17,16 @@ # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. open_source: bool = getattr(fbgemm_gpu, "open_source", False) -if not open_source: +if open_source: + # pyre-ignore[21] + from test_utils import gpu_unavailable, running_on_github +else: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:cumem_utils") + from fbgemm_gpu.test.test_utils import ( # noqa F401 + gpu_unavailable, + running_on_github, + ) + torch.ops.import_module("fbgemm_gpu.sparse_ops") settings.register_profile("derandomize", derandomize=True) diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_l2_cache_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_l2_cache_test.py new file mode 100644 index 000000000..ef5a9f0a2 --- /dev/null +++ b/fbgemm_gpu/test/tbe/ssd/ssd_l2_cache_test.py @@ -0,0 +1,231 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +# pyre-ignore-all-errors[3,6,56] + +import tempfile + +import threading +import time +import unittest + +from typing import Any, Dict, List, Tuple + +import hypothesis.strategies as st +import numpy as np +import torch +from fbgemm_gpu.split_embedding_configs import SparseType +from fbgemm_gpu.tbe.ssd import SSDTableBatchedEmbeddingBags +from fbgemm_gpu.tbe.utils import round_up +from hypothesis import given, settings, Verbosity + +from .. import common # noqa E402 +from ..common import gpu_unavailable, running_on_github + +MAX_EXAMPLES = 20 +default_st: Dict[str, Any] = { + "T": st.integers(min_value=1, max_value=10), + "D": st.integers(min_value=2, max_value=128), + "log_E": st.integers(min_value=2, max_value=3), + "mixed": st.booleans(), + "weights_precision": st.sampled_from([SparseType.FP32, SparseType.FP16]), +} + +default_settings: Dict[str, Any] = { + "verbosity": Verbosity.verbose, + "max_examples": MAX_EXAMPLES, + "deadline": None, +} + + +@unittest.skipIf(*running_on_github) +@unittest.skipIf(*gpu_unavailable) +class SSDCheckpointTest(unittest.TestCase): + def generate_fbgemm_ssd_tbe( + self, + T: int, + D: int, + log_E: int, + weights_precision: SparseType, + mixed: bool, + enable_l2: bool = True, + ) -> Tuple[SSDTableBatchedEmbeddingBags, List[int], List[int], int]: + E = int(10**log_E) + D = D * 4 + if not mixed: + Ds = [D] * T + Es = [E] * T + else: + Ds = [ + round_up(np.random.randint(low=int(0.25 * D), high=int(1.0 * D)), 4) + for _ in range(T) + ] + Es = [ + np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T) + ] + + feature_table_map = list(range(T)) + emb = SSDTableBatchedEmbeddingBags( + embedding_specs=[(E, D) for (E, D) in zip(Es, Ds)], + feature_table_map=feature_table_map, + ssd_storage_directory=tempfile.mkdtemp(), + cache_sets=1, + ssd_uniform_init_lower=-0.1, + ssd_uniform_init_upper=0.1, + weights_precision=weights_precision, + l2_cache_size=1 if enable_l2 else 0, + ) + return emb, Es, Ds, max(Ds) + + # @given(**default_st, do_flush=st.sampled_from([True, False])) + # @settings(**default_settings) + # def test_l2_flush( + # self, + # T: int, + # D: int, + # log_E: int, + # mixed: bool, + # weights_precision: SparseType, + # do_flush: bool, + # ) -> None: + # emb, Es, Ds, max_D = self.generate_fbgemm_ssd_tbe( + # T, D, log_E, weights_precision, mixed + # ) + # indices = torch.arange(start=0, end=sum(Es)) + # weights = torch.randn( + # indices.numel(), max_D, dtype=weights_precision.as_dtype() + # ) + # weights_from_l2 = torch.empty_like(weights) + # count = torch.as_tensor([indices.numel()]) + # emb.ssd_db.set_cuda(indices, weights, count, 1) + # emb.ssd_db.get_cuda(indices.clone(), weights_from_l2, count) + + # torch.cuda.synchronize() + # assert torch.equal(weights, weights_from_l2) + # import logging + + # logging.info(f"wgqtest {do_flush=}") + # weights_from_ssd = torch.empty_like(weights) + # if do_flush: + # emb.ssd_db.flush() + # emb.ssd_db.reset_l2_cache() + # emb.ssd_db.get_cuda(indices, weights_from_ssd, count) + # torch.cuda.synchronize() + # if do_flush: + # assert torch.equal(weights, weights_from_ssd) + # else: + # assert not torch.equal(weights, weights_from_ssd) + + # @given(**default_st, enable_l2=st.sampled_from([True, False])) + # @settings(**default_settings) + # def test_l2_io( + # self, + # T: int, + # D: int, + # log_E: int, + # mixed: bool, + # weights_precision: SparseType, + # enable_l2: bool, + # ) -> None: + # emb, Es, Ds, max_D = self.generate_fbgemm_ssd_tbe( + # T, D, log_E, weights_precision, mixed, enable_l2 + # ) + # E = int(10**log_E) + # num_rounds = 10 + # N = E + # total_indices = torch.tensor([]) + + # indices = torch.as_tensor( + # np.random.choice(E, replace=False, size=(N,)), dtype=torch.int64 + # ) + # weights = torch.randn( + # indices.numel(), max_D, dtype=weights_precision.as_dtype() + # ) + # sub_N = N // num_rounds + + # for _ in range(num_rounds): + # sub_indices = torch.as_tensor( + # np.random.choice(E, replace=False, size=(sub_N,)), dtype=torch.int64 + # ) + # sub_weights = weights[sub_indices, :] + # sub_weights_out = torch.empty_like(sub_weights) + # count = torch.as_tensor([sub_indices.numel()]) + # emb.ssd_db.set_cuda(sub_indices, sub_weights, count, 1) + # emb.ssd_db.get_cuda(sub_indices.clone(), sub_weights_out, count) + # torch.cuda.synchronize() + # assert torch.equal(sub_weights, sub_weights_out) + # total_indices = torch.cat((total_indices, sub_indices)) + # # dedup + # used_unique_indices = torch.tensor( + # list(set(total_indices.tolist())), dtype=torch.int64 + # ) + # stored_weights = weights[used_unique_indices, :] + # weights_out = torch.empty_like(stored_weights) + # count = torch.as_tensor([used_unique_indices.numel()]) + # emb.ssd_db.get_cuda(used_unique_indices.clone(), weights_out, count) + # torch.cuda.synchronize() + # assert torch.equal(stored_weights, weights_out) + + # emb.ssd_db.flush() + # emb.ssd_db.reset_l2_cache() + # weights_out = torch.empty_like(stored_weights) + # count = torch.as_tensor([used_unique_indices.numel()]) + # emb.ssd_db.get_cuda(used_unique_indices.clone(), weights_out, count) + # torch.cuda.synchronize() + # assert torch.equal(stored_weights, weights_out) + + @given(**default_st) + @settings(**default_settings) + def test_l2_prefetch_compatibility( + self, + T: int, + D: int, + log_E: int, + mixed: bool, + weights_precision: SparseType, + ) -> None: + weights_precision: SparseType = SparseType.FP32 + emb, Es, Ds, max_D = self.generate_fbgemm_ssd_tbe( + T, D, log_E, weights_precision, mixed + ) + E = int(10**log_E) + N = E + indices = torch.as_tensor( + np.random.choice(E, replace=False, size=(N,)), dtype=torch.int64 + ) + weights = torch.randn(N, max_D, dtype=weights_precision.as_dtype()) + new_weights = weights + 1 + weights_out = torch.empty_like(weights) + count = torch.as_tensor([E]) + emb.ssd_db.set(indices, weights, count) + emb.ssd_db.wait_util_filling_work_done() + + event = threading.Event() + get_sleep_ms = 50 + + # pyre-ignore + def trigger_get() -> None: + event.set() + emb.ssd_db.get(indices.clone(), weights_out, count, get_sleep_ms) + + # pyre-ignore + def trigger_set() -> None: + event.wait() + time.sleep( + get_sleep_ms / 1000.0 / 2 + ) # sleep half of the sleep time in get, making sure set is trigger after get but before get is done + emb.ssd_db.set(indices, new_weights, count) + + thread1 = threading.Thread(target=trigger_get) + thread2 = threading.Thread(target=trigger_set) + thread1.start() + thread2.start() + thread1.join() + thread2.join() + assert torch.equal(weights, weights_out) + emb.ssd_db.get(indices.clone(), weights_out, count) + assert torch.equal(new_weights, weights_out) From 03773083e5a3f23c9d42f32be16e92734e287a9e Mon Sep 17 00:00:00 2001 From: Supadchaya Puangpontip Date: Thu, 19 Sep 2024 13:00:02 -0700 Subject: [PATCH 26/27] Add schema compatibility test (#3130) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/217 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3130 To ensure that changes to the ops are forward and backward compatible with the stable release, we add unit tests to test schema compatibility. **Usage**: ``` check_schema_compatibility_from_op_name( namespace: Callable, op_name: str ref_schema_str: str, ) check_schema_compatibility( op: Callable, ref_schema_str: str, ) ``` e.g., ``` check_schema_compatibility_from_op_name( torch.ops.fbgemm, "merge_pooled_embeddings", "fbgemm::merge_pooled_embeddings(Tensor[] pooled_embeddings, SymInt uncat_dim_size, Device target_device, SymInt cat_dim=1) -> Tensor" ) check_schema_compatibility( fbgemm_gpu.sparse_ops.merge_pooled_embeddings, "merge_pooled_embeddings(Tensor[] pooled_embeddings, int uncat_dim_size, Device target_device, int cat_dim=1) -> Tensor", ) ``` Reviewed By: q10 Differential Revision: D61766648 fbshipit-source-id: dac52b88834331a466e7165812def1a3fe4c0804 --- fbgemm_gpu/test/release/__init__.py | 6 + fbgemm_gpu/test/release/example.json | 12 + fbgemm_gpu/test/release/stable_ops.json | 30 +++ .../test/release/stable_release_test.py | 186 +++++++++++++ fbgemm_gpu/test/release/utils.py | 245 ++++++++++++++++++ 5 files changed, 479 insertions(+) create mode 100644 fbgemm_gpu/test/release/__init__.py create mode 100644 fbgemm_gpu/test/release/example.json create mode 100644 fbgemm_gpu/test/release/stable_ops.json create mode 100755 fbgemm_gpu/test/release/stable_release_test.py create mode 100644 fbgemm_gpu/test/release/utils.py diff --git a/fbgemm_gpu/test/release/__init__.py b/fbgemm_gpu/test/release/__init__.py new file mode 100644 index 000000000..a9fdb3b99 --- /dev/null +++ b/fbgemm_gpu/test/release/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/fbgemm_gpu/test/release/example.json b/fbgemm_gpu/test/release/example.json new file mode 100644 index 000000000..50324b9a3 --- /dev/null +++ b/fbgemm_gpu/test/release/example.json @@ -0,0 +1,12 @@ +{ + "_description": "This is a dict containing example schemas. The schema of future releases need to be backward and forward compatible. For more details, please see https://docs.google.com/document/d/18I0lSkyHHqJ5BY30bx8YhpQHAMOg25nAFV2zeO8PIGk/edit#heading=h.y00l3f1ht5u1", + "_version": 1, + "data": { + "mx4_to_fp32": + "mx4_to_fp32(Tensor tensor, int group_size=32, bool use_triton=True, int ebits=2, int mbits=1) -> Tensor", + "merge_pooled_embeddings": + "merge_pooled_embeddings(Tensor[] pooled_embeddings, int uncat_dim_size, Device target_device, int cat_dim=1) -> Tensor", + "dummy_func": + "dummy_func(str var1, int var2) -> ()" + } +} diff --git a/fbgemm_gpu/test/release/stable_ops.json b/fbgemm_gpu/test/release/stable_ops.json new file mode 100644 index 000000000..d5fe76a24 --- /dev/null +++ b/fbgemm_gpu/test/release/stable_ops.json @@ -0,0 +1,30 @@ +{ + "_description": "This is a dict containing schema of FBGEMM_GPU ops that are marked as stable. The schema of future releases need to be backward and forward compatible. For more details, please see https://docs.google.com/document/d/18I0lSkyHHqJ5BY30bx8YhpQHAMOg25nAFV2zeO8PIGk/edit#heading=h.y00l3f1ht5u1", + "_version": 1, + "data": { + "torch.ops.fbgemm.jagged_to_padded_dense": + "fbgemm::jagged_to_padded_dense(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value = 0) -> Tensor", + "torch.ops.fbgemm.merge_pooled_embeddings": + "fbgemm::merge_pooled_embeddings(Tensor[] pooled_embeddings, SymInt uncat_dim_size, Device target_device, SymInt cat_dim=1) -> Tensor", + "torch.ops.fbgemm.permute_pooled_embs_auto_grad": + "fbgemm::permute_pooled_embs_auto_grad(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor", + "torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf": + "fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(Tensor input, int bit_rate) -> Tensor", + "torch.ops.fbgemm.permute_2D_sparse_data": + "fbgemm::permute_2D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)", + "torch.ops.fbgemm.permute_1D_sparse_data": + "fbgemm::permute_1D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)", + "torch.ops.fbgemm.expand_into_jagged_permute": + "fbgemm::expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, SymInt output_size) -> Tensor", + "torch.ops.fbgemm.block_bucketize_sparse_features": + "fbgemm::block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None, bool keep_orig_idx=False) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)", + "torch.ops.fbgemm.asynchronous_complete_cumsum": + "fbgemm::asynchronous_complete_cumsum(Tensor t_in) -> Tensor", + "torch.ops.fbgemm.offsets_range": + "fbgemm::offsets_range(Tensor offsets, SymInt range_size) -> Tensor", + "torch.ops.fbgemm.segment_sum_csr": + "fbgemm::segment_sum_csr(SymInt batch_size, Tensor csr_seg, Tensor values) -> Tensor", + "torch.ops.fbgemm.keyed_jagged_index_select_dim1": + "fbgemm::keyed_jagged_index_select_dim1(Tensor values, Tensor lengths, Tensor offsets, Tensor indices, SymInt batch_size, Tensor? weights=None, SymInt? selected_lengths_sum=None) -> Tensor[]" + } +} diff --git a/fbgemm_gpu/test/release/stable_release_test.py b/fbgemm_gpu/test/release/stable_release_test.py new file mode 100755 index 000000000..7055fd566 --- /dev/null +++ b/fbgemm_gpu/test/release/stable_release_test.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import json +import os +import unittest +from typing import Callable + +import fbgemm_gpu +import fbgemm_gpu.permute_pooled_embedding_modules +import fbgemm_gpu.sparse_ops + +import torch +from torch._C import FunctionSchema, parse_schema +from torch._utils_internal import get_file_path_2 + +from .utils import infer_schema + +# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. +open_source: bool = getattr(fbgemm_gpu, "open_source", False) + +if open_source: + from test_utils import TestSuite # pyre-fixme[21] + +else: + # pyre-fixme[21] + from fbgemm_gpu.test.test_utils import TestSuite + + +def _check_schema_compatibility( + schema: FunctionSchema, + ref_schema_str: str, +) -> None: + """ + Check if the schema is forward and backward compatible with the reference schema. + This function will raise an Exception error if the schema is not compatible. + + Args: + schema (FunctionSchema): The schema object to check. + ref_schema_str (str): The reference schema in string format. + Returns: + None + """ + assert isinstance(schema, FunctionSchema) + ref_schema = parse_schema(ref_schema_str) + # pyre-fixme[16] + fwd_compatible = schema.check_forward_compatible_with(ref_schema) + # pyre-fixme[16] + bwd_compatible = schema.is_backward_compatible_with(ref_schema) + msg = "" + if not fwd_compatible: + msg += f"Schema of {schema} is not forward compatible with {ref_schema}\n" + # pyre-fixme[16] + if not bwd_compatible: + msg += f"Schema of {schema} is not backward compatible with {ref_schema}" + assert fwd_compatible and bwd_compatible, msg + + +def check_schema_compatibility( + op: Callable, + ref_schema: str, +) -> None: + """ + Check if the schema of the given operator is forward and backward compatible with the reference schema. + This works with python functions whose schema do NOT have positional-only args, varargs, or varkwargs + For ops registered via torch.ops.fbgemm and ops with *args and **kwargs, please use check_schema_compatibility_from_op_name. + + Args: + op (Callable): The operator to check. + ref_schema (str): The reference schema in string format. + Returns: + None + """ + op_schema = infer_schema(op, mutates_args={}) + # pyre-fixme[16] + op_name = op.__name__ + # Create schema string + schema_str = f"{op_name}{op_schema}" + # Create FunctionalSchema + functional_schema = parse_schema(schema_str) + + # Get stable schema to compare against + return _check_schema_compatibility(functional_schema, ref_schema) + + +def check_schema_compatibility_from_op_name( + namespace: Callable, + op_name: str, + ref_schema_str: str, +) -> None: + """ + Check if the schema of the given operator is forward and backward compatible with the reference schema. + Use this function to check registered ops (via torch.ops.fbgemm). + This function will raise an Exception error if the schema is not compatible. + + Args: + namespace (Callable): The namespace of the operator e.g., torch.ops.fbgemm. + op_name (str): The name of the operator. + ref_schema_str (str): The reference schema in string format. + Returns: + None + """ + op = getattr(namespace, op_name) + schema = op._schemas[""] + + return _check_schema_compatibility(schema, ref_schema_str) + + +class StableRelease(TestSuite): # pyre-ignore[11] + def test_stable_schema(self) -> None: + """ + Test the schema compatibility of the operators against stable schema. + This is to ensure that any changes to the ops' schema do not break compatibility of the stable versions. + This test will fail if the current op schema is not forward or backward compatible with the stable schema. + """ + + # Load stable ops from file into dict + stable_dict_file = open( + get_file_path_2("", os.path.dirname(__file__), "stable_ops.json") + ) + stable_op_dict = json.load(stable_dict_file)["data"] + stable_dict_file.close() + # Get all op names + stable_op_names = set(stable_op_dict.keys()) + + # Check compatibility for all ops that are marked stable + for full_op_name in stable_op_names: + # Test the schema given the op name + ref_schema_str = stable_op_dict[full_op_name] + op_name = full_op_name.split(".")[3] + + check_schema_compatibility_from_op_name( + torch.ops.fbgemm, op_name, ref_schema_str + ) + + def test_example_ops(self) -> None: + """ + Test examples for schema compatibility. + """ + + # Load example ops to dict + stable_dict_file = open( + get_file_path_2("", os.path.dirname(__file__), "example.json") + ) + op_dict = json.load(stable_dict_file)["data"] + stable_dict_file.close() + + # Example op 1 + # Expect to pass + check_schema_compatibility( + fbgemm_gpu.sparse_ops.merge_pooled_embeddings, + op_dict["merge_pooled_embeddings"], + ) + + # Example op 2 + # stable schema is: dummy_func(str var1, int var2) -> ()" + def dummy_func(var1: str, var2: int, var3: torch.Tensor) -> None: + pass + + # Expect to fail + with self.assertRaises(AssertionError): # pyre-fixme[16] + check_schema_compatibility( + dummy_func, + op_dict["dummy_func"], + ) + + # Example op 3 + # stable schema is: dummy_func(str var1, int var2) -> ()" + def dummy_func(var1: str, var2: int, var3: str = "default") -> None: + pass + + # Expect to pass + check_schema_compatibility( + dummy_func, + op_dict["dummy_func"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/fbgemm_gpu/test/release/utils.py b/fbgemm_gpu/test/release/utils.py new file mode 100644 index 000000000..005cf38b5 --- /dev/null +++ b/fbgemm_gpu/test/release/utils.py @@ -0,0 +1,245 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +import typing +from typing import Iterable, List, Optional, Sequence, Union # noqa: F401 + +import torch +from torch import device, dtype, Tensor, types + +from torch._library.infer_schema import ( + derived_types, + parse_return, + supported_param, + SUPPORTED_PARAM_TYPES, + tuple_to_list, +) + +# Temporary work around for `infer_schema` + +# `get_supported_param_types` and `SUPPORTED_RETURN_TYPES` are modified from torch/_library/infer_schema.py +# as `torch.library.infer_schema` infers any `int` to be `SymInt` in the schema and does not +# support `str` as return type, which may not reflect the actual signature of the function. +# Other modifications are to address linter warning. +# The rest of the code is copied from `torch/_library/infer_schema.py` +# TO DO: clean up and remove this when we implement our own + + +def error_fn(what: str, sig: Optional[inspect.Signature] = None): + raise ValueError(f"infer_schema(func): {what} " f"Got func with signature {sig})") + + +def convert_type_string(annotation_type: str): + try: + return eval(annotation_type) + except Exception: + error_fn(f"Unsupported type annotation {annotation_type}. It is not a type.") + + +# Modified support param types and return types from torch/_library/infer_schema.py +def get_supported_param_types(): + data = [ + # (python type, schema type, type[] variant, type?[] variant, type[]? variant + (Tensor, "Tensor", True, True, False), + (int, "int", True, False, True), + (float, "float", True, False, True), + (bool, "bool", True, False, True), + (str, "str", False, False, False), + (types.Number, "Scalar", True, False, False), + (dtype, "ScalarType", False, False, False), + (device, "Device", False, False, False), + ] + result = [] + for line in data: + result.extend(derived_types(*line)) + return dict(result) + + +SUPPORTED_RETURN_TYPES = { + Tensor: "Tensor", + typing.List[Tensor]: "Tensor[]", + int: "int", + float: "float", + bool: "bool", + str: "str", + types.Number: "Scalar", +} + + +def check_param_annotation(name: str, annotation: type, sig: inspect.Signature): + if annotation is inspect.Parameter.empty: + error_fn(f"Parameter {name} must have a type annotation.", sig) + + # The annotation might be converted to a string by annotation, + # we convert it to the actual type. + annotation_type = annotation + if isinstance(annotation_type, str): + annotation_type = convert_type_string(annotation_type) + + if annotation_type not in SUPPORTED_PARAM_TYPES.keys(): + if annotation_type.__origin__ is tuple: + list_type = tuple_to_list(annotation_type) + example_type_str = "\n\n" + # Only suggest the list type if this type is supported. + if list_type in SUPPORTED_PARAM_TYPES.keys(): + example_type_str = f"For example, {list_type}.\n\n" + error_fn( + f"Parameter {name} has unsupported type {annotation}. " + f"We do not support Tuple inputs in schema. As a workaround, please try to use List instead. " + f"{example_type_str}" + f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}.", + sig, + ) + else: + error_fn( + f"Parameter {name} has unsupported type {annotation}. " + f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}.", + sig, + ) + return annotation_type + + +def get_schema_type( + schema_type: str, + mutates_args: Union[str, Iterable[str]], + name: str, + sig: inspect.Signature, + idx: int, +): + if isinstance(mutates_args, str): + if mutates_args != "unknown": + raise ValueError( + "mutates_args must either be a sequence of the names of " + "the arguments that are mutated or the string 'unknown'. " + ) + if schema_type.startswith("Tensor"): + schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" + elif name in mutates_args: + if not schema_type.startswith("Tensor"): + error_fn( + f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated" + ) + schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" + return schema_type + + +def check_mutates_args( + mutates_args: Union[str, Iterable[str]], sig: inspect.Signature, seen_args: set +): + if mutates_args != "unknown": + mutates_args_not_seen = set(mutates_args) - seen_args + if len(mutates_args_not_seen) > 0: + error_fn( + f"{mutates_args_not_seen} in mutates_args were not found in " + f"the custom op's signature. " + f"mutates_args should contain the names of all args that the " + f"custom op mutates, or just the string 'unknown' if you don't know.", + sig, + ) + + +def get_return_annonation( + return_annotation: type, +): + if isinstance(return_annotation, str): + return_annotation = convert_type_string(return_annotation) + return parse_return(return_annotation, error_fn) + + +def infer_schema( + prototype_function: typing.Callable, + /, + *, + mutates_args, + op_name: Optional[str] = None, +) -> str: + r""" + This is modified from torch._library.infer_schema.infer_schema. + + Parses the schema of a given function with type hints. The schema is inferred from the + function's type hints, and can be used to define a new operator. + + We make the following assumptions: + + * None of the outputs alias any of the inputs or each other. + * | String type annotations "device, dtype, Tensor, types" without library specification are + | assumed to be torch.*. Similarly, string type annotations "Optional, List, Sequence, Union" + | without library specification are assumed to be typing.*. + * | Only the args listed in ``mutates_args`` are being mutated. If ``mutates_args`` is "unknown", + | it assumes that all inputs to the operator are being mutates. + + Callers (e.g. the custom ops API) are responsible for checking these assumptions. + + Args: + prototype_function: The function from which to infer a schema for from its type annotations. + op_name (Optional[str]): The name of the operator in the schema. If ``name`` is None, then the + name is not included in the inferred schema. Note that the input schema to + ``torch.library.Library.define`` requires a operator name. + mutates_args ("unknown" | Iterable[str]): The arguments that are mutated in the function. + + Returns: + The inferred schema. + + Example: + >>> def foo_impl(x: torch.Tensor) -> torch.Tensor: + >>> return x.sin() + >>> + >>> infer_schema(foo_impl, op_name="foo", mutates_args={}) + foo(Tensor x) -> Tensor + >>> + >>> infer_schema(foo_impl, mutates_args={}) + (Tensor x) -> Tensor + """ + sig = inspect.signature(prototype_function) + + params = [] + seen_args = set() + saw_kwarg_only_arg = False + for idx, (name, param) in enumerate(sig.parameters.items()): + if not supported_param(param): + error_fn( + "We do not support positional-only args, varargs, or varkwargs.", sig + ) + + if param.kind == inspect.Parameter.KEYWORD_ONLY: + # The first time we see a kwarg-only arg, add "*" to the schema. + if not saw_kwarg_only_arg: + params.append("*") + saw_kwarg_only_arg = True + + annotation_type = check_param_annotation(name, param.annotation, sig) + + schema_type = SUPPORTED_PARAM_TYPES[annotation_type] + schema_type = get_schema_type(schema_type, mutates_args, name, sig, idx) + + seen_args.add(name) + if param.default is inspect.Parameter.empty: + params.append(f"{schema_type} {name}") + else: + default_repr = None + if param.default is None or isinstance(param.default, (int, float, bool)): + default_repr = str(param.default) + elif isinstance(param.default, (str, torch.device)): + default_repr = f'"{param.default}"' + elif isinstance(param.default, torch.dtype): + dtype_repr = str(param.default) + torch_dot = "torch." + assert dtype_repr.startswith(torch_dot) + default_repr = dtype_repr[len(torch_dot) :] + else: + error_fn( + f"Parameter {name} has an unsupported default value type {type(param.default)}. " + f"Please file an issue on GitHub so we can prioritize this.", + sig, + ) + params.append(f"{schema_type} {name}={default_repr}") + check_mutates_args(mutates_args, sig, seen_args) + + ret = get_return_annonation(sig.return_annotation) + if op_name is not None: + return f"{op_name}({', '.join(params)}) -> {ret}" + return f"({', '.join(params)}) -> {ret}" From ebbebd4b0e30e91d2fddbeda76b5c45f5e6c88ec Mon Sep 17 00:00:00 2001 From: Shiyan Deng Date: Thu, 19 Sep 2024 13:36:05 -0700 Subject: [PATCH 27/27] reserve a method during torchscript (#3152) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3152 X-link: https://github.com/facebookresearch/FBGEMM/pull/247 This method is useful to help recompute buffers for torchscripted model in pyhon. Reviewed By: seanx92 Differential Revision: D63000116 fbshipit-source-id: 43ad420b22ac2e06b59e2189dada1a0b30befcad --- .../fbgemm_gpu/split_embedding_configs.py | 24 +++++++++++-------- ..._table_batched_embeddings_ops_inference.py | 13 ++++++---- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py b/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py index 365aebbfc..afb931bb3 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py +++ b/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py @@ -68,6 +68,19 @@ def get(self, name: str) -> int: return self.config[name] +def sparse_type_to_int(sparse_type: "SparseType") -> int: + return { + SparseType.FP32.value: 0, + SparseType.FP16.value: 1, + SparseType.INT8.value: 2, + SparseType.INT4.value: 3, + SparseType.INT2.value: 4, + SparseType.BF16.value: 5, + SparseType.FP8.value: 6, + SparseType.MX4.value: 7, + }[sparse_type.value] + + @enum.unique class SparseType(enum.Enum): FP32 = "fp32" @@ -104,16 +117,7 @@ def from_int(ty: int) -> "SparseType": raise ValueError(f"Unsupported sparse type: {ty}") def as_int(self) -> int: - return { - SparseType.FP32.value: 0, - SparseType.FP16.value: 1, - SparseType.INT8.value: 2, - SparseType.INT4.value: 3, - SparseType.INT2.value: 4, - SparseType.BF16.value: 5, - SparseType.FP8.value: 6, - SparseType.MX4.value: 7, - }[self.value] + return sparse_type_to_int(self) @staticmethod def from_dtype(dtype: torch.dtype, is_mx: bool = False) -> "SparseType": diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index 721e4f248..d988563ae 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -17,7 +17,7 @@ import torch # usort:skip from torch import nn, Tensor # usort:skip -from fbgemm_gpu.split_embedding_configs import SparseType +from fbgemm_gpu.split_embedding_configs import sparse_type_to_int, SparseType from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( BoundsCheckMode, CacheAlgorithm, @@ -943,6 +943,7 @@ def reset_embedding_spec_location( for spec in self.embedding_specs ] + @torch.jit.export def recompute_module_buffers(self) -> None: """ Compute module buffers that're on meta device and are not materialized in reset_weights_placements_and_offsets(). @@ -955,7 +956,7 @@ def recompute_module_buffers(self) -> None: ): return - weights_tys_int = [e[3].as_int() for e in self.embedding_specs] + weights_tys_int = [sparse_type_to_int(e[3]) for e in self.embedding_specs] self.weights_tys = torch.tensor( [weights_tys_int[t] for t in self.feature_table_map], device=self.current_device, @@ -968,8 +969,9 @@ def recompute_module_buffers(self) -> None: dtype=torch.int64, ) dims = [e[2] for e in self.embedding_specs] - D_offsets_list = [dims[t] for t in self.feature_table_map] - D_offsets_list = [0] + list(accumulate(D_offsets_list)) + D_offsets_list = [0] + for t in self.feature_table_map: + D_offsets_list.append(dims[t] + D_offsets_list[-1]) self.D_offsets = torch.tensor( D_offsets_list, device=self.current_device, dtype=torch.int32 ) @@ -999,6 +1001,9 @@ def recompute_module_buffers(self) -> None: self.table_wise_cache_miss = torch.empty_like( self.table_wise_cache_miss, device=self.current_device ) + self.weights_uvm = torch.empty_like( + self.weights_uvm, device=self.current_device + ) def _apply_split( self,