diff --git a/.github/workflows/fbgemm_gpu_pip.yml b/.github/workflows/fbgemm_gpu_pip.yml index 20cac9dc9..33125145f 100644 --- a/.github/workflows/fbgemm_gpu_pip.yml +++ b/.github/workflows/fbgemm_gpu_pip.yml @@ -160,7 +160,7 @@ jobs: run: . $PRELUDE; install_fbgemm_gpu_pip $BUILD_ENV ${{ github.event.inputs.fbgemm_gpu_channel_version || 'nightly' }} cuda/${{ matrix.cuda-version }} - name: Test with PyTest - timeout-minutes: 20 + timeout-minutes: 40 run: . $PRELUDE; test_all_fbgemm_gpu_modules $BUILD_ENV diff --git a/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py b/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py index 592f11496..c919199ee 100644 --- a/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py +++ b/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py @@ -14,7 +14,8 @@ import torch from torch import Tensor -logging.basicConfig(level=logging.DEBUG) +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) try: # pyre-ignore[21] diff --git a/fbgemm_gpu/bench/jagged_tensor_benchmark.py b/fbgemm_gpu/bench/jagged_tensor_benchmark.py index 51c231ad0..acbe22fb2 100644 --- a/fbgemm_gpu/bench/jagged_tensor_benchmark.py +++ b/fbgemm_gpu/bench/jagged_tensor_benchmark.py @@ -16,7 +16,8 @@ import torch from torch.profiler import profile -logging.basicConfig(level=logging.DEBUG) +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. open_source: bool = getattr(fbgemm_gpu, "open_source", False) diff --git a/fbgemm_gpu/bench/merge_embeddings_benchmark.py b/fbgemm_gpu/bench/merge_embeddings_benchmark.py index 2c0b62664..95ce71d27 100644 --- a/fbgemm_gpu/bench/merge_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/merge_embeddings_benchmark.py @@ -32,6 +32,9 @@ # pyre-fixme[21]: Could not find name `ProfilerActivity` in `torch.profiler`. from torch.profiler import profile, ProfilerActivity +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. open_source: bool = getattr(fbgemm_gpu, "open_source", False) diff --git a/fbgemm_gpu/bench/quantize_ops_benchmark.py b/fbgemm_gpu/bench/quantize_ops_benchmark.py index b4e596f4d..9ffbd9911 100644 --- a/fbgemm_gpu/bench/quantize_ops_benchmark.py +++ b/fbgemm_gpu/bench/quantize_ops_benchmark.py @@ -22,8 +22,8 @@ # pyre-ignore[21] from torch.profiler import profile, ProfilerActivity - -logging.basicConfig(level=logging.DEBUG) +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. open_source: bool = getattr(fbgemm_gpu, "open_source", False) diff --git a/fbgemm_gpu/bench/sparse_ops_benchmark.py b/fbgemm_gpu/bench/sparse_ops_benchmark.py index fdd051909..2ef9abe8f 100644 --- a/fbgemm_gpu/bench/sparse_ops_benchmark.py +++ b/fbgemm_gpu/bench/sparse_ops_benchmark.py @@ -20,7 +20,8 @@ from torch.profiler import profile -logging.basicConfig(level=logging.DEBUG) +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. open_source: bool = getattr(fbgemm_gpu, "open_source", False) diff --git a/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py b/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py index 432ef3f4d..d3169ca81 100644 --- a/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py +++ b/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py @@ -26,7 +26,8 @@ from torch import nn, Tensor -logging.basicConfig(level=logging.DEBUG) +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) try: # pyre-ignore[21] diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index a809b5e9b..177c79508 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -48,6 +48,9 @@ from torch import Tensor from torch.profiler import profile +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + haveAIBench = False try: from aibench_observer.utils.observer import emitMetric diff --git a/fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py index 430087c4a..25540c190 100644 --- a/fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/ssd_table_batched_embeddings_benchmark.py @@ -40,14 +40,13 @@ from torch.autograd.profiler import record_function from torch.profiler import profile -logging.basicConfig(level=logging.DEBUG) +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) load_torch_module( "//deeplearning/fbgemm/fbgemm_gpu:ssd_split_table_batched_embeddings", ) -logging.basicConfig(level=logging.DEBUG) - @click.group() def cli() -> None: diff --git a/fbgemm_gpu/bench/stride_gemm_benchmark.py b/fbgemm_gpu/bench/stride_gemm_benchmark.py index 3c70d734f..2609f7fbf 100644 --- a/fbgemm_gpu/bench/stride_gemm_benchmark.py +++ b/fbgemm_gpu/bench/stride_gemm_benchmark.py @@ -13,7 +13,8 @@ import torch from fbgemm_gpu.bench.bench_utils import benchmark_torch_function -logging.basicConfig(level=logging.DEBUG) +logger: logging.Logger = logging.getLogger() +logger.setLevel(logging.DEBUG) try: # pyre-ignore[21] diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index c7d77e21a..afdcb8b3c 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -163,38 +163,38 @@ def generate_backward_split_gpu(**kwargs: Any) -> None: # has_gpu_support=True if not kwargs.get("dense"): # Generate CUDA autograd - template_filepath = ( - "training/backward/embedding_backward_split_host_template.cpp" - ) + for ssd in [True, False] if kwargs.get("has_ssd_support") else [False]: + template_filepath = ( + "training/backward/embedding_backward_split_host_template.cpp" + ) desc = "ssd" if ssd else "split" + sdesc = "_ssd" if ssd else "" filename = f"gen_embedding_backward_{desc}_{optimizer}.cpp" CodeTemplate.load(template_filepath).write( filename, is_forward=False, ssd=ssd, **kwargs ) - # Generate PT2 unified autograd, and PT2 backward wrapper - for template_filepath, filename in [ - ( - "training/pt2/embedding_split_host_pt2_autograd_template.cpp", - f"gen_embedding_split_{optimizer}_pt2_autograd.cpp", - ), - ( - "training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp", - f"gen_embedding_backward_split_{optimizer}_pt2_cuda_wrapper.cpp", - ), - ]: - CodeTemplate.load(template_filepath).write( - filename, is_forward=False, **kwargs - ) - - if kwargs.get("has_cpu_support") or kwargs.get("has_gpu_support"): - # Generates Python invoker for CUDA + CPU, and PT2 - template = CodeTemplate.load( - "training/python/split_embedding_codegen_lookup_invoker.template" - ) - for ssd in [True, False] if kwargs.get("has_ssd_support") else [False]: - sdesc = "_ssd" if ssd else "" + # Generate PT2 unified autograd, and PT2 backward wrapper for all optimizers + for template_filepath, filename in [ + ( + "training/pt2/embedding_split_host_pt2_autograd_template.cpp", + f"gen_embedding_{desc}_{optimizer}_pt2_autograd.cpp", + ), + ( + "training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp", + f"gen_embedding_backward_{desc}_{optimizer}_pt2_cuda_wrapper.cpp", + ), + ]: + CodeTemplate.load(template_filepath).write( + filename, is_forward=False, ssd=ssd, **kwargs + ) + + if kwargs.get("has_cpu_support") or kwargs.get("has_gpu_support"): + # Generates Python invoker for CUDA + CPU, and PT2 + template = CodeTemplate.load( + "training/python/split_embedding_codegen_lookup_invoker.template" + ) for filename in [ f"lookup_{optimizer}{sdesc}.py", f"lookup_{optimizer}{sdesc}_pt2.py", diff --git a/fbgemm_gpu/codegen/genscript/generate_forward_split.py b/fbgemm_gpu/codegen/genscript/generate_forward_split.py index 8844712b0..285cf9a55 100644 --- a/fbgemm_gpu/codegen/genscript/generate_forward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_forward_split.py @@ -85,6 +85,17 @@ def generate_pt2_wrappers() -> None: is_forward=True, ) + # Generate PT2 forward wrapper (CUDA) + CodeTemplate.load( + "training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp", + ).write( + f"gen_embedding_forward_ssd_pt2_cuda_wrapper.cpp", + has_gpu_support=True, + is_forward=True, + has_vbe_support=True, + ssd=True, + ) + @staticmethod def generate_small_kernels() -> None: # Generate the small kernels (for nobag only) for the forward splits diff --git a/fbgemm_gpu/codegen/genscript/optimizer_args.py b/fbgemm_gpu/codegen/genscript/optimizer_args.py index bd10fa8d6..dab0bc5fa 100644 --- a/fbgemm_gpu/codegen/genscript/optimizer_args.py +++ b/fbgemm_gpu/codegen/genscript/optimizer_args.py @@ -27,6 +27,23 @@ from torch_type_utils import arg_type_to_tensor_type, ArgType, TensorType +###################################################################### +# Optimizer Args Set Item +###################################################################### + + +@dataclass +class OptimizerArgsSetItem: + ty: ArgType # type + name: str + default: Union[float, ArgType] = 0 # DEFAULT_ARG_VAL + ph_tys: Optional[List[ArgType]] = None # placeholder types + + +# Alias b/c the name is too long +OptimItem = OptimizerArgsSetItem + + ###################################################################### ## Helper functions for the code generator script ## ###################################################################### @@ -117,6 +134,11 @@ def int_tensor_arg(name: str, gpu: bool = True, pass_by_ref: bool = False) -> st return _arg("int32_t", name, gpu=gpu, pass_by_ref=pass_by_ref) +def tensor_list_arg_no_default(name: str, pass_by_ref: bool) -> str: + ref = "&" if pass_by_ref else "" + return f"at::TensorList{ref} {name}" + + def tensor_arg(name: str) -> str: return f"Tensor {name}" @@ -165,6 +187,10 @@ def schema_sym_int_arg_no_default(name: str) -> str: return f"SymInt {name}" +def schema_tensor_list_arg_no_default(name: str) -> str: + return f"Tensor[] {name}" + + def make_kernel_arg( # pyre-fixme[11]: Annotation `ArgType` is not defined as a type. ty: ArgType, @@ -282,21 +308,69 @@ def make_ivalue_cast(ty: ArgType) -> str: }[ty] -###################################################################### -# Optimizer Args Set Item -###################################################################### - - @dataclass -class OptimizerArgsSetItem: - ty: ArgType # type - name: str - default: Union[float, ArgType] = 0 # DEFAULT_ARG_VAL - ph_tys: Optional[List[ArgType]] = None # placeholder types - +class PT2ArgsSet: + split_function_args: List[str] + split_function_arg_names: List[str] + split_function_schemas: List[str] + split_saved_tensor_list: List[str] -# Alias b/c the name is too long -OptimItem = OptimizerArgsSetItem + @staticmethod + # pyre-ignore[3] + def create( + split_arg_spec: List[OptimItem], + ): + """ + PT2ArgsSet.create() is a method that creates different formats given the optimization arguments + to be used in TBE codegen PT2 templates. + + Parameters: + split_arg_spec: List[OptimItem] - list of argument specs + + Returns: + PT2ArgsSet object with the following attributes: + split_function_args: List[str] - List of function arguments + e.g., ['at::TensorList momentum1', 'double eps', 'double weight_decay']. + split_function_arg_names: List[str] - List of argument names + e.g., ['momentum1', 'eps', 'weight_decay']. + split_function_schemas: List[str] - List of arguments in the schema format + e.g., ['Tensor[] momentum1', 'float eps', 'float weight_decay']. + split_saved_tensor_list: List[str] - List of saved tensors for the split function + e.g., ['momentum1']. + """ + split_function_arg_names = [] + split_function_args = [] + split_function_schemas = [] + split_saved_tensor_list = [] + for s in split_arg_spec: + if s.ty in ( + ArgType.TENSOR, + ArgType.INT_TENSOR, + ArgType.LONG_TENSOR, + ArgType.PLACEHOLDER_TENSOR, + ): + name = s.name.rsplit("_", 1)[0] + if name not in split_function_arg_names: + split_function_arg_names.append(name) + split_saved_tensor_list.append(name) + split_function_args.append( + tensor_list_arg_no_default(name, pass_by_ref=False) + ) + split_function_schemas.append( + schema_tensor_list_arg_no_default(name) + ) + else: + split_function_arg_names.append(s.name) + split_function_args.append(make_function_arg(s.ty, s.name, s.default)) + split_function_schemas.append( + make_function_schema_arg(s.ty, s.name, s.default) + ) + return PT2ArgsSet( + split_function_args=split_function_args, + split_function_arg_names=split_function_arg_names, + split_function_schemas=split_function_schemas, + split_saved_tensor_list=split_saved_tensor_list, + ) ###################################################################### @@ -324,6 +398,7 @@ class OptimizerArgs: placeholder_tensor_names: List[str] # pyre-fixme[11]: Annotation `TensorType` is not defined as a type. placeholder_type_combos: Union[List[Dict[str, TensorType]], List[None]] + unified_pt2: PT2ArgsSet @staticmethod # pyre-ignore[3] @@ -419,6 +494,7 @@ def create( ], placeholder_tensor_names=ph_tensor_names, placeholder_type_combos=ph_combos, + unified_pt2=PT2ArgsSet.create(split_arg_spec), ) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp index e799120a6..41fd137dd 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp @@ -103,6 +103,9 @@ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu( std::optional max_float8_D, std::optional fp8_exponent_bits, std::optional fp8_exponent_bias) { + if (offsets.scalar_type() != indices.scalar_type()) { + offsets = offsets.toType(indices.scalar_type()); + } if (static_cast(pooling_mode) == PoolingMode::NONE) { std::vector max_D_list{ max_int2_D, diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu index e7b908cdd..bc4e7ba74 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu @@ -62,7 +62,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no {%- set func_name = "nbit::" + emb_weight_type + "_split_embedding" + ("_nobag" if nobag else "") + "_codegen_forward_" + wdesc + "_kernel_small_L" %} #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name_{{ emb_weight_type }} = func_name_{{ emb_weight_type }}; + const auto func_name_{{ emb_weight_type }} = "{{ func_name }}_{{ emb_weight_type }}"; #endif #ifdef X diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 40245fba2..b623b92d0 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -36,22 +36,468 @@ #include #include "fbgemm_gpu/utils/dispatch_macros.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" +#include "fbgemm_gpu/config/feature_gates.h" using Tensor = at::Tensor; using namespace fbgemm_gpu; +{#/* Module description */#} +{%- set fwd_mdesc = "ssd" if ssd else ("dense" if dense else "split") %} +{%- set bwd_mdesc = "ssd" if ssd else "split" %} + +{%- if ssd %} +enum SSDTensor { + {%- for tensor in ssd_tensors %} + {{ tensor | upper }} = {{ loop.index - 1 }}, + {%- endfor %} +}; +{%- endif %} + +//////////////////////////////////////////////////////////////////////////////// +// Macro Helper Functions +//////////////////////////////////////////////////////////////////////////////// + +// TO DO: Refactor +{# +/* This macro generates a code blob for dispatching corresponding weighted and + unweighted forward op from via Pytorch dispatcher +*/ +#} +{%- macro call_forward_op_dispatch(nobag, weighted, vbe, is_gwd) %} + {%- set forward_op = "{}_embedding{}_codegen_forward_{}{}{}_pt2_wrapper".format( + fwd_mdesc, + "_nobag" if nobag else "", + "weighted" if weighted else "unweighted", + "_vbe" if vbe else "", + "_gwd" if is_gwd else "", + ) + %} + {%- set has_experimental = has_experimental_support( + dense, nobag, vbe, is_index_select=False, is_rocm=is_rocm, ssd=ssd + ) and not is_gwd + %} + static auto embedding_codegen_forward_op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::{{ forward_op }}", "") + .typed(); + + return { + embedding_codegen_forward_op.call( + weights_host, + flatten_weights_dev, + weights_uvm, + lxu_cache_weights, + weights_placements, + weights_offsets, + {%- if nobag %} + D, + {%- else %} + D_offsets, + total_D, + max_D, + {%- endif %} + hash_size_cumsum, + indices, + offsets, + {%- if not nobag %} + pooling_mode, + indice_weights_value, + {%- endif %} {# /* if not nobag */ #} + {%- if not dense %} + {{ "ssd_tensors[SSDTensor::ROW_ADDRS]" if ssd else "lxu_cache_locations" }}, + uvm_cache_stats_, + {%- endif %} + {%- if not nobag %} + {%- if vbe %} + vbe_row_output_offsets, + vbe_b_t_map, + vbe_output_size, + info_B_num_bits, + info_B_mask_int64, + {%- endif %} {# /* if vbe */ #} + {%- if is_gwd %} + prev_iter_dev_, + learning_rate, + weight_decay, + iter, + gwd_lower_bound, + {%- endif %} {# /* if is_gwd */ #} + {%- endif %} {# /* if not nobag */ #} + is_experimental, + output_dtype + ) + }; +{%- endmacro %} + +/* This macro generates a code blob for dispatching corresponding weighted and + unweighted backward op via Pytorch dispatcher +*/ +{%- macro call_backward_op_dispatch(nobag, weighted, vbe, is_gwd) %} + {%- set wdesc = "_weighted" if weighted else "_unweighted" %} + {%- set backward_op = "{}_embedding{}_backward_codegen_{}{}{}{}_pt2_wrapper".format( + bwd_mdesc, + "_nobag" if nobag else "", + optimizer, + wdesc, + "_vbe" if vbe else "", + "_gwd" if is_gwd else "", + ) + %} + static auto embedding_codegen{{ wdesc }}_backward_op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::{{ backward_op }}", "") + .typed(); + + grad_weights_dev = embedding_codegen{{ wdesc }}_backward_op.call( + grad_output, + weights_host, + weights_dev, + weights_uvm, + lxu_cache_weights, + weights_placements, + weights_offsets, + {% if nobag %} + D, + {%- else %} + D_offsets, + max_D, + {%- endif %} {# /* if nobag */ #} + hash_size_cumsum, + total_hash_size_bits, + indices, + offsets, + {%- if not nobag %} + pooling_mode, + indice_weights, + {%- endif %} {# /* if not nobag */ #} + {%- if ssd %} + ssd_row_addrs, + {%- else %} + lxu_cache_locations, + {%- endif %} + BT_block_size, + max_segment_length_per_warp, + {%- if optimizer != "none" %} + stochastic_rounding, + {%- endif %} + info_B_num_bits, + info_B_mask_int64, + {%- if vbe %} + B_offsets, + vbe_row_output_offsets, + vbe_b_t_map, + {%- endif %} {# /* if vbe */ #} + {%- if not dense %} + use_uniq_cache_locations_bwd, + use_homogeneous_placements, + {%- endif %} + {%- if is_gwd %} + {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} + prev_iter_dev, + {%- endif %} + {%- if "iter" not in args_pt2.split_function_arg_names %} + iter, + {%- endif %} + gwd_lower_bound, + {%- endif %} {# /* if is_gwd */ #} + {{ args_pt2.split_function_arg_names | join(", ") }} + {%- if not nobag %} + , output_dtype + {%- endif %} + ); + return { + {%- if not dense %} + Tensor(), // placeholder autograd tensor + {%- endif %} + Variable(), // output_dtype + Variable(), // weights_host + grad_weights_dev, // weights_dev + {%- if not dense %} + Variable(), // weights_uvm + Variable(), // lxu_cache_weights + Variable(), // weights_placements + {%- endif %} + Variable(), // weights_offsets + {%- if nobag %} + Variable(), // D + {%- else %} + Variable(), // D_offsets + Variable(), // total_D + Variable(), // max_D + {%- endif %} + Variable(), // hash_size_cumsum + Variable(), //total_hash_size_bits + Variable(), // indices + Variable(), // offsets + {%- if not nobag %} + Variable(), // pooling_mode + grad_indice_weights, // indice_weights + Variable(), // feature_requires_grad + {%- endif %} + {%- if not dense %} + Variable(), // lxu_cache_locations + Variable(), // uvm_cache_stats + {%- endif %} + {%- if optimizer != "none" and not dense %} + Variable(), // gradient_clipping + Variable(), // max_gradient + Variable(), // stochastic_rounding + {%- endif %} + {%- if vbe %} + Variable(), // B_offsets + Variable(), // vbe_output_offsets_feature_rank + Variable(), // vbe_B_offsets_rank_per_feature + Variable(), // max_B + Variable(), // max_B_feature_rank + Variable(), // vbe_output_size + {%- endif %} + {%- if not dense %} + Variable(), // is_experimental + Variable(), // use_uniq_cache_locations_bwd + Variable(), // use_homogeneous_placements + {%- endif %} + {%- if is_gwd %} + {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} + Variable(), // prev_iter_dev + {%- endif %} + {%- if "iter" not in args_pt2.split_function_arg_names %} + Variable(), // iter + {%- endif %} + Variable(), // gwd_lower_bound + {%- endif %} + {%- if ssd %} + {%- for tensor in ssd_tensors %} + Variable(), // {{ tensor }} + {%- endfor %} + {%- endif %} + {{ args_pt2.split_variables | join(", ") }} + }; +{%- endmacro %} + +/* This macro generates a code blob that calls corresponding autograd function + from lookup_function +*/ +{%- macro call_autograd(nobag, vbe, is_gwd) %} + {%- set autograd_func = "{}{}{}{}LookupFunction_{}_Op_pt2".format( + "SSD" if ssd else "Split", + "NoBag" if nobag else "", + "VBE" if vbe else "", + "GWD" if is_gwd else "", + optimizer + ) + %} + return {{ autograd_func }}::apply( + {%- if not dense %} + placeholder_autograd_tensor, + {%- endif %} + output_dtype, + weights, + {%- if not dense %} + lxu_cache_weights, + {%- endif %} + {%- if nobag %} + max_D, + {%- else %} + D_offsets, + total_D, + max_D, + {%- endif %} + hash_size_cumsum, + total_hash_size_bits, + indices, + {%- if not nobag and dense and not vbe %} + offsets, + pooling_mode, + indice_weights, + feature_requires_grad + {%- elif not nobag %} + offsets, + pooling_mode, + indice_weights, + feature_requires_grad, + {%- elif nobag and dense and not vbe %} + offsets + {%- else %} + offsets, + {%- endif %} + {%- if not dense %} + lxu_cache_locations, + uvm_cache_stats, + {%- endif %} + {%- if optimizer != "none" and not dense %} + gradient_clipping, + max_gradient, + stochastic_rounding, + {%- endif %} + {%- if vbe %} + B_offsets, + vbe_output_offsets_feature_rank, + vbe_B_offsets_rank_per_feature, + max_B, + max_B_feature_rank, + vbe_output_size, + {%- endif %} + {%- if not dense %} + is_experimental, + use_uniq_cache_locations_bwd, + use_homogeneous_placements, + {%- if is_gwd %} + {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} + prev_iter_dev, + {%- endif %} + {%- if "iter" not in args_pt2.split_function_arg_names %} + iter, + {%- endif %} + gwd_lower_bound, + {%- endif %} + {%- if ssd %} + ssd_tensors.value(), + {%- endif %} + {{ args_pt2.unified_pt2.split_function_arg_names | join(", ") }} + {%- endif %} + )[0]; +{%- endmacro %} + +/* This macro generates a code blob for unpacking the tensor list +*/ +{%- macro unpack_tensor_list(tensor_list) %} + const Tensor {{ tensor_list }}_host = {{ tensor_list }}[0]; + const Tensor {{ tensor_list }}_dev = {{ tensor_list }}[1]; + const Tensor {{ tensor_list }}_uvm = {{ tensor_list }}[2]; + const Tensor {{ tensor_list }}_placements = {{ tensor_list }}[3]; + const Tensor {{ tensor_list }}_offsets = {{ tensor_list }}[4]; +{%- endmacro %} + + +//////////////////////////////////////////////////////////////////////////////// +// Autograd Function Declarations +//////////////////////////////////////////////////////////////////////////////// + {%- if has_gpu_support or has_cpu_support %} {%- for vbe in ([True, False] if has_vbe_support else [False]) %} {%- set vdesc = "_vbe" if vbe else "" %} -{%- for nobag in [True, False] %} -{%- if not nobag or not vbe %} {#-/* nobag does not support vbe */#} -{%- set autograd_func = "Split{}{}LookupFunction_{}_Op_pt2".format( - "NoBag" if nobag else "", - "VBE" if vbe else "", - optimizer +{%- for nobag in ([False] if (weighted or vbe) else [True, False]) %} +{%- set ndesc = "_nobag" if nobag else "" %} + +{%- for is_gwd in ([True, False] + if is_valid_gwd_config( + dense, + nobag, + vbe, + is_index_select, + has_global_weight_decay_support, + ssd) + else [False]) %} +{%- set gwddesc = "_gwd" if is_gwd else "" %} + +{%- set autograd_func = "{}{}{}{}LookupFunction_{}_Op_pt2".format( + "SSD" if ssd else "Split", + "NoBag" if nobag else "", + "VBE" if vbe else "", + "GWD" if is_gwd else "", + optimizer, ) %} @@ -68,18 +514,14 @@ class {{ autograd_func }} : torch::autograd::AutogradContext* ctx, const Tensor& placeholder_autograd_tensor, const int64_t output_dtype, - const Tensor& host_weights, - const Tensor& dev_weights, - const Tensor& uvm_weights, + const at::TensorList weights, const Tensor& lxu_cache_weights, - const Tensor& weights_placements, - const Tensor& weights_offsets, {%- if not nobag %} const Tensor& D_offsets, - const int64_t total_D, - const int64_t max_D, + const c10::SymInt total_D, + const c10::SymInt max_D, {%- else %} - const int64_t D, + const c10::SymInt D, {%- endif %} const Tensor& hash_size_cumsum, const int64_t total_hash_size_bits, @@ -108,7 +550,25 @@ class {{ autograd_func }} : const bool is_experimental, const bool use_uniq_cache_locations_bwd, const bool use_homogeneous_placements, - {{ args_pt2.split_function_args | join(", ") }}) { + {%- if is_gwd %} + {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} + const std::optional& prev_iter_dev, + {%- endif %} + {%- if "iter" not in args_pt2.split_function_arg_names %} + const int64_t iter, + {%- endif %} + const double gwd_lower_bound, + {%- endif %} + {%- if ssd %} + const at::TensorList& ssd_tensors, + {%- endif %} + {{ args_pt2.unified_pt2.split_function_args | join(", ") }}) { + + // unpack Tensor lists + {{ unpack_tensor_list("weights") }} + {%- for arg_name in args_pt2.unified_pt2.split_saved_tensor_list %} + {{ unpack_tensor_list(arg_name) }} + {%- endfor %} const auto T = weights_offsets.sym_numel(); {%- if vbe %} @@ -124,7 +584,7 @@ class {{ autograd_func }} : // NOTE: The `local_uvm_cache_stats` variable held by the nn.Module has dtype int32_t // TODO: Hook up with frontend code const auto uvm_cache_stats_ = uvm_cache_stats - .value_or(at::empty({0}, uvm_weights.options().dtype(at::kInt))); + .value_or(at::empty({0}, weights_uvm.options().dtype(at::kInt))); // Default values for Dynamo tracing // SymInt does not support bitshifts operator @@ -147,7 +607,16 @@ class {{ autograd_func }} : static auto generate_vbe_metadata_op = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::generate_vbe_metadata", "") - .typed(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const c10::SymInt, const bool, const c10::SymInt, const int64_t, const c10::SymInt)>(); + .typed( + const Tensor&, + const Tensor&, + const Tensor&, + const Tensor&, + const int64_t, + const bool, + const c10::SymInt, + const int64_t, + const c10::SymInt)>(); auto [ vbe_row_output_offsets, @@ -168,18 +637,22 @@ class {{ autograd_func }} : {%- endif %} max_B_feature_rank, info_B_num_bits, - /*total_B=*/offsets.size(0) - 1 + /*total_B=*/offsets.sym_size(0) - 1 ); {%- endif %} // vbe + {%- if is_gwd %} + const auto prev_iter_dev_ = prev_iter_dev.value_or(Tensor()); + {%- endif %} + {%- if not nobag %} const auto indice_weights_value = indice_weights.value_or(Tensor()); {%- endif %} ctx->save_for_backward({ - host_weights, - dev_weights, - uvm_weights, + weights_host, + weights_dev, + weights_uvm, lxu_cache_weights, weights_placements, weights_offsets, @@ -199,6 +672,14 @@ class {{ autograd_func }} : vbe_row_output_offsets, vbe_b_t_map, {%- endif %} + {%- if is_gwd and "prev_iter_dev" not in args_pt2.split_function_arg_names %} + prev_iter_dev_, + {%- endif %} + {%- if ssd %} + {%- for tensor in ssd_tensors %} + ssd_tensors[SSDTensor::{{ tensor | upper }}], + {%- endfor %} + {%- endif %} {{ args_pt2.split_saved_tensors | join(", ") }} }); @@ -220,6 +701,12 @@ class {{ autograd_func }} : ctx->saved_data["info_B_mask"] = info_B_mask_int64; ctx->saved_data["use_uniq_cache_locations_bwd"] = use_uniq_cache_locations_bwd; ctx->saved_data["use_homogeneous_placements"] = use_homogeneous_placements; + {%- if is_gwd %} + {%- if "iter" not in args_pt2.split_function_arg_names %} + ctx->saved_data["iter"] = iter; + {%- endif %} + ctx->saved_data["gwd_lower_bound"] = gwd_lower_bound; + {%- endif %} {%- if not nobag %} ctx->saved_data["output_dtype"] = output_dtype; {%- endif %} @@ -230,136 +717,42 @@ class {{ autograd_func }} : {%- if optimizer == "none" %} // Flatten - const auto& flatten_dev_weights = dev_weights.flatten(); + const auto& flatten_weights_dev = weights_dev.flatten(); {%- else %} - const auto& flatten_dev_weights = dev_weights; + const auto& flatten_weights_dev = weights_dev; {%- endif %} - - {%- if not nobag %} - {%- for weighted in [False, True] %} - {%- set wdesc = "weighted" if weighted else "unweighted" %} - {%- if not weighted %} - if (!indice_weights_value.defined()) { + {%- if nobag %} + // nobag + {{ + call_forward_op_dispatch( + nobag=True, + weighted=False, + vbe=vbe, + is_gwd=is_gwd, + ) + }} {%- else %} - else { - {%- endif %} - {%- set forward_op = "split_embedding_codegen_forward_{}{}_pt2_wrapper".format( - wdesc, vdesc - ) - %} - static auto split_embedding_codegen_forward_op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::{{ forward_op }}", "") - .typed(); - return { - split_embedding_codegen_forward_op.call( - host_weights, - flatten_dev_weights, - uvm_weights, - lxu_cache_weights, - weights_placements, - weights_offsets, - D_offsets, - total_D, - max_D, - hash_size_cumsum, - indices, - offsets, - pooling_mode, - indice_weights_value, - lxu_cache_locations, - uvm_cache_stats_, - {%- if vbe %} - vbe_row_output_offsets, - vbe_b_t_map, - vbe_output_size, - info_B_num_bits, - info_B_mask_int64, - {%- endif %} - is_experimental, - output_dtype - ) - }; + if (indice_weights) { + // weighted + {{ + call_forward_op_dispatch( + nobag=False, + weighted=True, + vbe=vbe, + is_gwd=is_gwd, + ) + }} } - {%- endfor %} - {%- else %} - {%- set forward_nobag_op = "split_embedding_nobag_codegen_forward_unweighted_pt2_wrapper" %} - static auto split_embedding_codegen_forward_op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::{{ forward_nobag_op }}", "") - .typed(); - return { - split_embedding_codegen_forward_op.call( - host_weights, - flatten_dev_weights, - uvm_weights, - lxu_cache_weights, - weights_placements, - weights_offsets, - D, - hash_size_cumsum, - indices, - offsets, - lxu_cache_locations, - uvm_cache_stats_, - /*is_experimental=*/false, - output_dtype - ) - }; - {%- endif %} + // unweighted + {{ + call_forward_op_dispatch( + nobag=False, + weighted=False, + vbe=vbe, + is_gwd=is_gwd, + ) + }} + {%- endif %} {#-/* if not nobag */ #} } static torch::autograd::variable_list backward( @@ -367,9 +760,9 @@ static torch::autograd::variable_list backward( torch::autograd::variable_list grad_outputs) { const auto saved = ctx->get_saved_variables(); auto savedItr = std::begin(saved); - auto host_weights = *savedItr++; - auto dev_weights = *savedItr++; - auto uvm_weights = *savedItr++; + auto weights_host = *savedItr++; + auto weights_dev = *savedItr++; + auto weights_uvm = *savedItr++; auto lxu_cache_weights = *savedItr++; auto weights_placements = *savedItr++; auto weights_offsets = *savedItr++; @@ -389,6 +782,14 @@ static torch::autograd::variable_list backward( auto vbe_row_output_offsets = *savedItr++; auto vbe_b_t_map = *savedItr++; {%- endif %} + {%- if is_gwd and "prev_iter_dev" not in args_pt2.split_function_arg_names %} + auto prev_iter_dev = *savedItr++; + {%- endif %} + {%- if ssd %} + {%- for tensor in ssd_tensors %} + auto ssd_{{ tensor }} = *savedItr++; + {%- endfor %} + {%- endif %} {%- for tensor in args_pt2.split_saved_tensors %} auto {{ tensor }} = *savedItr++; @@ -411,6 +812,13 @@ static torch::autograd::variable_list backward( const int64_t info_B_mask_int64 = ctx->saved_data["info_B_mask"].toInt(); const auto use_uniq_cache_locations_bwd = ctx->saved_data["use_uniq_cache_locations_bwd"].toBool(); const auto use_homogeneous_placements = ctx->saved_data["use_homogeneous_placements"].toBool(); + {%- if is_gwd %} + {%- if "iter" not in args_pt2.split_function_arg_names %} + const auto iter = ctx->saved_data["iter"].toInt(); + {%- endif %} + const auto gwd_lower_bound = ctx->saved_data["gwd_lower_bound"].toDouble(); + {%- endif %} + {%- if not nobag %} auto output_dtype = ctx->saved_data["output_dtype"].toInt(); {%- endif %} @@ -438,22 +846,22 @@ static torch::autograd::variable_list backward( {%- if not nobag %} {%- if optimizer == "none" %} - // Flatten (dev_weights is used in - // split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cuda) - dev_weights = dev_weights.flatten(); + // Flatten (weights_dev is used in + // {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cuda) + weights_dev = weights_dev.flatten(); {%- endif %} {%- set grad_indice_weights_op = - "split_embedding_codegen_grad_indice_weights{}_pt2_wrapper".format(vdesc) + "{}_embedding_codegen_grad_indice_weights{}_pt2_wrapper".format(fwd_mdesc, vdesc) %} - static auto split_embedding_codegen_grad_indice_weights_op = + static auto embedding_codegen_grad_indice_weights_op = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::{{ grad_indice_weights_op }}", "") .typed(); + )>(); const auto grad_indice_weights = !indice_weights.defined() ? Variable() : - split_embedding_codegen_grad_indice_weights_op.call( + embedding_codegen_grad_indice_weights_op.call( grad_output, - host_weights, - dev_weights, - uvm_weights, + weights_host, + weights_dev, + weights_uvm, lxu_cache_weights, weights_placements, weights_offsets, @@ -487,7 +899,11 @@ static torch::autograd::variable_list backward( max_D, indices, offsets, + {%- if ssd %} + ssd_row_addrs, + {%- else %} lxu_cache_locations, + {%- endif %} {%- if vbe %} feature_requires_grad, vbe_row_output_offsets, @@ -498,276 +914,56 @@ static torch::autograd::variable_list backward( feature_requires_grad {%- endif %} ); - - {%- for weighted in [False, True] %} - {%- set wdesc = "weighted" if weighted else "unweighted" %} - {%- set backward_op = "split_embedding_backward_codegen_{}_{}_exact{}_pt2_wrapper".format( - optimizer, wdesc, vdesc + + Tensor grad_weights_dev; + if (indice_weights.defined()) + { + // weighted + {{ + call_backward_op_dispatch( + nobag=False, + weighted=True, + vbe=vbe, + is_gwd=is_gwd, ) - %} - static auto split_embedding_codegen_{{ wdesc }}_backward_op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::{{ backward_op }}", "") - .typed(); - {%- endfor %} {#-/* for weighted */#} - - const auto grad_dev_weights = !indice_weights.defined() ? - {%- for weighted in [False, True] %} - {%- set wdesc = "weighted" if weighted else "unweighted" %} - split_embedding_codegen_{{ wdesc }}_backward_op.call( - grad_output, - host_weights, - dev_weights, - uvm_weights, - lxu_cache_weights, - weights_placements, - weights_offsets, - D_offsets, - max_D, - hash_size_cumsum, - total_hash_size_bits, - indices, - offsets, - pooling_mode, - indice_weights, - lxu_cache_locations, - BT_block_size, - max_segment_length_per_warp, - {%- if optimizer != "none" %} - stochastic_rounding, - {%- endif %} - info_B_num_bits, - info_B_mask_int64, - {%- if vbe %} - B_offsets, - vbe_row_output_offsets, - vbe_b_t_map, - {%- endif %} - use_uniq_cache_locations_bwd, - use_homogeneous_placements, - {{ args_pt2.split_function_arg_names | join(", ") }} - {%- if not nobag %} - , output_dtype - {%- endif %} - ) {{ ":" if not weighted else ";" }} - {%- endfor %} {#-/* for weighted in [False, True] */#} - return { - Tensor(), // placeholder autograd tensor - Variable(), // output_dtype - Variable(), // host_weights - grad_dev_weights, // dev_weights - Variable(), // uvm_weights - Variable(), // lxu_cache_weights - Variable(), // weights_placements - Variable(), // weights_offsets - Variable(), // D_offsets - Variable(), // total_D - Variable(), // max_D - Variable(), // hash_size_cumsum - Variable(), //total_hash_size_bits - Variable(), // indices - Variable(), // offsets - Variable(), // pooling_mode - grad_indice_weights, // indice_weights - Variable(), // feature_requires_grad - Variable(), // lxu_cache_locations - Variable(), // uvm_cache_stats - {%- if optimizer != "none" %} - Variable(), // gradient_clipping - Variable(), // max_gradient - Variable(), // stochastic_rounding - {%- endif %} - {%- if vbe %} - Variable(), // B_offsets - Variable(), // vbe_output_offsets_feature_rank - Variable(), // vbe_B_offsets_rank_per_feature - Variable(), // max_B - Variable(), // max_B_feature_rank - Variable(), // vbe_output_size - {%- endif %} - Variable(), // is_experimental - Variable(), // use_uniq_cache_locations_bwd - Variable(), // use_homogeneous_placements - {{ args_pt2.split_variables | join(", ") }} - }; - {%- else %} - {%- set backward_nobag_op = - "split_embedding_nobag_backward_codegen_{}_unweighted_exact_pt2_wrapper".format( - optimizer + }} + } + // unweighted + {{ + call_backward_op_dispatch( + nobag=False, + weighted=False, + vbe=vbe, + is_gwd=is_gwd, + ) + }} + {%- else %} {#-/* if not nobag */#} + // nobag + Tensor grad_weights_dev; + {{ + call_backward_op_dispatch( + nobag=True, + weighted=False, + vbe=vbe, + is_gwd=is_gwd, ) - %} - - static auto split_embedding_nobag_codegen_backward_op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::{{ backward_nobag_op }}", "") - .typed(); + }} + {%- endif %} {#-/* if not nobag */#} - const auto grad_dev_weights = split_embedding_nobag_codegen_backward_op.call( - grad_output, - host_weights, - dev_weights, - uvm_weights, - lxu_cache_weights, - weights_placements, - weights_offsets, - D, - hash_size_cumsum, - total_hash_size_bits, - indices, - offsets, - lxu_cache_locations, - BT_block_size, - max_segment_length_per_warp, - {%- if optimizer != "none" %} - stochastic_rounding, - {%- endif %} - info_B_num_bits, - info_B_mask_int64, - {%- if vbe %} - B_offsets, - vbe_row_output_offsets, - vbe_b_t_map, - {%- endif %} - use_uniq_cache_locations_bwd, - use_homogeneous_placements, - {{ args_pt2.split_function_arg_names | join(", ") }} - ); - return { - Tensor(), // placeholder autograd tensor - Variable(), // output_dtype - Variable(), // host_weights - grad_dev_weights, // dev_weights - Variable(), // uvm_weights - Variable(), // lxu_cache_weights - Variable(), // weights_placements - Variable(), // weights_offsets - Variable(), // D - Variable(), // hash_size_cumsum - Variable(), // total_hash_size_bits - Variable(), // indices - Variable(), // offsets - Variable(), // lxu_cache_locations - Variable(), // uvm_cache_stats - {%- if optimizer != "none" %} - Variable(), // gradient_clipping - Variable(), // max_gradient - Variable(), // stochastic_rounding - {%- endif %} - {%- if vbe %} - Variable(), // B_offsets - Variable(), // vbe_output_offsets_feature_rank - Variable(), // vbe_B_offsets_rank_per_feature - Variable(), // max_B - Variable(), // max_B_feature_rank - Variable(), // vbe_output_size - {%- endif %} - Variable(), // is_experimental - Variable(), // use_uniq_cache_locations_bwd - Variable(), // use_homogeneous_placements - {{ args_pt2.split_variables | join(", ") }} - }; - {%- endif %} } }; -{%- endif %} {#-/* if not nobag or not vbe */#} +{%- endfor %} {#-/* for is_gwd */#} {%- endfor %} {#-/* for nobag in [True, False] */#} {%- endfor %} {#-/* for vbe in [True, False] */#} ///@ingroup embedding-cuda -Tensor split_embedding_codegen_lookup_{{ optimizer }}_function_pt2( +Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( const Tensor& placeholder_autograd_tensor, - const Tensor& host_weights, - const Tensor& dev_weights, - const Tensor& uvm_weights, + const at::TensorList weights, const Tensor& lxu_cache_weights, - const Tensor& weights_placements, - const Tensor& weights_offsets, const Tensor& D_offsets, - const int64_t total_D, - const int64_t max_D, + const c10::SymInt total_D, + const c10::SymInt max_D, const Tensor& hash_size_cumsum, const int64_t total_hash_size_bits, const Tensor& indices, @@ -781,105 +977,96 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function_pt2( const double max_gradient, const bool stochastic_rounding, {%- endif %} - {{ args_pt2.split_function_args | join(", ") }}, + {{ args_pt2.unified_pt2.split_function_args | join(", ") }}, const int64_t output_dtype = static_cast(SparseType::FP32), - const std::optional& B_offsets = std::optional(), - const std::optional& vbe_output_offsets_feature_rank = std::optional(), - const std::optional& vbe_B_offsets_rank_per_feature = std::optional(), + const std::optional& B_offsets = c10::nullopt, + const std::optional& vbe_output_offsets_feature_rank = c10::nullopt, + const std::optional& vbe_B_offsets_rank_per_feature = c10::nullopt, const c10::SymInt max_B = -1, const c10::SymInt max_B_feature_rank = -1, const c10::SymInt vbe_output_size = -1, - const bool is_experimental = false, + const bool is_experimental_tbe = false, // formerly named is_experimental const bool use_uniq_cache_locations_bwd = false, const bool use_homogeneous_placements = false, - const std::optional& uvm_cache_stats = std::optional()) { - {%- for vbe in ([True, False] if has_vbe_support else [False]) %} - {%- if has_vbe_support %} - {%- if vbe %} - if (B_offsets.has_value()) { - {%- else %} - else { // if (B_offsets.has_value()) - {%- endif %} - {%- endif %} {#-/* if has_vbe_support */#} - {%- for nobag in [True, False] %} - {%- set vbe = False if nobag else vbe %} - {%- set autograd_func = "Split{}{}LookupFunction_{}_Op_pt2".format( - "NoBag" if nobag else "", - "VBE" if vbe else "", - optimizer - ) - %} - {%- if nobag %} - if (static_cast(pooling_mode) == PoolingMode::NONE) { - {%- else %} - else { + const std::optional& uvm_cache_stats = c10::nullopt, + {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} + const std::optional& prev_iter_dev = c10::nullopt, {%- endif %} - return {{ autograd_func }}::apply( - placeholder_autograd_tensor, - output_dtype, - host_weights, - dev_weights, - uvm_weights, - lxu_cache_weights, - weights_placements, - weights_offsets, - {%- if nobag %} - max_D, - {%- else %} - D_offsets, - total_D, - max_D, - {%- endif %} - hash_size_cumsum, - total_hash_size_bits, - indices, - offsets, - {%- if not nobag %} - pooling_mode, - indice_weights, - feature_requires_grad, - {%- endif %} - lxu_cache_locations, - uvm_cache_stats, - {%- if optimizer != "none" %} - gradient_clipping, - max_gradient, - stochastic_rounding, - {%- endif %} - {%- if vbe %} - B_offsets, - vbe_output_offsets_feature_rank, - vbe_B_offsets_rank_per_feature, - max_B, - max_B_feature_rank, - vbe_output_size, - {%- endif %} - is_experimental, - use_uniq_cache_locations_bwd, - use_homogeneous_placements, - {{ args_pt2.split_function_arg_names | join(", ") }} - )[0]; - } - {%- endfor %} {#-/* for nobag */#} + {%- if "iter" not in args_pt2.split_function_arg_names %} + const int64_t iter = 0, + {%- endif %} + const bool apply_global_weight_decay = false, + {%- if ssd %} + const c10::optional& ssd_tensors = c10::nullopt, + {%- endif %} + const double gwd_lower_bound = 0 +) { + {%- if has_gpu_support or has_cpu_support %} + + {%- if not dense %} + // Load the config value from JK once + static auto is_tbev2_enabled = config::is_feature_enabled(config::FeatureGateName::TBE_V2); + + // Set to experimental if either the feature is enabled in JK, or the user specifies to use TBEv2 + const auto is_experimental = is_tbev2_enabled || is_experimental_tbe; + {%- endif %} + + {%- if not ssd %} {%- if has_vbe_support %} + // has vbe support and on gpu + if (B_offsets.has_value() && !(weights[0].numel() > 0)) { + {%- if has_global_weight_decay_support %} + // vbe and has gwd support + if (apply_global_weight_decay && weight_decay > 0) { + {{ call_autograd(nobag=False, vbe=True, is_gwd=True) }} + } + {%- endif %} {#-/* if has_global_weight_decay_support */ #} + // vbe and no gwd support + {{ call_autograd(nobag=False, vbe=True, is_gwd=False) }} + } + {%- endif %} {#-/* if has_vbe_support */ #} + + {%- if has_global_weight_decay_support %} + // has gwd support + if (apply_global_weight_decay && weight_decay > 0) { + // not vbe and gwd + {{ call_autograd(nobag=False, vbe=False, is_gwd=True) }} } + {%- endif %} {#-/* if has_global_weight_decay_support */ #} + {%- endif %} {#-/* if not ssd */#} + + {%- if ssd %} + TORCH_CHECK( + ssd_tensors.value().size() == {{ ssd_tensors | length }}, + "SSD TBE expects {{ ssd_tensors | length }} in ssd_tensors"); {%- endif %} - {%- endfor %} {#-/* vbe */#} + + if (static_cast(pooling_mode) == PoolingMode::NONE) { + // no bag + {{ call_autograd(nobag=True, vbe=False, is_gwd=False) }} + } + else { + {{ call_autograd(nobag=False, vbe=False, is_gwd=False) }} + } + {%- else %} + TORCH_CHECK( + false, + "{{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function is deprecated. Please see https://github.com/pytorch/FBGEMM/discussions/1727 for more detail." + ); + return Tensor(); + {%- endif %} } TORCH_LIBRARY_FRAGMENT(fbgemm, m) { - m.def("split_embedding_codegen_lookup_{{ optimizer }}_function_pt2(" + {%- set op_name = "{}_embedding_codegen_lookup_{}_function_pt2".format(bwd_mdesc, optimizer) %} + m.def("{{ op_name }}(" " Tensor placeholder_autograd_tensor, " - " Tensor host_weights, " - " Tensor dev_weights, " - " Tensor uvm_weights, " + " Tensor[] weights, " " Tensor lxu_cache_weights, " - " Tensor weights_placements, " - " Tensor weights_offsets, " " Tensor D_offsets, " - " int total_D, " - " int max_D, " + " SymInt total_D, " + " SymInt max_D, " " Tensor hash_size_cumsum, " " int total_hash_size_bits, " " Tensor indices, " @@ -893,7 +1080,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " float max_gradient, " " bool stochastic_rounding, " {%- endif %} - " {{ args_pt2.split_function_schemas | join(", ") }}, " + " {{ args_pt2.unified_pt2.split_function_schemas | join(", ") }}, " " int output_dtype=0, " " Tensor? B_offsets=None, " " Tensor? vbe_output_offsets_feature_rank=None, " @@ -901,10 +1088,21 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " SymInt max_B=-1, " " SymInt max_B_feature_rank=-1, " " SymInt vbe_output_size=-1, " - " bool is_experimental=False, " + " bool is_experimental_tbe=False, " " bool use_uniq_cache_locations_bwd=False, " " bool use_homogeneous_placements=False, " - " Tensor? uvm_cache_stats=None" + " Tensor? uvm_cache_stats=None," + {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} + " Tensor? prev_iter_dev=None, " + {%- endif %} + {%- if "iter" not in args_pt2.split_function_arg_names %} + " int iter=0, " + {%- endif %} + " bool apply_global_weight_decay=False, " + {%- if ssd %} + " Tensor[]? ssd_tensors=None," + {%- endif %} + " float gwd_lower_bound=0 " ") -> Tensor", {PT2_COMPLIANT_TAG}); // We're playing a funny trick here: we're using the autograd @@ -913,18 +1111,18 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { // no autograd enabled, and all of the internal implementations redispatch // appropriately m.impl( - "split_embedding_codegen_lookup_{{ optimizer }}_function_pt2", + "{{ op_name }}", torch::dispatch( c10::DispatchKey::Autograd, - TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function_pt2))); + TORCH_FN({{ op_name }}))); m.impl( - "split_embedding_codegen_lookup_{{ optimizer }}_function_pt2", + "{{ op_name }}", torch::dispatch( c10::DispatchKey::Meta, - TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function_pt2))); + TORCH_FN({{ op_name }}))); DISPATCH_TO_CUDA( - "split_embedding_codegen_lookup_{{ optimizer }}_function_pt2", - split_embedding_codegen_lookup_{{ optimizer }}_function_pt2); + " {{ op_name }} ", + {{ op_name }} ); } {%- endif %} {#-/* if has_gpu_support or has_cpu_support */#} diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp index 7cafb32cd..c74355207 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp @@ -107,7 +107,7 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}_pt2_cpu_wrapper( } {% else %} {#-/* PT2 wrapper function for backward CPU */#} -Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_exact_pt2_cpu_wrapper( +Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_pt2_cpu_wrapper( const Tensor& grad_output, const Tensor& host_weights, const Tensor& /*dev_weights*/, @@ -208,7 +208,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { DISPATCH_TO_CPU("{{ embedding_codegen_forward_op }}_wrapper", {{ embedding_codegen_forward_op }}_cpu_wrapper); {%- else %} - {%- set embedding_codegen_backward_op = "split_embedding_backward_codegen_{}_{}_exact_pt2".format( + {%- set embedding_codegen_backward_op = "split_embedding_backward_codegen_{}_{}_pt2".format( optimizer, wdesc ) %} diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp index ded9bf261..623947e2e 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp @@ -30,6 +30,18 @@ using Tensor = at::Tensor; using namespace fbgemm_gpu; +{#/* Module description */#} +{%- set fwd_mdesc = "ssd" if ssd else ("dense" if dense else "split") %} +{%- set bwd_mdesc = "ssd" if ssd else "split" %} + +{%- if ssd %} +enum SSDTensor { + {%- for tensor in ssd_tensors %} + {{ tensor | upper }} = {{ loop.index - 1 }}, + {%- endfor %} +}; +{%- endif %} + {%- for vbe in ([True, False] if has_vbe_support else [False]) %} {%- set vdesc = "_vbe" if vbe else "" %} @@ -37,10 +49,22 @@ using namespace fbgemm_gpu; {%- for nobag in ([False] if (weighted or vbe) else [True, False]) %} {%- set wdesc = "weighted" if weighted else "unweighted" %} {%- set ndesc = "_nobag" if nobag else "" %} +{%- for is_gwd in ([True, False] + if is_valid_gwd_config( + dense, + nobag, + vbe, + is_index_select, + True, + ssd) + else [False]) %} +{%- set gwddesc = "_gwd" if is_gwd else "" %} +{%- set desc_suffix = wdesc + vdesc + gwddesc %} {%- if is_forward %} {#-/* PT2 wrapper function for forward CUDA */#} -Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cuda_wrapper( +{%- for dispatch_type in ["cuda", "meta"] %} +Tensor {{ fwd_mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ desc_suffix }}_pt2_{{ dispatch_type }}_wrapper( const Tensor& /*host_weights*/, const Tensor& dev_weights, const Tensor& uvm_weights, @@ -54,14 +78,14 @@ Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cud const c10::SymInt total_D, const c10::SymInt max_D, {%- endif %} - const Tensor& /*hash_size_cumsum*/, + const Tensor& hash_size_cumsum, const Tensor& indices, const Tensor& offsets, {%- if not nobag %} const int64_t pooling_mode, const Tensor& indice_weights, // CPU always takes indice_weights {%- endif %} - const Tensor& lxu_cache_locations, + const Tensor& {{ "ssd_row_addrs" if ssd else "lxu_cache_locations" }}, const Tensor& uvm_cache_stats, {%- if vbe %} const Tensor& vbe_row_output_offsets, @@ -70,48 +94,61 @@ Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cud const int64_t info_B_num_bits, const int64_t info_B_mask_int64, {%- endif %} + {%- if is_gwd %} + const Tensor& prev_iter_dev, + const double learning_rate, + const double weight_decay, + const int64_t iter, + const double gwd_lower_bound, + {%- endif %} const bool is_experimental, const int64_t output_dtype ){ - {%- set op = "split_embedding{}_codegen_forward_{}{}_cuda".format( - ndesc, wdesc, vdesc + {%- set op = "{}_embedding{}_codegen_forward_{}_cuda".format( + fwd_mdesc, ndesc, desc_suffix ) %} static auto op = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::{{ op }}", "") .typed(); @@ -137,7 +174,11 @@ Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cud {%- if weighted %} indice_weights, {%- endif %} + {%- if ssd %} + ssd_row_addrs, + {%- else %} lxu_cache_locations, + {%- endif %} uvm_cache_stats, output_dtype, {%- if vbe %} @@ -147,131 +188,22 @@ Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cud info_B_num_bits, info_B_mask_int64, {%- endif %} + {%- if is_gwd %} + hash_size_cumsum, + prev_iter_dev, + learning_rate, + weight_decay, + iter, + gwd_lower_bound, + {%- endif %} {# /* if is_gwd */ #} is_experimental ); }; - -{#-/* PT2 wrapper function for forward META */#} -Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_meta_wrapper( - const Tensor& /*host_weights*/, - const Tensor& dev_weights, - const Tensor& uvm_weights, - const Tensor& lxu_cache_weights, - const Tensor& weights_placements, - const Tensor& weights_offsets, - {%- if nobag %} - const c10::SymInt D, - {%- else %} - const Tensor& D_offsets, - const c10::SymInt total_D, - const c10::SymInt max_D, - {%- endif %} - const Tensor& /*hash_size_cumsum*/, - const Tensor& indices, - const Tensor& offsets, - {%- if not nobag %} - const int64_t pooling_mode, - const Tensor& indice_weights, // CPU always takes indice_weights - {%- endif %} - const Tensor& lxu_cache_locations, - const Tensor& uvm_cache_stats, - {%- if vbe %} - const Tensor& vbe_row_output_offsets, - const Tensor& vbe_b_t_map, - const c10::SymInt vbe_output_size, - const int64_t info_B_num_bits, - const int64_t info_B_mask_int64, - {%- endif %} - const bool is_experimental, - const int64_t output_dtype - ){ - {%- set op = "split_embedding{}_codegen_forward_{}{}_cuda".format( - ndesc, wdesc, vdesc - ) - %} - static auto op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::{{ op }}", "") - .typed(); - - return op.call( - dev_weights, - uvm_weights, - lxu_cache_weights, - weights_placements, - weights_offsets, - {%- if not nobag %} - D_offsets, - {%- else %} - D, - {%- endif %} - {%- if not nobag %} - total_D, - {%- endif %} - {%- if not nobag %} - max_D, - {%- endif %} - indices, - offsets, - {%- if not nobag %} - pooling_mode, - {%- endif %} - {%- if weighted %} - indice_weights, - {%- endif %} - lxu_cache_locations, - uvm_cache_stats, - output_dtype, - {%- if vbe %} - vbe_row_output_offsets, - vbe_b_t_map, - vbe_output_size, - info_B_num_bits, // int32_t - info_B_mask_int64, // uint32_t - {%- endif %} - is_experimental - ); - } +{%- endfor %} {#-/*for dispatch_type in ["cuda", "meta"]*/#} {%- else %} {#-/* PT2 wrapper function for backward CUDA */#} -Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_exact{{ vdesc }}_pt2_cuda_wrapper( +Tensor {{ bwd_mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ desc_suffix }}_pt2_cuda_wrapper( const Tensor& grad_output, const Tensor& /*host_weights*/, const Tensor& dev_weights, @@ -293,7 +225,11 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e const int64_t pooling_mode, const Tensor& indice_weights, // currently supports no bag with unweighted {%- endif %} + {%- if ssd %} + const Tensor& ssd_row_addrs, + {%- elif not dense %} const Tensor& lxu_cache_locations, + {%- endif %} const int64_t BT_block_size, const int64_t max_segment_length_per_warp, {%- if optimizer != "none" %} @@ -308,55 +244,73 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e {%- endif %} const bool use_uniq_cache_locations, const bool use_homogeneous_placements, + {%- if is_gwd %} + {%- if "prev_iter_dev" not in args.split_function_arg_names %} + const Tensor& prev_iter_dev, + {%- endif %} + {%- if "iter" not in args.split_function_arg_names %} + const int64_t iter, + {%- endif %} + const double gwd_lower_bound, + {%- endif %} {{ args_pt2.split_function_args | join(", ") }} {%- if not nobag %} , const int64_t output_dtype = static_cast(SparseType::FP32) {%- endif %}){ - {%- set backward_op = "split_embedding{}_backward_codegen_{}_{}_exact{}_cuda".format( - ndesc, optimizer, wdesc, vdesc + {%- set backward_op = "{}_embedding{}_backward_codegen_{}_{}_exact_cuda".format( + bwd_mdesc, ndesc, optimizer, desc_suffix ) %} static auto op = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::{{ backward_op }}", "") .typed Tensor" @@ -547,8 +547,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { ); m.impl("{{ embedding_codegen_forward_op }}_wrapper", torch::dispatch(c10::DispatchKey::Meta, TORCH_FN({{ embedding_codegen_forward_op }}_meta_wrapper))); {%- else %} - {%- set embedding_codegen_backward_op = "split_embedding{}_backward_codegen_{}_{}_exact{}_pt2".format( - ndesc, optimizer, wdesc, vdesc + {%- set embedding_codegen_backward_op = "{}_embedding{}_backward_codegen_{}_{}_pt2".format( + bwd_mdesc, ndesc, optimizer, desc_suffix ) %} m.def("{{ embedding_codegen_backward_op }}_wrapper(" @@ -573,7 +573,11 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " int pooling_mode, " " Tensor indice_weights, " {%- endif %} + {%- if ssd %} + " Tensor ssd_row_addrs, " + {%- else %} " Tensor lxu_cache_locations, " + {%- endif %} " int BT_block_size, " " int max_segment_length_per_warp, " {%- if optimizer != "none" %} @@ -588,6 +592,15 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {%- endif %} " bool use_uniq_cache_locations, " " bool use_homogeneous_placements," + {%- if is_gwd %} + {%- if "prev_iter_dev" not in args.split_function_arg_names %} + " Tensor prev_iter_dev, " + {%- endif %} + {%- if "iter" not in args.split_function_arg_names %} + " int iter, " + {%- endif %} + " float gwd_lower_bound, " + {%- endif %} " {{ args_pt2.split_function_schemas | join(", ") }} " {%- if not nobag %} " , int output_dtype=0 " @@ -598,12 +611,13 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {{ embedding_codegen_backward_op }}_cuda_wrapper ); {%- endif %} + {%- endfor %} {#-/*for is_gwd*/#} {%- endfor %} {#-/*for nobag*/#} {%- endfor %} {#-/*for weighted*/#} {%- if is_forward %} {%- set embedding_codegen_grad_indice_weights_op = - "split_embedding_codegen_grad_indice_weights{}_pt2".format( - vdesc + "{}_embedding_codegen_grad_indice_weights{}_pt2".format( + fwd_mdesc, vdesc ) %} m.def("{{ embedding_codegen_grad_indice_weights_op }}_wrapper(" @@ -618,7 +632,11 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " SymInt max_D, " " Tensor indices, " " Tensor offsets, " + {%- if ssd %} + " Tensor ssd_row_addrs, " + {%- else %} " Tensor lxu_cache_locations, " + {%- endif %} {%- if vbe %} " Tensor feature_requires_grad, " " Tensor vbe_row_output_offsets, " diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp index 85d23cc94..1098378d0 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp @@ -49,6 +49,9 @@ void bounds_check_indices_cpu( const std::optional& weights, const std::optional& B_offsets, const int64_t max_B) { + if (offsets.scalar_type() != indices.scalar_type()) { + offsets = offsets.toType(indices.scalar_type()); + } const auto vbe = B_offsets.has_value(); if (vbe) { TENSOR_NDIM_EQUALS(B_offsets.value(), 1); diff --git a/fbgemm_gpu/docs/requirements.txt b/fbgemm_gpu/docs/requirements.txt index 9f3bca439..533232736 100644 --- a/fbgemm_gpu/docs/requirements.txt +++ b/fbgemm_gpu/docs/requirements.txt @@ -14,7 +14,7 @@ sphinx<7 breathe bs4 -docutils +docutils<0.20,>=0.18.1 lxml myst-parser sphinx-lint diff --git a/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu index f0c689570..4712f620f 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cu @@ -67,17 +67,26 @@ DEVICE_INLINE bf16x8 add_bf16x8(bf16x8 a, bf16x8 b) { return c; } -template -__global__ void one_shot_all_reduce( - int32_t rank, - int32_t world_size, - int32_t flag, - std::array barriers, - std::array inputs, +template +#if defined(USE_ROCM) +__launch_bounds__(512) +#endif + __global__ void one_shot_all_reduce( + int32_t rank, + int32_t world_size, + int32_t flag, + std::array barriers, + std::array inputs, +#if defined(USE_ROCM) + at::BFloat16* __restrict__ ar_input, + at::BFloat16* __restrict__ acc, + at::BFloat16* __restrict__ output, +#else at::BFloat16* ar_input, at::BFloat16* acc, at::BFloat16* output, - int32_t N) { +#endif + int32_t N) { // It is expensive to launch hipMemcpyAsync on ROCm // Move data copy here. Each block copies part of input data at::BFloat16* input = inputs[rank]; @@ -143,11 +152,11 @@ __global__ void one_shot_all_reduce( // Sum the values from the different ranks. bf16x8 sums; - if (acc) { + if constexpr (has_acc) { *reinterpret_cast(&sums) = *reinterpret_cast(&acc[i]); } else { - memset(reinterpret_cast(&sums), 0, sizeof(sums)); + *reinterpret_cast(&sums) = uint4{0}; } #pragma unroll kWorldSize @@ -336,15 +345,24 @@ static DEVICE_INLINE void ld_flag_acquire(int32_t& flag, int32_t* flag_addr) { #endif } -template +template +#if defined(USE_ROCM) +__launch_bounds__(512) __global__ void two_shot_all_reduce( +#else __launch_bounds__(1024) __global__ void two_shot_all_reduce( +#endif int32_t rank, int32_t world_size, int32_t flag, std::array barriers, std::array inputs, +#if defined(USE_ROCM) + at::BFloat16* __restrict__ acc, + at::BFloat16* __restrict__ output, +#else at::BFloat16* acc, at::BFloat16* output, +#endif int32_t N) { int32_t N_per_rank = N / kWorldSize; int32_t N_start = N_per_rank * rank; @@ -374,13 +392,11 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce( __syncthreads(); at::BFloat16* src_d[kWorldSize]; - int dst_rank[kWorldSize]; #pragma unroll kWorldSize for (int ii = 0; ii < kWorldSize; ++ii) { int d_rank = (rank + ii) % kWorldSize; src_d[ii] = inputs[d_rank]; - dst_rank[ii] = d_rank; } // Each block accumulates the values from the different GPUs on the same @@ -395,11 +411,12 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce( } bf16x8 sums; - if (acc) { + + if constexpr (has_acc) { *reinterpret_cast(&sums) = *reinterpret_cast(&acc[i + N_start]); } else { - memset(reinterpret_cast(&sums), 0, sizeof(sums)); + *reinterpret_cast(&sums) = uint4{0}; } #pragma unroll kWorldSize @@ -433,11 +450,19 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce( // Gather all needed elts from other intra-node ranks for (size_t i = threadIdx.x * 8 + blockIdx.x * blockDim.x * 8; i < N_per_rank; i += gridDim.x * blockDim.x * 8) { + uint4 temp[kWorldSize]; +#pragma unroll kWorldSize + for (int ii = 0; ii < kWorldSize; ++ii) { + int d_rank = (rank + ii) % kWorldSize; + int i_r = N_start + i + (d_rank - rank) * N_per_rank; + temp[ii] = reinterpret_cast(&src_d[ii][i_r])[0]; + } + #pragma unroll kWorldSize for (int ii = 0; ii < kWorldSize; ++ii) { - int i_r = N_start + i + (dst_rank[ii] - rank) * N_per_rank; - *reinterpret_cast(&output[i_r]) = - reinterpret_cast(&src_d[ii][i_r])[0]; + int d_rank = (rank + ii) % kWorldSize; + int i_r = N_start + i + (d_rank - rank) * N_per_rank; + *reinterpret_cast(&output[i_r]) = temp[ii]; } } } @@ -474,7 +499,11 @@ void one_shot_car_allreduce( constexpr int32_t N_per_thread = 8; constexpr int32_t N_per_warp = N_per_thread * kThreadsPerWarp; TORCH_CHECK(N % N_per_warp == 0); +#if defined(USE_ROCM) + constexpr int32_t kThreadsPerBlock = 512; +#else constexpr int32_t kThreadsPerBlock = 1024; +#endif constexpr int32_t kMaxBlocks = 24; dim3 threads(0, 1, 1); @@ -494,21 +523,37 @@ void one_shot_car_allreduce( threads.x = threads_per_block; } -#define X(kWorldSize) \ - if (state->world_size_ == kWorldSize) { \ - one_shot_all_reduce \ - <<>>( \ - state->rank_, \ - state->world_size_, \ - state->flag_ * state->world_size_, \ - barriers, \ - inputs, \ - y.data_ptr(), \ - z ? z->data_ptr() : nullptr, \ - y_allreduce.data_ptr(), \ - N); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - return; \ +#define X(kWorldSize) \ + if (state->world_size_ == kWorldSize) { \ + if (z) { \ + one_shot_all_reduce \ + <<>>( \ + state->rank_, \ + state->world_size_, \ + state->flag_ * state->world_size_, \ + barriers, \ + inputs, \ + y.data_ptr(), \ + z->data_ptr(), \ + y_allreduce.data_ptr(), \ + N); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + return; \ + } else { \ + one_shot_all_reduce \ + <<>>( \ + state->rank_, \ + state->world_size_, \ + state->flag_ * state->world_size_, \ + barriers, \ + inputs, \ + y.data_ptr(), \ + nullptr, \ + y_allreduce.data_ptr(), \ + N); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + return; \ + } \ } TORCH_CHECK( @@ -520,7 +565,7 @@ void one_shot_car_allreduce( #undef X return; -} +} // namespace fbgemm_gpu void two_shot_car_allreduce( at::Tensor y_allreduce, @@ -565,26 +610,46 @@ void two_shot_car_allreduce( TORCH_CHECK(N_per_rank % N_per_thread == 0); auto threads_per_rank = N_per_rank / N_per_thread; +#if defined(USE_ROCM) + constexpr int32_t kThreadsPerBlock = 512; +#else constexpr int32_t kThreadsPerBlock = 1024; +#endif + constexpr int32_t kMaxBlocks = 24; auto blocks = std::min( cuda_calc_block_count(threads_per_rank, kThreadsPerBlock), kMaxBlocks); -#define X(kWorldSize) \ - if (state->world_size_ == kWorldSize) { \ - two_shot_all_reduce \ - <<>>( \ - state->rank_, \ - state->world_size_, \ - state->flag_ * state->world_size_, \ - barriers, \ - inputs, \ - z ? z->data_ptr() : nullptr, \ - y_allreduce.data_ptr(), \ - N); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - return; \ +#define X(kWorldSize) \ + if (state->world_size_ == kWorldSize) { \ + if (z) { \ + two_shot_all_reduce \ + <<>>( \ + state->rank_, \ + state->world_size_, \ + state->flag_ * state->world_size_, \ + barriers, \ + inputs, \ + z->data_ptr(), \ + y_allreduce.data_ptr(), \ + N); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + return; \ + } else { \ + two_shot_all_reduce \ + <<>>( \ + state->rank_, \ + state->world_size_, \ + state->flag_ * state->world_size_, \ + barriers, \ + inputs, \ + nullptr, \ + y_allreduce.data_ptr(), \ + N); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + return; \ + } \ } TORCH_CHECK( diff --git a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp index 3d5d337ec..38ed3ec6b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp @@ -133,7 +133,9 @@ std::tuple dequantize_fp8_cache( at::Tensor cache_V, at::Tensor kv_seqlen, std::optional qparam_k, - std::optional qparam_v); + std::optional qparam_v, + std::optional block_tables, + int64_t page_size); at::Tensor mqa_attn( at::Tensor XQ, @@ -162,7 +164,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "dequantize_int4_cache(Tensor cache_K, Tensor cache_V, Tensor kv_seqlen, int? num_groups=1) -> (Tensor, Tensor)"); m.impl("dequantize_int4_cache", dequantize_int4_cache); m.def( - "dequantize_fp8_cache(Tensor cache_K, Tensor cache_V, Tensor kv_seqlen, Tensor? qparam_k=None, Tensor? qparam_v=None) -> (Tensor, Tensor)"); + "dequantize_fp8_cache(Tensor cache_K, Tensor cache_V, Tensor kv_seqlen, Tensor? qparam_k=None, Tensor? qparam_v=None, Tensor? block_tables=None, int page_size=" STRING( + DEFAULT_PAGE_SIZE) ") -> (Tensor, Tensor)"); m.impl("dequantize_fp8_cache", dequantize_fp8_cache); } diff --git a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu index 0728e6b9a..787c0547c 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu @@ -795,11 +795,27 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_( } else { __half2* qparam_row = nullptr; auto T = cache_K.size(1); - auto idx = b * (T * N_KVH) + (size_t)cache_loc_t * N_KVH + h; - if (qkv == QKV::K) { - qparam_row = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]); + if (block_tables == nullptr) { + auto idx = b * (T * N_KVH) + (size_t)cache_loc_t * N_KVH + h; + if (qkv == QKV::K) { + qparam_row = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]); + } else { + qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]); + } } else { - qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]); + // This is duplicate computation with get_dst_row above. + // TODO: Maybe clean up and merge later. + int page_logical_idx = cache_loc_t / page_size; + int page_offset = cache_loc_t % page_size; + int page_physical_idx = + block_tables[b * block_tables_b_stride + page_logical_idx]; + int physical_t = page_physical_idx * page_size + page_offset; + auto idx = physical_t * N_KVH + h; + if (qkv == QKV::K) { + qparam_row = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]); + } else { + qparam_row = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]); + } } quantize_fp8_kv(dst, dst_row_q, qparam_row); } @@ -1477,16 +1493,113 @@ __global__ void dequantize_fp8_cache_kernel( *reinterpret_cast(&kv_dq.vals[2]); } } + +// Cloned from dequantize_fp8_cache_kernel because +// branching inside the original kernel runs into +// "too many resources requested for launch" which +// necessitates decreasing the number of warps per block, +// which might have performance implications. Also we +// might have more diverging behaviors for paged kernel +// as noted in the comment below so we will keep a separate +// kernel for now. +__global__ void dequantize_fp8_cache_kernel_paged( + // This code currently represents FP8 version not int4 + at::PackedTensorAccessor64 + cache_K, // [1][MAX_PAGE * PAGE_SIZE][N_KVH][D_H] + at::PackedTensorAccessor64 + cache_V, // [1][MAX_PAGE * PAGE_SIZE][N_KVH][D_H // G] + at::PackedTensorAccessor32 kv_seqlen, + at::PackedTensorAccessor64 + cache_K_dq, // [1][MAX_T][N_KVH][D_H] + at::PackedTensorAccessor64 + cache_V_dq, // [1][MAX_T][N_KVH][D_H] + int32_t* qparam_k_ptr, + int32_t* qparam_v_ptr, + int32_t* block_tables, + int32_t block_tables_b_stride, + int32_t page_size) { + auto N_KVH = cache_K.size(2); + auto MAX_T = cache_K.size(1); + auto D_H = cache_K_dq.size(3); + auto D_H_q = cache_K.size(3); + CUDA_KERNEL_ASSERT(D_H == 128); + + auto b = blockIdx.x; + // only need to dequantize this far. + auto max_t = kv_seqlen[b]; + + // one warp per T/H + for (auto t_h = threadIdx.y + blockIdx.y * blockDim.y; t_h < max_t * N_KVH; + t_h += blockDim.y * gridDim.y) { + auto h = t_h % N_KVH; + auto t = t_h / N_KVH; + + int page_logical_idx = t / page_size; + int page_offset = t % page_size; + int page_physical_idx = + block_tables[b * block_tables_b_stride + page_logical_idx]; + int physical_t = page_physical_idx * page_size + page_offset; + + uint8_t* row_k = &cache_K[0][physical_t][h][0]; + uint8_t* row_v = &cache_V[0][physical_t][h][0]; + + bfx8 kv_dq; + uint8_t qparam_offset_bytes; + __half2* qparam_k_src; + __half2* qparam_v_src; + if (qparam_k_ptr) { + // read from standalone qparam tensor + qparam_offset_bytes = 0; + auto idx = physical_t * N_KVH + h; + qparam_k_src = reinterpret_cast<__half2*>(&qparam_k_ptr[idx]); + qparam_v_src = reinterpret_cast<__half2*>(&qparam_v_ptr[idx]); + } else { + // read from first row + qparam_offset_bytes = 4; + qparam_k_src = reinterpret_cast<__half2*>(&row_k[0]); + qparam_v_src = reinterpret_cast<__half2*>(&row_v[0]); + } + // Assert the quantized row dim is as expected + CUDA_KERNEL_ASSERT(D_H_q - D_H == qparam_offset_bytes); + if (4 * threadIdx.x >= D_H) { + continue; + } + // each thread reads 4 x 8 bits + + uint64_t kq = *reinterpret_cast( + &row_k[threadIdx.x * 4 + qparam_offset_bytes]); + uint64_t vq = *reinterpret_cast( + &row_v[threadIdx.x * 4 + qparam_offset_bytes]); + + uint64_t packed = kq | (vq << 32); + + kv_dq = dequantize_packed_fp8(packed, *qparam_k_src, *qparam_v_src); + + // now, write our outputs + auto* row_k_dq = &cache_K_dq[0][physical_t][h][0]; + auto* row_v_dq = &cache_V_dq[0][physical_t][h][0]; + // each thread writes 4 elements of type bf16 + *reinterpret_cast(&row_k_dq[4 * threadIdx.x]) = + *reinterpret_cast(&kv_dq.vals[0]); + *reinterpret_cast(&row_v_dq[4 * threadIdx.x]) = + *reinterpret_cast(&kv_dq.vals[2]); + } +} std::tuple dequantize_fp8_cache( at::Tensor cache_K, at::Tensor cache_V, at::Tensor kv_seqlen, std::optional qparam_k, - std::optional qparam_v) { + std::optional qparam_v, + std::optional block_tables, + int64_t page_size) { TORCH_CHECK(cache_K.is_cuda()); TORCH_CHECK(cache_V.is_cuda()); TORCH_CHECK(kv_seqlen.is_cuda()); - auto B = cache_K.size(0); + auto B = kv_seqlen.size(0); + // vanilla: B_KV = B, paged: B_KV = 1 + auto B_KV = cache_K.size(0); + // vanilla: MAX_T = MAX_T, paged: MAX_T = MAX_PAGE * PAGE_SIZE auto MAX_T = cache_K.size(1); auto N_KVH = cache_K.size(2); auto D_HQ = cache_K.size(3); @@ -1500,31 +1613,72 @@ std::tuple dequantize_fp8_cache( } auto D_H = (D_HQ - fp8_qparam_offset); - auto cache_K_dq = - at::empty({B, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16)); - auto cache_V_dq = - at::empty({B, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16)); + // TODO: + // The below allocates Tensors that have the same shape as cache_K and cache_V + // to store their dequantize results. For paged KV cache, this can be a bit + // inefficient because it has the shape of [1 x (MAX_PAGES * PAGE_SIZE) x + // N_KVH x D_H] to accommodate pages globally across batch instances, and + // if we have very large MAX_PAGES then we are essentially allocating a very + // huge Tensor here. The benefit is that the following users of this + // dequantized results can reuse the existing block_tables to access their + // elements. If we want to be more efficient, there are two possible + // approaches: (1) Allocate a shorter Tensor here and store the dequantize + // results in a more compact manner, but that requires creating a new + // block_tables here and making sure the following users all use the + // correct block_tables. (2) From outside, keep a persistent buffer that has a + // matching shape with the original paged KV and feed the same buffer + // into this function at every layer to reuse it and prevent allocation. + auto cache_K_dq = at::empty( + {B_KV, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16)); + auto cache_V_dq = at::empty( + {B_KV, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16)); if (B == 0) { return {cache_K_dq, cache_V_dq}; } + int32_t* block_tables_ptr = nullptr; + int32_t block_tables_b_stride = 0; + if (block_tables.has_value()) { + block_tables_ptr = static_cast(block_tables.value().data_ptr()); + block_tables_b_stride = block_tables.value().stride(0); + } + constexpr int32_t kMaxBlocks = 256; dim3 blocks(B, std::max(1, kMaxBlocks / B)); dim3 threads(kThreadsPerWarp, kWarpsPerBlock); - dequantize_fp8_cache_kernel<<< - blocks, - threads, - 0, - at::cuda::getCurrentCUDAStream()>>>( - cache_K.packed_accessor64(), - cache_V.packed_accessor64(), - kv_seqlen.packed_accessor32(), - cache_K_dq.packed_accessor64(), - cache_V_dq.packed_accessor64(), - qparam_k_ptr, - qparam_v_ptr); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + if (block_tables_ptr == nullptr) { + dequantize_fp8_cache_kernel<<< + blocks, + threads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + cache_K.packed_accessor64(), + cache_V.packed_accessor64(), + kv_seqlen.packed_accessor32(), + cache_K_dq.packed_accessor64(), + cache_V_dq.packed_accessor64(), + qparam_k_ptr, + qparam_v_ptr); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + dequantize_fp8_cache_kernel_paged<<< + blocks, + threads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + cache_K.packed_accessor64(), + cache_V.packed_accessor64(), + kv_seqlen.packed_accessor32(), + cache_K_dq.packed_accessor64(), + cache_V_dq.packed_accessor64(), + qparam_k_ptr, + qparam_v_ptr, + block_tables_ptr, + block_tables_b_stride, + page_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } return {cache_K_dq, cache_V_dq}; } @@ -1606,7 +1760,9 @@ std::tuple dequantize_fp8_cache( at::Tensor cache_V, at::Tensor kv_seqlen, std::optional qparam_k, - std::optional qparam_v) { + std::optional qparam_v, + std::optional block_tables, + int64_t page_size) { throw std::runtime_error( "CUDA version is older than 12.0"); // requires CUDA>=12 } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip index 94c9dbe13..034f66fa6 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip @@ -15,6 +15,7 @@ #if !defined(USE_ROCM) #include #endif + #include #if defined(USE_ROCM) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip index b57563df5..066a84c6d 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip @@ -18,6 +18,7 @@ #if !defined(USE_ROCM) #include #endif + #include #if defined(USE_ROCM) @@ -128,13 +129,33 @@ static const std::unordered_map< fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, // EMU 1.6 Shapes. {{1536, 3584, 3584}, - fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1}, + fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, {{8192, 9728, 3584}, - fp8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, {{8192, 3584, 9728}, fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5}, {{8192, 3584, 3584}, - fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1}, + fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + {{4096, 3584, 3584}, + fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + {{768, 3584, 3584}, + fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + {{4096, 9728, 3584}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + {{4096, 3584, 9728}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + {{7200, 3584, 3584}, + fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + {{7200, 9728, 3584}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + {{7200, 3584, 9728}, + fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, + {{3600, 3584, 3584}, + fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + {{3600, 9728, 3584}, + fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + {{3600, 3584, 9728}, + fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, // Pro Shapes. {{32768, 128, 8192}, fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_tensorwise_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_tensorwise_gemm.hip index 67966a43c..4e925aed2 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_tensorwise_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_tensorwise_gemm.hip @@ -15,6 +15,7 @@ #if !defined(USE_ROCM) #include #endif + #include #if defined(USE_ROCM) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..acf927c05 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,69 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_common.h" + +at::Tensor +fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (M % 128 != 0) || (N % 32 != 0) || (K % 128 != 0); + + // This kernel seems optimal in the most purely compute bound tasks. + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2>; + // Run kernel instance. + return f8f8bf16_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip new file mode 100644 index 000000000..b47c082fd --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip @@ -0,0 +1,69 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_common.h" + +at::Tensor +fp8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (M % 256 != 0) || (N % 256 != 0) || (K % 64 != 0); + + // This kernel seems optimal in the most purely compute bound tasks. + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 256, + 64, + 16, + 16, + 8, + 8, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3>; + // Run kernel instance. + return f8f8bf16_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 256, + 64, + 16, + 16, + 8, + 8, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_kernel_manifest.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_kernel_manifest.h index db38343dc..146450ec1 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_kernel_manifest.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_kernel_manifest.h @@ -231,3 +231,21 @@ fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave at::Tensor x_scale, at::Tensor w_scale, at::Tensor Y); + +// Decent mid-size kernel. +at::Tensor +fp8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + +// Kernel for small but not too small batch sizes. +at::Tensor +fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 024170a68..21e858e1a 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -7,7 +7,7 @@ # pyre-strict import math -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional, Sequence, Tuple import torch @@ -1113,3 +1113,18 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None _setup() + + +@torch.library.register_fake("fbgemm::lengths_range") +def lengths_range_abstract( + lengths: Tensor, + output_shape: Optional[Sequence[int]] = None, +) -> Tensor: + torch._check(lengths.dim() == 1, lambda: "lengths must be a 1D tensor") + output_size = 0 + if output_shape is not None: + output_size = math.prod(output_shape) + else: + ctx = torch.library.get_ctx() + output_size = ctx.new_dynamic_size() + return lengths.new_empty([output_size], dtype=lengths.dtype) diff --git a/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py b/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py index 365aebbfc..afb931bb3 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py +++ b/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py @@ -68,6 +68,19 @@ def get(self, name: str) -> int: return self.config[name] +def sparse_type_to_int(sparse_type: "SparseType") -> int: + return { + SparseType.FP32.value: 0, + SparseType.FP16.value: 1, + SparseType.INT8.value: 2, + SparseType.INT4.value: 3, + SparseType.INT2.value: 4, + SparseType.BF16.value: 5, + SparseType.FP8.value: 6, + SparseType.MX4.value: 7, + }[sparse_type.value] + + @enum.unique class SparseType(enum.Enum): FP32 = "fp32" @@ -104,16 +117,7 @@ def from_int(ty: int) -> "SparseType": raise ValueError(f"Unsupported sparse type: {ty}") def as_int(self) -> int: - return { - SparseType.FP32.value: 0, - SparseType.FP16.value: 1, - SparseType.INT8.value: 2, - SparseType.INT4.value: 3, - SparseType.INT2.value: 4, - SparseType.BF16.value: 5, - SparseType.FP8.value: 6, - SparseType.MX4.value: 7, - }[self.value] + return sparse_type_to_int(self) @staticmethod def from_dtype(dtype: torch.dtype, is_mx: bool = False) -> "SparseType": diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index f8b691be4..d988563ae 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -17,7 +17,7 @@ import torch # usort:skip from torch import nn, Tensor # usort:skip -from fbgemm_gpu.split_embedding_configs import SparseType +from fbgemm_gpu.split_embedding_configs import sparse_type_to_int, SparseType from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( BoundsCheckMode, CacheAlgorithm, @@ -153,6 +153,26 @@ def random_quant_scaled_tensor( ) +@torch.fx.wrap +def inputs_to_device( + indices: torch.Tensor, + offsets: torch.Tensor, + per_sample_weights: Optional[torch.Tensor], + device: torch.device, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + if device.type == "meta": + return indices, offsets, per_sample_weights + + non_blocking = device.type != "cpu" + if indices.device != device: + indices = indices.to(device, non_blocking=non_blocking) + if offsets.device != device: + offsets = offsets.to(device, non_blocking=non_blocking) + if per_sample_weights is not None and per_sample_weights.device != device: + per_sample_weights = per_sample_weights.to(device, non_blocking=non_blocking) + return indices, offsets, per_sample_weights + + # pyre-fixme[13]: Attribute `cache_miss_counter` is never initialized. class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module): """ @@ -756,6 +776,10 @@ def forward( self.weight_initialized ), "weight needs to be initialized before forward function" + indices, offsets, per_sample_weights = inputs_to_device( + indices, offsets, per_sample_weights, self.bounds_check_warning.device + ) + # First bound check: check if the indices/offsets are within the boundary # of the original embedding rows before pruning. # Note that this is only applied when we enable pruning (if the perf becomes @@ -919,6 +943,7 @@ def reset_embedding_spec_location( for spec in self.embedding_specs ] + @torch.jit.export def recompute_module_buffers(self) -> None: """ Compute module buffers that're on meta device and are not materialized in reset_weights_placements_and_offsets(). @@ -931,7 +956,7 @@ def recompute_module_buffers(self) -> None: ): return - weights_tys_int = [e[3].as_int() for e in self.embedding_specs] + weights_tys_int = [sparse_type_to_int(e[3]) for e in self.embedding_specs] self.weights_tys = torch.tensor( [weights_tys_int[t] for t in self.feature_table_map], device=self.current_device, @@ -944,8 +969,9 @@ def recompute_module_buffers(self) -> None: dtype=torch.int64, ) dims = [e[2] for e in self.embedding_specs] - D_offsets_list = [dims[t] for t in self.feature_table_map] - D_offsets_list = [0] + list(accumulate(D_offsets_list)) + D_offsets_list = [0] + for t in self.feature_table_map: + D_offsets_list.append(dims[t] + D_offsets_list[-1]) self.D_offsets = torch.tensor( D_offsets_list, device=self.current_device, dtype=torch.int32 ) @@ -975,6 +1001,9 @@ def recompute_module_buffers(self) -> None: self.table_wise_cache_miss = torch.empty_like( self.table_wise_cache_miss, device=self.current_device ) + self.weights_uvm = torch.empty_like( + self.weights_uvm, device=self.current_device + ) def _apply_split( self, diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 04fac4a62..e2d35072f 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -659,6 +659,9 @@ def __init__( self.l2_num_cache_lookups_stats_name: str = ( f"l2_cache.perf.get.tbe_id{tbe_unique_id}.num_lookups" ) + self.l2_num_cache_evictions_stats_name: str = ( + f"l2_cache.perf.tbe_id{tbe_unique_id}.num_l2_cache_evictions" + ) self.l2_cache_free_mem_stats_name: str = ( f"l2_cache.mem.tbe_id{tbe_unique_id}.free_mem_bytes" ) @@ -686,6 +689,7 @@ def __init__( self.stats_reporter.register_stats(self.l2_num_cache_misses_stats_name) # pyre-ignore self.stats_reporter.register_stats(self.l2_num_cache_lookups_stats_name) + self.stats_reporter.register_stats(self.l2_num_cache_evictions_stats_name) self.stats_reporter.register_stats(self.l2_cache_free_mem_stats_name) self.stats_reporter.register_stats(self.l2_cache_capacity_stats_name) @@ -1744,13 +1748,15 @@ def _report_ssd_io_stats(self) -> None: self.step, self.stats_reporter.report_interval # pyre-ignore ) - if len(ssd_io_duration) != 3: - logging.error("ssd io duration should have 3 elements") + if len(ssd_io_duration) != 5: + logging.error("ssd io duration should have 5 elements") return ssd_read_dur_us = ssd_io_duration[0] - fwd_ssd_write_dur_us = ssd_io_duration[1] - bwd_ssd_write_dur_us = ssd_io_duration[2] + fwd_rocksdb_read_dur = ssd_io_duration[1] + fwd_l1_eviction_dur = ssd_io_duration[2] + bwd_l1_cnflct_miss_write_back_dur = ssd_io_duration[3] + flush_write_dur = ssd_io_duration[4] # pyre-ignore self.stats_reporter.report_duration( @@ -1762,15 +1768,29 @@ def _report_ssd_io_stats(self) -> None: # pyre-ignore self.stats_reporter.report_duration( iteration_step=self.step, - event_name="ssd.io_duration.fwd_write_us", - duration_ms=fwd_ssd_write_dur_us, + event_name="ssd.io_duration.write.fwd_rocksdb_read_us", + duration_ms=fwd_rocksdb_read_dur, + time_unit="us", + ) + # pyre-ignore + self.stats_reporter.report_duration( + iteration_step=self.step, + event_name="ssd.io_duration.write.fwd_l1_eviction_us", + duration_ms=fwd_l1_eviction_dur, + time_unit="us", + ) + # pyre-ignore + self.stats_reporter.report_duration( + iteration_step=self.step, + event_name="ssd.io_duration.write.bwd_l1_cnflct_miss_write_back_us", + duration_ms=bwd_l1_cnflct_miss_write_back_dur, time_unit="us", ) # pyre-ignore self.stats_reporter.report_duration( iteration_step=self.step, - event_name="ssd.io_duration.bwd_write_us", - duration_ms=bwd_ssd_write_dur_us, + event_name="ssd.io_duration.write.flush_write_us", + duration_ms=flush_write_dur, time_unit="us", ) @@ -1830,8 +1850,8 @@ def _report_l2_cache_perf_stats(self) -> None: self.step, stats_reporter.report_interval # pyre-ignore ) - if len(l2_cache_perf_stats) != 11: - logging.error("l2 perf stats should have 11 elements") + if len(l2_cache_perf_stats) != 15: + logging.error("l2 perf stats should have 15 elements") return num_cache_misses = l2_cache_perf_stats[0] @@ -1840,12 +1860,17 @@ def _report_l2_cache_perf_stats(self) -> None: get_cache_lookup_total_duration = l2_cache_perf_stats[3] get_cache_lookup_wait_filling_thread_duration = l2_cache_perf_stats[4] get_weights_fillup_total_duration = l2_cache_perf_stats[5] - total_cache_update_duration = l2_cache_perf_stats[6] - get_tensor_copy_for_cache_update_duration = l2_cache_perf_stats[7] - set_tensor_copy_for_cache_update_duration = l2_cache_perf_stats[8] + get_cache_memcpy_duration = l2_cache_perf_stats[6] + total_cache_update_duration = l2_cache_perf_stats[7] + get_tensor_copy_for_cache_update_duration = l2_cache_perf_stats[8] + set_tensor_copy_for_cache_update_duration = l2_cache_perf_stats[9] + num_l2_evictions = l2_cache_perf_stats[10] - l2_cache_free_bytes = l2_cache_perf_stats[9] - l2_cache_capacity = l2_cache_perf_stats[10] + l2_cache_free_bytes = l2_cache_perf_stats[11] + l2_cache_capacity = l2_cache_perf_stats[12] + + set_cache_lock_wait_duration = l2_cache_perf_stats[13] + get_cache_lock_wait_duration = l2_cache_perf_stats[14] stats_reporter.report_data_amount( iteration_step=self.step, @@ -1857,6 +1882,11 @@ def _report_l2_cache_perf_stats(self) -> None: event_name=self.l2_num_cache_lookups_stats_name, data_bytes=num_lookups, ) + stats_reporter.report_data_amount( + iteration_step=self.step, + event_name=self.l2_num_cache_evictions_stats_name, + data_bytes=num_l2_evictions, + ) stats_reporter.report_data_amount( iteration_step=self.step, event_name=self.l2_cache_capacity_stats_name, @@ -1892,6 +1922,12 @@ def _report_l2_cache_perf_stats(self) -> None: duration_ms=get_weights_fillup_total_duration, time_unit="us", ) + stats_reporter.report_duration( + iteration_step=self.step, + event_name="l2_cache.perf.get.cache_memcpy_duration_us", + duration_ms=get_cache_memcpy_duration, + time_unit="us", + ) stats_reporter.report_duration( iteration_step=self.step, event_name="l2_cache.perf.total.cache_update_duration_us", @@ -1911,6 +1947,19 @@ def _report_l2_cache_perf_stats(self) -> None: time_unit="us", ) + stats_reporter.report_duration( + iteration_step=self.step, + event_name="l2_cache.perf.get.cache_lock_wait_duration_us", + duration_ms=get_cache_lock_wait_duration, + time_unit="us", + ) + stats_reporter.report_duration( + iteration_step=self.step, + event_name="l2_cache.perf.set.cache_lock_wait_duration_us", + duration_ms=set_cache_lock_wait_duration, + time_unit="us", + ) + # pyre-ignore def _recording_to_timer( self, timer: Optional[AsyncSeriesTimer], **kwargs: Any diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py b/fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py index fbccebd9a..19927a5ea 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/utils/requests.py @@ -11,6 +11,7 @@ from typing import List, Optional, Tuple import numpy as np +import numpy.typing as npt import torch # pyre-fixme[21]: Could not find name `default_rng` in `numpy.random` (stubbed). @@ -135,7 +136,7 @@ def generate_int_data_from_stats( sigma: int, size: int, distribution: str, -) -> np.ndarray: +) -> npt.NDArray: """ Generate integer data based on stats """ diff --git a/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h b/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h index 4bbe877f0..1cfc1a987 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h +++ b/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h @@ -48,7 +48,7 @@ std::vector permute_multi_embedding_function_cpu( const Tensor& permutes, const Tensor& in_shapes, const Tensor& out_shapes, - const c10::SymIntArrayRef out_lengths, + const c10::IntArrayRef out_lengths, const bool& reverse_permute); std::vector permute_multi_embedding_function_meta( @@ -64,7 +64,7 @@ std::vector permute_multi_embedding_function_gpu( const Tensor& permutes, const Tensor& in_shapes, const Tensor& out_shapes, - const c10::SymIntArrayRef out_lengths, + const c10::IntArrayRef out_lengths, const bool& reverse_permute); std::tuple, std::vector, std::vector> diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h index 4965c3731..80955120f 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h @@ -39,7 +39,9 @@ class CacheLibCache { int64_t max_D_; }; - explicit CacheLibCache(const CacheConfig& cache_config); + explicit CacheLibCache( + const CacheConfig& cache_config, + int64_t unique_tbe_id); std::unique_ptr initializeCacheLib(const CacheConfig& config); @@ -48,7 +50,7 @@ class CacheLibCache { /// Find the stored embeddings from a given embedding indices, aka key /// - /// @param key embedding index to look up + /// @param key_tensor embedding index(tensor with only one element) to look up /// /// @return an optional value, return none on cache misses, if cache hit /// return a pointer to the cachelib underlying storage of associated @@ -57,7 +59,7 @@ class CacheLibCache { /// @note that this is not thread safe, caller needs to make sure the data is /// fully processed before doing cache insertion, otherwise the returned space /// might be overwritten if cache is full - std::optional get(int64_t key); + folly::Optional get(const at::Tensor& key_tensor); /// Cachelib wrapper specific hash function /// @@ -74,9 +76,18 @@ class CacheLibCache { /// deterministic mapping from a embedding index to a specific pool id facebook::cachelib::PoolId get_pool_id(int64_t key); + /// update the LRU queue in cachelib, this is detached from cache->find() + /// so that we could boost up the lookup perf without worrying about LRU queue + /// contention + /// + /// @param read_handles the read handles that record what cache item has been + /// accessed + void batchMarkUseful(const std::vector& read_handles); + /// Add an embedding index and embeddings into cachelib /// - /// @param key embedding index to insert + /// @param key_tensor embedding index(tensor with only one element) to insert + /// @param data embedding weights to insert /// /// @return true on success insertion, false on failure insertion, a failure /// insertion could happen if the refcount of bottom K items in LRU queue @@ -86,11 +97,12 @@ class CacheLibCache { /// bulk read and bluk write sequentially /// /// @note cache_->allocation will trigger eviction callback func - bool put(int64_t key, const at::Tensor& data); + bool put(const at::Tensor& key_tensor, const at::Tensor& data); /// iterate through all items in L2 cache, fill them in indices and weights /// respectively and return indices, weights and count /// + /// @return optional value, if cache is empty return none /// @return indices The 1D embedding index tensor, should skip on negative /// value /// @return weights The 2D tensor that each row(embeddings) is paired up with @@ -100,7 +112,8 @@ class CacheLibCache { /// /// @note this isn't thread safe, caller needs to make sure put isn't called /// while this is executed. - std::tuple get_all_items(); + folly::Optional> + get_all_items(); /// instantiate eviction related indices and weights tensors(size of ) /// for L2 eviction using the same dtype and device from and @@ -120,25 +133,28 @@ class CacheLibCache { const at::Tensor& count); /// reset slot pointer that points to the next available slot in the eviction - /// tensors + /// tensors and returns number of slots filled void reset_eviction_states(); /// get the filled indices and weights tensors from L2 eviction, could be all /// invalid if no eviction happened - folly::Optional> - get_evicted_indices_and_weights(); + folly::Optional> + get_tensors_and_reset(); /// get L2 cache utilization stats std::vector get_cache_usage(); private: const CacheConfig cache_config_; + const int64_t unique_tbe_id_; std::unique_ptr cache_; std::vector pool_ids_; std::unique_ptr admin_; - std::shared_ptr evicted_indices_ptr_{nullptr}; - std::shared_ptr evicted_weights_ptr_{nullptr}; + folly::Optional evicted_indices_opt_{folly::none}; + folly::Optional evicted_weights_opt_{folly::none}; + folly::Optional index_dtype_{folly::none}; + folly::Optional weights_dtype_{folly::none}; std::atomic eviction_row_id{0}; }; diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_accessor.h b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_accessor.h index 3f5ed08f2..24ad14125 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_accessor.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/tensor_accessor.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include #include @@ -472,6 +473,53 @@ template < using PackedTensorAccessor64 = GenericPackedTensorAccessor; +template +inline at::ScalarType scalar_type_for() { +#define TYPE_CASE(U, name) \ + if constexpr (std::is_same_v) { \ + return at::ScalarType::name; \ + } + + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(TYPE_CASE) + +#undef TYPE_CASE + + return at::ScalarType::Undefined; +} + +template +inline void check_scalar_type( + const at::TensorBase& tensor +#ifdef FBGEMM_GPU_MEMCHECK + , + const char* const func_name, + const char* const tensor_name +#endif +) { + const auto expected_type = scalar_type_for(); + + TORCH_CHECK( + tensor.scalar_type() == expected_type || + (isQIntType(tensor.scalar_type()) && + toUnderlying(tensor.scalar_type()) == expected_type), +#ifdef FBGEMM_GPU_MEMCHECK + "[ ", + func_name, + " ]: ", +#endif + "Expected tensor ", +#ifdef FBGEMM_GPU_MEMCHECK + "'", + tensor_name, + "' ", +#endif + "to have scalar type ", + expected_type, + ", but found ", + tensor.scalar_type(), + " instead!"); +} + } // namespace fbgemm_gpu #ifdef FBGEMM_GPU_MEMCHECK @@ -521,10 +569,21 @@ pta::PackedTensorAccessor32 make_packed_tensor_accessor32( #else const at::Tensor& tensor) { #endif + TORCH_CHECK( tensor.numel() <= static_cast(std::numeric_limits::max()), "numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64"); + + fbgemm_gpu::check_scalar_type( + tensor +#ifdef FBGEMM_GPU_MEMCHECK + , + func_name, + ptr_name +#endif + ); + #ifdef FBGEMM_GPU_MEMCHECK return make_generic_packed_tensor_accessor( tensor, ptr_name, func_name); @@ -542,10 +601,23 @@ pta::PackedTensorAccessor64 make_packed_tensor_accessor64( const at::Tensor& tensor, const char* const ptr_name, const char* const func_name) { +#else + const at::Tensor& tensor) { +#endif + + fbgemm_gpu::check_scalar_type( + tensor +#ifdef FBGEMM_GPU_MEMCHECK + , + func_name, + ptr_name +#endif + ); + +#ifdef FBGEMM_GPU_MEMCHECK return make_generic_packed_tensor_accessor( tensor, ptr_name, func_name); #else - const at::Tensor& tensor) { return tensor.packed_accessor64(); #endif } diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp index ed8daa31b..1eba2be30 100644 --- a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp @@ -42,7 +42,7 @@ variable_list PermuteMultiEmbeddingOp::forward( const auto permute_op = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::permute_multi_embedding_function", "") - .typed(); + .typed(); return permute_op.call( pooled_embs, permutes, in_shapes, out_shapes, out_lengths, false); @@ -64,7 +64,7 @@ variable_list PermuteMultiEmbeddingOp::backward( const auto permute_op = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::permute_multi_embedding_function", "") - .typed(); + .typed(); auto grad_input = permute_op.call( grad_output, permutes, out_shapes, in_shapes, in_lengths, true); grad_input.push_back(torch::autograd::Variable()); // permutes diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu index 9ef3c62a4..11948bb22 100644 --- a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu @@ -222,7 +222,7 @@ std::vector permute_multi_embedding_function_gpu( const Tensor& permutes, const Tensor& in_shapes, const Tensor& out_shapes, - const c10::SymIntArrayRef out_lengths, + const c10::IntArrayRef out_lengths, const bool& reverse_permute) { // we assume that there's at least one input tensor in the list // it should be enforced from the caller side who has the knowledge. @@ -302,7 +302,7 @@ std::vector permute_multi_embedding_gpu( const Tensor& permutes, const Tensor& in_shapes, const Tensor& out_shapes, - const c10::SymIntArrayRef out_lengths) { + const c10::IntArrayRef out_lengths) { return permute_multi_embedding_function_gpu( pooled_embs, permutes, in_shapes, out_shapes, out_lengths, false); } @@ -314,14 +314,8 @@ std::vector regroup_keyed_tensor_gpu( const std::vector>& groups) { auto [permutes, in_shapes, out_shapes, out_lengths] = kt_regroup_arguments_gpu(pooled_embs[0], keys, lengths, groups); - std::vector out; - std::transform( - out_lengths.begin(), - out_lengths.end(), - std::back_inserter(out), - [](const int32_t v) { return c10::SymInt(v); }); return permute_multi_embedding_function_gpu( - pooled_embs, permutes, in_shapes, out_shapes, out, false); + pooled_embs, permutes, in_shapes, out_shapes, out_lengths, false); } } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp index 9ca8f7d2a..75d37c2d3 100644 --- a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp @@ -19,7 +19,7 @@ std::vector permute_multi_embedding_function_cpu( const Tensor& permutes, const Tensor& /* in_shapes */, const Tensor& /* out_shapes */, - const c10::SymIntArrayRef out_lengths, + const c10::IntArrayRef out_lengths, const bool& reverse_permute) { std::vector inputs; inputs.reserve(pooled_embs.size()); @@ -171,7 +171,7 @@ std::vector permute_multi_embedding_cpu( const Tensor& permutes, const Tensor& in_shapes, const Tensor& out_shapes, - const c10::SymIntArrayRef out_lengths) { + const c10::IntArrayRef out_lengths) { return permute_multi_embedding_function_cpu( pooled_embs, permutes, in_shapes, out_shapes, out_lengths, false); } @@ -321,14 +321,8 @@ std::vector regroup_keyed_tensor_cpu( const std::vector>& groups) { auto [permutes, in_shapes, out_shapes, out_lengths] = kt_regroup_arguments_cpu(pooled_embs[0], keys, lengths, groups); - std::vector out; - std::transform( - out_lengths.begin(), - out_lengths.end(), - std::back_inserter(out), - [](const int32_t v) { return c10::SymInt(v); }); return permute_multi_embedding_function_cpu( - pooled_embs, permutes, in_shapes, out_shapes, out, false); + pooled_embs, permutes, in_shapes, out_shapes, out_lengths, false); } std::vector regroup_keyed_tensor_meta( diff --git a/fbgemm_gpu/src/placeholder.cpp b/fbgemm_gpu/src/placeholder.cpp index 16d443485..82cf3aa52 100644 --- a/fbgemm_gpu/src/placeholder.cpp +++ b/fbgemm_gpu/src/placeholder.cpp @@ -6,7 +6,6 @@ * LICENSE file in the root directory of this source tree. */ -/* - This is placeholder code to force compilation and generation of .so file. -*/ +/// This is a placeholder source file that is used to force compilation and +/// generation of an .SO file. namespace fbgemm_gpu {} diff --git a/fbgemm_gpu/src/ps_split_embeddings_cache/ps_table_batched_embeddings.h b/fbgemm_gpu/src/ps_split_embeddings_cache/ps_table_batched_embeddings.h index 897672943..0498bda96 100644 --- a/fbgemm_gpu/src/ps_split_embeddings_cache/ps_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ps_split_embeddings_cache/ps_table_batched_embeddings.h @@ -48,7 +48,8 @@ class EmbeddingParameterServer : public kv_db::EmbeddingKVDB { const at::Tensor& indices, const at::Tensor& weights, const at::Tensor& count, - const bool is_bwd = false) override { + const kv_db::RocksdbWriteMode w_mode = + kv_db::RocksdbWriteMode::FWD_ROCKSDB_READ) override { RECORD_USER_SCOPE("EmbeddingParameterServer::set"); co_await tps_client_->set(indices, weights, count.item().toLong()); } diff --git a/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp b/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp index 2510d4e81..f5002375c 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp +++ b/fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp @@ -7,6 +7,7 @@ */ #include "fbgemm_gpu/split_embeddings_cache/cachelib_cache.h" +#include #include "fbgemm_gpu/split_embeddings_cache/kv_db_cpp_utils.h" #include "fbgemm_gpu/utils/dispatch_macros.h" @@ -14,65 +15,82 @@ namespace l2_cache { using Cache = facebook::cachelib::LruAllocator; -// this is a general predictor for weights data type, might not be general -// enough for all the cases -at::ScalarType bytes_to_dtype(int num_bytes) { - switch (num_bytes) { - case 1: - return at::kByte; - case 2: - return at::kHalf; - case 4: - return at::kFloat; - case 8: - return at::kDouble; - default: - throw std::runtime_error("Unsupported dtype"); - } -} - -CacheLibCache::CacheLibCache(const CacheConfig& cache_config) +CacheLibCache::CacheLibCache( + const CacheConfig& cache_config, + int64_t unique_tbe_id) : cache_config_(cache_config), + unique_tbe_id_(unique_tbe_id), cache_(initializeCacheLib(cache_config_)), admin_(createCacheAdmin(*cache_)) { for (size_t i = 0; i < cache_config_.num_shards; i++) { pool_ids_.push_back(cache_->addPool( fmt::format("shard_{}", i), - cache_->getCacheMemoryStats().ramCacheSize / cache_config_.num_shards)); + cache_->getCacheMemoryStats().ramCacheSize / cache_config_.num_shards, + std::set{}, + Cache::MMConfig{ + 0, /* promote on every access*/ + true, /*enable promotion on write*/ + true /*enable promotion on read*/})); } } std::unique_ptr CacheLibCache::initializeCacheLib( const CacheConfig& config) { - auto eviction_cb = - [this](const facebook::cachelib::LruAllocator::RemoveCbData& data) { - FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( - evicted_weights_ptr_->scalar_type(), "l2_eviction_handling", [&] { - if (data.context == - facebook::cachelib::RemoveContext::kEviction) { - auto indices_data_ptr = - evicted_indices_ptr_->data_ptr(); - auto weights_data_ptr = - evicted_weights_ptr_->data_ptr(); - auto row_id = eviction_row_id++; - auto weight_dim = evicted_weights_ptr_->size(1); - const auto key_ptr = - reinterpret_cast(data.item.getKey().data()); - indices_data_ptr[row_id] = *key_ptr; + auto eviction_cb = [this]( + const facebook::cachelib::LruAllocator::RemoveCbData& + data) { + if (evicted_weights_opt_.has_value()) { + FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( + evicted_weights_opt_->scalar_type(), "l2_eviction_handling", [&] { + using value_t = scalar_t; + FBGEMM_DISPATCH_INTEGRAL_TYPES( + evicted_indices_opt_->scalar_type(), + "l2_eviction_handling", + [&] { + using index_t = scalar_t; + if (data.context == + facebook::cachelib::RemoveContext::kEviction) { + auto indices_data_ptr = + evicted_indices_opt_->data_ptr(); + auto weights_data_ptr = + evicted_weights_opt_->data_ptr(); + auto row_id = eviction_row_id++; + auto weight_dim = evicted_weights_opt_->size(1); + const auto key_ptr = reinterpret_cast( + data.item.getKey().data()); + indices_data_ptr[row_id] = *key_ptr; - std::copy( - reinterpret_cast(data.item.getMemory()), - reinterpret_cast(data.item.getMemory()) + - weight_dim, - &weights_data_ptr[row_id * weight_dim]); // dst_start - } - }); - }; + std::copy( + reinterpret_cast(data.item.getMemory()), + reinterpret_cast( + data.item.getMemory()) + + weight_dim, + &weights_data_ptr[row_id * weight_dim]); // dst_start + } + }); + }); + } + }; Cache::Config cacheLibConfig; + int64_t rough_num_items = + cache_config_.cache_size_bytes / cache_config_.item_size_bytes; + unsigned int bucket_power = std::log(rough_num_items) / std::log(2) + 1; + // 15 here is a magic number between 10 and 20 + unsigned int lock_power = + std::log(cache_config_.num_shards * 15) / std::log(2) + 1; + XLOG(INFO) << fmt::format( + "[TBE_ID{}] Setting up Cachelib for L2 cache, capacity: {}GB, " + "item_size: {}B, max_num_items: {}, bucket_power: {}, lock_power: {}", + unique_tbe_id_, + config.cache_size_bytes / 1024 / 1024 / 1024, + cache_config_.item_size_bytes, + rough_num_items, + bucket_power, + lock_power); cacheLibConfig.setCacheSize(static_cast(config.cache_size_bytes)) .setRemoveCallback(eviction_cb) .setCacheName("TBEL2Cache") - .setAccessConfig({25 /* bucket power */, 10 /* lock power */}) + .setAccessConfig({bucket_power, lock_power}) .setFullCoredump(false) .validate(); return std::make_unique(cacheLibConfig); @@ -86,14 +104,21 @@ std::unique_ptr CacheLibCache::createCacheAdmin( cache, std::move(adminConfig)); } -std::optional CacheLibCache::get(int64_t key) { - auto key_str = - folly::StringPiece(reinterpret_cast(&key), sizeof(int64_t)); - auto item = cache_->find(key_str); - if (!item) { - return std::nullopt; - } - return const_cast(item->getMemory()); +folly::Optional CacheLibCache::get(const at::Tensor& key_tensor) { + folly::Optional res; + FBGEMM_DISPATCH_INTEGRAL_TYPES(key_tensor.scalar_type(), "get", [&] { + using index_t = scalar_t; + auto key = *(key_tensor.data_ptr()); + auto key_str = folly::StringPiece( + reinterpret_cast(&key), sizeof(index_t)); + auto item = cache_->find(key_str); + if (!item) { + res = folly::none; + return; + } + res = const_cast(item->getMemory()); + }); + return res; } size_t CacheLibCache::get_shard_id(int64_t key) { @@ -104,48 +129,91 @@ facebook::cachelib::PoolId CacheLibCache::get_pool_id(int64_t key) { return pool_ids_[get_shard_id(key)]; } -bool CacheLibCache::put(int64_t key, const at::Tensor& data) { - auto key_str = - folly::StringPiece(reinterpret_cast(&key), sizeof(int64_t)); - auto item = cache_->allocate(get_pool_id(key), key_str, data.nbytes()); - if (!item) { - XLOG(ERR) << fmt::format("Failed to allocate item {} in cache, skip", key); - return false; +void CacheLibCache::batchMarkUseful( + const std::vector& read_handles) { + if (read_handles.empty()) { + return; + } + for (auto& handle : read_handles) { + if (handle) { + cache_->markUseful(handle, facebook::cachelib::AccessMode::kRead); + } + } +} + +bool CacheLibCache::put(const at::Tensor& key_tensor, const at::Tensor& data) { + if (!index_dtype_.has_value()) { + index_dtype_ = key_tensor.scalar_type(); + } + if (!weights_dtype_.has_value()) { + weights_dtype_ = data.scalar_type(); } - std::memcpy(item->getMemory(), data.data_ptr(), data.nbytes()); - cache_->insertOrReplace(std::move(item)); - return true; + bool res; + FBGEMM_DISPATCH_INTEGRAL_TYPES(key_tensor.scalar_type(), "put", [&] { + using index_t = scalar_t; + auto key = *(key_tensor.data_ptr()); + auto key_str = folly::StringPiece( + reinterpret_cast(&key), sizeof(index_t)); + auto item = cache_->findToWrite(key_str); + if (!item) { + auto alloc_item = + cache_->allocate(get_pool_id(key), key_str, data.nbytes()); + if (!alloc_item) { + XLOG(ERR) << fmt::format( + "[TBE_ID{}]Failed to allocate item {} in cache, skip", + unique_tbe_id_, + key); + res = false; + return; + } + std::memcpy(alloc_item->getMemory(), data.data_ptr(), data.nbytes()); + cache_->insertOrReplace(std::move(alloc_item)); + } else { + std::memcpy(item->getMemory(), data.data_ptr(), data.nbytes()); + } + res = true; + }); + return res; } -std::tuple CacheLibCache::get_all_items() { +folly::Optional> +CacheLibCache::get_all_items() { + if (!index_dtype_.has_value() || !weights_dtype_.has_value()) { + return folly::none; + } int total_num_items = 0; for (auto& pool_id : pool_ids_) { total_num_items += cache_->getPoolStats(pool_id).numItems(); } auto weight_dim = cache_config_.max_D_; - auto weights_dtype = - bytes_to_dtype(cache_config_.item_size_bytes / weight_dim); auto indices = at::empty( - total_num_items, at::TensorOptions().dtype(at::kLong).device(at::kCPU)); + total_num_items, + at::TensorOptions().dtype(index_dtype_.value()).device(at::kCPU)); auto weights = at::empty( {total_num_items, weight_dim}, - at::TensorOptions().dtype(weights_dtype).device(at::kCPU)); + at::TensorOptions().dtype(weights_dtype_.value()).device(at::kCPU)); FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( weights.scalar_type(), "get_all_items", [&] { - auto indices_data_ptr = indices.data_ptr(); - auto weights_data_ptr = weights.data_ptr(); - int64_t item_idx = 0; - for (auto itr = cache_->begin(); itr != cache_->end(); ++itr) { - const auto key_ptr = - reinterpret_cast(itr->getKey().data()); - indices_data_ptr[item_idx] = *key_ptr; - std::copy( - reinterpret_cast(itr->getMemory()), - reinterpret_cast(itr->getMemory()) + weight_dim, - &weights_data_ptr[item_idx * weight_dim]); // dst_start - item_idx++; - } - CHECK_EQ(total_num_items, item_idx); + using value_t = scalar_t; + FBGEMM_DISPATCH_INTEGRAL_TYPES( + indices.scalar_type(), "get_all_items", [&] { + using index_t = scalar_t; + auto indices_data_ptr = indices.data_ptr(); + auto weights_data_ptr = weights.data_ptr(); + int64_t item_idx = 0; + for (auto itr = cache_->begin(); itr != cache_->end(); ++itr) { + const auto key_ptr = + reinterpret_cast(itr->getKey().data()); + indices_data_ptr[item_idx] = *key_ptr; + std::copy( + reinterpret_cast(itr->getMemory()), + reinterpret_cast(itr->getMemory()) + + weight_dim, + &weights_data_ptr[item_idx * weight_dim]); // dst_start + item_idx++; + } + CHECK_EQ(total_num_items, item_idx); + }); }); return std::make_tuple( indices, @@ -160,28 +228,38 @@ void CacheLibCache::init_tensor_for_l2_eviction( const at::Tensor& weights, const at::Tensor& count) { auto num_lookups = count.item(); - evicted_indices_ptr_ = std::make_shared( + evicted_indices_opt_ = at::ones( num_lookups, at::TensorOptions().device(indices.device()).dtype(indices.dtype())) * - -1); - evicted_weights_ptr_ = std::make_shared(at::empty( + -1; + evicted_weights_opt_ = at::empty( {num_lookups, weights.size(1)}, - at::TensorOptions().device(weights.device()).dtype(weights.dtype()))); + at::TensorOptions().device(weights.device()).dtype(weights.dtype())); } void CacheLibCache::reset_eviction_states() { + evicted_indices_opt_.reset(); + evicted_weights_opt_.reset(); eviction_row_id = 0; + return; } -folly::Optional> -CacheLibCache::get_evicted_indices_and_weights() { - if (evicted_indices_ptr_) { - assert(evicted_weights_ptr_ != nullptr); - return std::make_pair(*evicted_indices_ptr_, *evicted_weights_ptr_); - } else { - return folly::none; +folly::Optional> +CacheLibCache::get_tensors_and_reset() { + folly::Optional> ret = + folly::none; + if (evicted_indices_opt_.has_value()) { + assert(evicted_weights_opt_.has_value()); + if (eviction_row_id > 0) { + ret = std::make_tuple( + evicted_indices_opt_.value(), + evicted_weights_opt_.value(), + at::tensor(eviction_row_id, evicted_indices_opt_->options())); + } } + reset_eviction_states(); + return ret; } std::vector CacheLibCache::get_cache_usage() { diff --git a/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu b/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu index 16d46e2e3..83a06d78a 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu +++ b/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu @@ -274,10 +274,10 @@ transpose_embedding_input( } AT_DISPATCH_INDEX_TYPES( - infos.scalar_type(), "transpose_embedding_input1", [&] { + infos.scalar_type(), "transpose_embedding_input_1", [&] { using info_t = index_t; AT_DISPATCH_INDEX_TYPES( - indices.scalar_type(), "transpose_embedding_input2", [&] { + indices.scalar_type(), "transpose_embedding_input_2", [&] { if (!is_index_select) { if (!nobag) { INVOKE_LINEARIZE_INDEX_KERNEL(int32_t, false); diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp index 6507d2fa6..fdc91b0e9 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp @@ -29,10 +29,11 @@ inline int64_t get_maybe_uvm_scalar(const at::Tensor& tensor) { }; // namespace -std::tuple tensor_copy( +QueueItem tensor_copy( const at::Tensor& indices, const at::Tensor& weights, - const at::Tensor& count) { + const at::Tensor& count, + kv_db::RocksdbWriteMode mode) { auto num_sets = get_maybe_uvm_scalar(count); auto new_indices = at::empty( num_sets, at::TensorOptions().device(at::kCPU).dtype(indices.dtype())); @@ -42,23 +43,28 @@ std::tuple tensor_copy( auto new_count = at::empty({1}, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( - weights.scalar_type(), "cache_memcpy", [&] { - auto indices_addr = indices.data_ptr(); - auto new_indices_addr = new_indices.data_ptr(); - std::copy( - indices_addr, - indices_addr + num_sets, - new_indices_addr); // dst_start - - auto weights_addr = weights.data_ptr(); - auto new_weightss_addr = new_weights.data_ptr(); - std::copy( - weights_addr, - weights_addr + num_sets * weights.size(1), - new_weightss_addr); // dst_start + weights.scalar_type(), "tensor_copy", [&] { + using value_t = scalar_t; + FBGEMM_DISPATCH_INTEGRAL_TYPES( + indices.scalar_type(), "tensor_copy", [&] { + using index_t = scalar_t; + auto indices_addr = indices.data_ptr(); + auto new_indices_addr = new_indices.data_ptr(); + std::copy( + indices_addr, + indices_addr + num_sets, + new_indices_addr); // dst_start + + auto weights_addr = weights.data_ptr(); + auto new_weightss_addr = new_weights.data_ptr(); + std::copy( + weights_addr, + weights_addr + num_sets * weights.size(1), + new_weightss_addr); // dst_start + }); }); *new_count.data_ptr() = num_sets; - return std::make_tuple(new_indices, new_weights, new_count); + return QueueItem{new_indices, new_weights, new_count, mode}; } EmbeddingKVDB::EmbeddingKVDB( @@ -78,7 +84,8 @@ EmbeddingKVDB::EmbeddingKVDB( cache_config.num_shards = num_shards_; cache_config.item_size_bytes = max_D_ * ele_size_bytes; cache_config.max_D_ = max_D_; - l2_cache_ = std::make_unique(cache_config); + l2_cache_ = + std::make_unique(cache_config, unique_id); } else { l2_cache_ = nullptr; } @@ -94,21 +101,24 @@ EmbeddingKVDB::EmbeddingKVDB( if (stop_) { return; } - auto& indices = std::get<0>(*filling_item_ptr); - auto& weights = std::get<1>(*filling_item_ptr); - auto& count = std::get<2>(*filling_item_ptr); + auto& indices = filling_item_ptr->indices; + auto& weights = filling_item_ptr->weights; + auto& count = filling_item_ptr->count; + auto& rocksdb_wmode = filling_item_ptr->mode; if (l2_cache_) { - auto evicted_pairs_opt = set_cache(indices, weights, count); - if (evicted_pairs_opt.has_value()) { - auto& evicted_indices = evicted_pairs_opt->first; - auto& evicted_weights = evicted_pairs_opt->second; - - folly::coro::blockingWait( - set_kv_db_async(evicted_indices, evicted_weights, count)); + auto evicted_tuples_opt = set_cache(indices, weights, count); + if (evicted_tuples_opt.has_value()) { + auto& evicted_indices = std::get<0>(evicted_tuples_opt.value()); + auto& evicted_weights = std::get<1>(evicted_tuples_opt.value()); + auto& evicted_count = std::get<2>(evicted_tuples_opt.value()); + + folly::coro::blockingWait(set_kv_db_async( + evicted_indices, evicted_weights, evicted_count, rocksdb_wmode)); } } else { - folly::coro::blockingWait(set_kv_db_async(indices, weights, count)); + folly::coro::blockingWait( + set_kv_db_async(indices, weights, count, rocksdb_wmode)); } weights_to_fill_queue_.dequeue(); @@ -124,11 +134,17 @@ EmbeddingKVDB::~EmbeddingKVDB() { void EmbeddingKVDB::flush() { wait_util_filling_work_done(); if (l2_cache_) { - auto tensor_tuple = l2_cache_->get_all_items(); - auto& indices = std::get<0>(tensor_tuple); - auto& weights = std::get<1>(tensor_tuple); - auto& count = std::get<2>(tensor_tuple); - folly::coro::blockingWait(set_kv_db_async(indices, weights, count)); + auto tensor_tuple_opt = l2_cache_->get_all_items(); + if (!tensor_tuple_opt.has_value()) { + XLOG(INFO) << "[TBE_ID" << unique_id_ + << "]no items exist in L2 cache, flush nothing"; + return; + } + auto& indices = std::get<0>(tensor_tuple_opt.value()); + auto& weights = std::get<1>(tensor_tuple_opt.value()); + auto& count = std::get<2>(tensor_tuple_opt.value()); + folly::coro::blockingWait(set_kv_db_async( + indices, weights, count, kv_db::RocksdbWriteMode::FLUSH)); } } @@ -138,6 +154,7 @@ void EmbeddingKVDB::get_cuda( const at::Tensor& count) { auto rec = torch::autograd::profiler::record_function_enter_new( "## EmbeddingKVDB::get_cuda ##"); + check_tensor_type_consistency(indices, weights); // take reference to self to avoid lifetime issues. auto self = shared_from_this(); std::function* functor = @@ -158,6 +175,7 @@ void EmbeddingKVDB::set_cuda( const bool is_bwd) { auto rec = torch::autograd::profiler::record_function_enter_new( "## EmbeddingKVDB::set_cuda ##"); + check_tensor_type_consistency(indices, weights); // take reference to self to avoid lifetime issues. auto self = shared_from_this(); std::function* functor = new std::function([=]() { @@ -175,11 +193,12 @@ void EmbeddingKVDB::set_cuda( std::vector EmbeddingKVDB::get_l2cache_perf( const int64_t step, const int64_t interval) { - std::vector ret(11, 0); // num metrics + std::vector ret(15, 0); // num metrics if (step > 0 && step % interval == 0) { int reset_val = 0; auto num_cache_misses = num_cache_misses_.exchange(reset_val); auto num_lookups = num_lookups_.exchange(reset_val); + auto num_evictions = num_evictions_.exchange(reset_val); auto get_total_duration = get_total_duration_.exchange(reset_val); auto get_cache_lookup_total_duration = get_cache_lookup_total_duration_.exchange(reset_val); @@ -187,30 +206,47 @@ std::vector EmbeddingKVDB::get_l2cache_perf( get_cache_lookup_wait_filling_thread_duration_.exchange(reset_val); auto get_weights_fillup_total_duration = get_weights_fillup_total_duration_.exchange(reset_val); + auto get_cache_memcpy_duration = + get_cache_memcpy_duration_.exchange(reset_val); + auto total_cache_update_duration = total_cache_update_duration_.exchange(reset_val); auto get_tensor_copy_for_cache_update_dur = get_tensor_copy_for_cache_update_.exchange(reset_val); auto set_tensor_copy_for_cache_update_dur = set_tensor_copy_for_cache_update_.exchange(reset_val); + + auto set_cache_lock_wait_duration = + set_cache_lock_wait_duration_.exchange(reset_val); + auto get_cache_lock_wait_duration = + get_cache_lock_wait_duration_.exchange(reset_val); + ret[0] = (double(num_cache_misses) / interval); ret[1] = (double(num_lookups) / interval); ret[2] = (double(get_total_duration) / interval); ret[3] = (double(get_cache_lookup_total_duration) / interval); ret[4] = (double(get_cache_lookup_wait_filling_thread_duration) / interval); ret[5] = (double(get_weights_fillup_total_duration) / interval); - ret[6] = (double(total_cache_update_duration) / interval); - ret[7] = (double(get_tensor_copy_for_cache_update_dur) / interval); - ret[8] = (double(set_tensor_copy_for_cache_update_dur) / interval); + ret[6] = (double(get_cache_memcpy_duration) / interval); + ret[7] = (double(total_cache_update_duration) / interval); + ret[8] = (double(get_tensor_copy_for_cache_update_dur) / interval); + ret[9] = (double(set_tensor_copy_for_cache_update_dur) / interval); + ret[10] = (double(num_evictions) / interval); if (l2_cache_) { auto cache_mem_stats = l2_cache_->get_cache_usage(); - ret[9] = (cache_mem_stats[0]); // free cache in bytes - ret[10] = (cache_mem_stats[1]); // total cache capacity in bytes + ret[11] = (cache_mem_stats[0]); // free cache in bytes + ret[12] = (cache_mem_stats[1]); // total cache capacity in bytes } + ret[13] = (double(set_cache_lock_wait_duration) / interval); + ret[14] = (double(get_cache_lock_wait_duration) / interval); } return ret; } +void EmbeddingKVDB::reset_l2_cache() { + l2_cache_ = nullptr; +} + void EmbeddingKVDB::set( const at::Tensor& indices, const at::Tensor& weights, @@ -218,8 +254,8 @@ void EmbeddingKVDB::set( const bool is_bwd) { if (auto num_evictions = get_maybe_uvm_scalar(count); num_evictions <= 0) { XLOG_EVERY_MS(INFO, 60000) - << "[" << unique_id_ << "]skip set_cuda since number evictions is " - << num_evictions; + << "[TBE_ID" << unique_id_ + << "]skip set_cuda since number evictions is " << num_evictions; return; } auto start_ts = facebook::WallClockUtil::NowInUsecFast(); @@ -228,8 +264,11 @@ void EmbeddingKVDB::set( // parallelized with other cuda kernels, as long as all updates are finished // before the next L2 cache lookup auto tensor_copy_start_ts = facebook::WallClockUtil::NowInUsecFast(); - auto new_tuple = tensor_copy(indices, weights, count); - weights_to_fill_queue_.enqueue(new_tuple); + kv_db::RocksdbWriteMode write_mode = is_bwd + ? kv_db::RocksdbWriteMode::BWD_L1_CNFLCT_MISS_WRITE_BACK + : kv_db::RocksdbWriteMode::FWD_L1_EVICTION; + auto new_item = tensor_copy(indices, weights, count, write_mode); + weights_to_fill_queue_.enqueue(new_item); set_tensor_copy_for_cache_update_ += facebook::WallClockUtil::NowInUsecFast() - tensor_copy_start_ts; } @@ -237,16 +276,28 @@ void EmbeddingKVDB::set( void EmbeddingKVDB::get( const at::Tensor& indices, const at::Tensor& weights, - const at::Tensor& count) { + const at::Tensor& count, + int64_t sleep_ms) { if (auto num_lookups = get_maybe_uvm_scalar(count); num_lookups <= 0) { XLOG_EVERY_MS(INFO, 60000) - << "[" << unique_id_ << "]skip get_cuda since number lookups is " + << "[TBE_ID" << unique_id_ << "]skip get_cuda since number lookups is " << num_lookups; return; } ASSERT_EQ(max_D_, weights.size(1)); auto start_ts = facebook::WallClockUtil::NowInUsecFast(); wait_util_filling_work_done(); + + std::unique_lock lock(l2_cache_mtx_); + get_cache_lock_wait_duration_ += + facebook::WallClockUtil::NowInUsecFast() - start_ts; + + // this is for unittest to repro synchronization situation deterministically + if (sleep_ms > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(sleep_ms)); + XLOG(INFO) << "get sleep end"; + } + auto cache_context = get_cache(indices, count); if (cache_context != nullptr) { if (cache_context->num_misses > 0) { @@ -264,8 +315,9 @@ void EmbeddingKVDB::get( // be parallelized with other cuda kernels, as long as all updates are // finished before the next L2 cache lookup auto tensor_copy_start_ts = facebook::WallClockUtil::NowInUsecFast(); - auto new_tuple = tensor_copy(indices, weights, count); - weights_to_fill_queue_.enqueue(new_tuple); + auto new_item = tensor_copy( + indices, weights, count, kv_db::RocksdbWriteMode::FWD_ROCKSDB_READ); + weights_to_fill_queue_.enqueue(new_item); get_tensor_copy_for_cache_update_ += facebook::WallClockUtil::NowInUsecFast() - tensor_copy_start_ts; } else { @@ -288,63 +340,67 @@ std::shared_ptr EmbeddingKVDB::get_cache( return nullptr; } auto start_ts = facebook::WallClockUtil::NowInUsecFast(); - auto indices_addr = indices.data_ptr(); + auto num_lookups = get_maybe_uvm_scalar(count); auto cache_context = std::make_shared(num_lookups); - - auto num_shards = executor_tp_->numThreads(); - - std::vector> tasks; - std::vector> row_ids_per_shard(num_shards); - for (int i = 0; i < num_shards; i++) { - row_ids_per_shard[i].reserve(num_lookups / num_shards * 2); - } - for (uint32_t row_id = 0; row_id < num_lookups; ++row_id) { - row_ids_per_shard[l2_cache_->get_shard_id(indices_addr[row_id])] - .emplace_back(row_id); - } - for (uint32_t shard_id = 0; shard_id < num_shards; ++shard_id) { - tasks.emplace_back( - folly::coro::co_invoke( - [this, - &indices_addr, - cache_context, - shard_id, - &row_ids_per_shard]() mutable -> folly::coro::Task { - for (const auto& row_id : row_ids_per_shard[shard_id]) { - auto emb_idx = indices_addr[row_id]; - if (emb_idx < 0) { - continue; - } - auto cached_addr_opt = l2_cache_->get(emb_idx); - if (cached_addr_opt.has_value()) { // cache hit - cache_context->cached_addr_list[row_id] = - cached_addr_opt.value(); - indices_addr[row_id] = -1; // mark to sentinel value - } else { // cache miss - cache_context->num_misses += 1; + FBGEMM_DISPATCH_INTEGRAL_TYPES(indices.scalar_type(), "get_cache", [&] { + using index_t = scalar_t; + auto indices_addr = indices.data_ptr(); + auto num_shards = executor_tp_->numThreads(); + + std::vector> tasks; + std::vector> row_ids_per_shard(num_shards); + for (int i = 0; i < num_shards; i++) { + row_ids_per_shard[i].reserve(num_lookups / num_shards * 2); + } + for (uint32_t row_id = 0; row_id < num_lookups; ++row_id) { + row_ids_per_shard[l2_cache_->get_shard_id(indices_addr[row_id])] + .emplace_back(row_id); + } + for (uint32_t shard_id = 0; shard_id < num_shards; ++shard_id) { + tasks.emplace_back( + folly::coro::co_invoke( + [this, + &indices_addr, + &indices, + cache_context, + shard_id, + &row_ids_per_shard]() mutable -> folly::coro::Task { + for (const auto& row_id : row_ids_per_shard[shard_id]) { + auto emb_idx = indices_addr[row_id]; + if (emb_idx < 0) { + continue; + } + auto cached_addr_opt = l2_cache_->get(indices[row_id]); + if (cached_addr_opt.has_value()) { // cache hit + cache_context->cached_addr_list[row_id] = + cached_addr_opt.value(); + indices_addr[row_id] = -1; // mark to sentinel value + } else { // cache miss + cache_context->num_misses += 1; + } } - } - co_return; - }) - .scheduleOn(executor_tp_.get())); - } - folly::coro::blockingWait(folly::coro::collectAllRange(std::move(tasks))); - - // the following metrics added here as the current assumption is - // get_cache will only be called in get_cuda path, if assumption no longer - // true, we should wrap this up on the caller side - auto dur = facebook::WallClockUtil::NowInUsecFast() - start_ts; - get_cache_lookup_total_duration_ += dur; - auto cache_misses = cache_context->num_misses.load(); - if (num_lookups > 0) { - num_cache_misses_ += cache_misses; - num_lookups_ += num_lookups; - } else { - XLOG_EVERY_MS(INFO, 60000) - << "[" << unique_id_ - << "]num_lookups is 0, skip collecting the L2 cache miss stats"; - } + co_return; + }) + .scheduleOn(executor_tp_.get())); + } + folly::coro::blockingWait(folly::coro::collectAllRange(std::move(tasks))); + + // the following metrics added here as the current assumption is + // get_cache will only be called in get_cuda path, if assumption no longer + // true, we should wrap this up on the caller side + auto dur = facebook::WallClockUtil::NowInUsecFast() - start_ts; + get_cache_lookup_total_duration_ += dur; + auto cache_misses = cache_context->num_misses.load(); + if (num_lookups > 0) { + num_cache_misses_ += cache_misses; + num_lookups_ += num_lookups; + } else { + XLOG_EVERY_MS(INFO, 60000) + << "[TBE_ID" << unique_id_ + << "]num_lookups is 0, skip collecting the L2 cache miss stats"; + } + }); return cache_context; } @@ -358,7 +414,8 @@ void EmbeddingKVDB::wait_util_filling_work_done() { total_wait_time_ms += 1; if (total_wait_time_ms > 100) { XLOG_EVERY_MS(ERR, 1000) - << "get_cache: waiting for L2 caching filling embeddings for " + << "[TBE_ID" << unique_id_ + << "]get_cache: waiting for L2 caching filling embeddings for " << total_wait_time_ms << " ms, somethings is likely wrong"; } } @@ -366,84 +423,142 @@ void EmbeddingKVDB::wait_util_filling_work_done() { facebook::WallClockUtil::NowInUsecFast() - start_ts; } -folly::Optional> EmbeddingKVDB::set_cache( +folly::Optional> +EmbeddingKVDB::set_cache( const at::Tensor& indices, const at::Tensor& weights, const at::Tensor& count) { if (l2_cache_ == nullptr) { return folly::none; } - // TODO: consider whether need to reconstruct indices/weights/count and free // the original tensor since most of the tensor elem will be invalid, // this will trade some perf for peak DRAM util saving auto cache_update_start_ts = facebook::WallClockUtil::NowInUsecFast(); + std::unique_lock lock(l2_cache_mtx_); + set_cache_lock_wait_duration_ += + facebook::WallClockUtil::NowInUsecFast() - cache_update_start_ts; l2_cache_->init_tensor_for_l2_eviction(indices, weights, count); - auto indices_addr = indices.data_ptr(); - const int64_t num_lookups = get_maybe_uvm_scalar(count); - auto num_shards = executor_tp_->numThreads(); - std::vector> tasks; - std::vector> row_ids_per_shard(num_shards); + FBGEMM_DISPATCH_INTEGRAL_TYPES(indices.scalar_type(), "set_cache", [&] { + using index_t = scalar_t; + auto indices_addr = indices.data_ptr(); + const int64_t num_lookups = get_maybe_uvm_scalar(count); + auto num_shards = executor_tp_->numThreads(); - for (int i = 0; i < num_shards; i++) { - row_ids_per_shard[i].reserve(num_lookups / num_shards * 2); - } - for (uint32_t row_id = 0; row_id < num_lookups; ++row_id) { - row_ids_per_shard[l2_cache_->get_shard_id(indices_addr[row_id])] - .emplace_back(row_id); - } + std::vector> tasks; + std::vector> row_ids_per_shard(num_shards); - for (uint32_t shard_id = 0; shard_id < num_shards; ++shard_id) { - tasks.emplace_back( - folly::coro::co_invoke( - [this, - &indices_addr, - &weights, - shard_id, - &row_ids_per_shard]() mutable -> folly::coro::Task { - for (const auto& row_id : row_ids_per_shard[shard_id]) { - auto emb_idx = indices_addr[row_id]; - if (emb_idx < 0) { - continue; - } - if (!l2_cache_->put(emb_idx, weights[row_id])) { - XLOG_EVERY_MS(ERR, 1000) - << "[" << unique_id_ - << "]Failed to insert into cache, this shouldn't happen"; + for (int i = 0; i < num_shards; i++) { + row_ids_per_shard[i].reserve(num_lookups / num_shards * 2); + } + for (uint32_t row_id = 0; row_id < num_lookups; ++row_id) { + row_ids_per_shard[l2_cache_->get_shard_id(indices_addr[row_id])] + .emplace_back(row_id); + } + + for (uint32_t shard_id = 0; shard_id < num_shards; ++shard_id) { + tasks.emplace_back( + folly::coro::co_invoke( + [this, + &indices_addr, + &indices, + &weights, + shard_id, + &row_ids_per_shard]() mutable -> folly::coro::Task { + for (const auto& row_id : row_ids_per_shard[shard_id]) { + auto emb_idx = indices_addr[row_id]; + if (emb_idx < 0) { + continue; + } + if (!l2_cache_->put(indices[row_id], weights[row_id])) { + XLOG_EVERY_MS(ERR, 1000) + << "[TBE_ID" << unique_id_ + << "]Failed to insert into cache, this shouldn't happen"; + } } - } - co_return; - }) - .scheduleOn(executor_tp_.get())); - } - folly::coro::blockingWait(folly::coro::collectAllRange(std::move(tasks))); - l2_cache_->reset_eviction_states(); + co_return; + }) + .scheduleOn(executor_tp_.get())); + } + folly::coro::blockingWait(folly::coro::collectAllRange(std::move(tasks))); + }); total_cache_update_duration_ += facebook::WallClockUtil::NowInUsecFast() - cache_update_start_ts; - return l2_cache_->get_evicted_indices_and_weights(); + auto tensor_tuple_opt = l2_cache_->get_tensors_and_reset(); + if (tensor_tuple_opt.has_value()) { + auto& num_evictions_tensor = std::get<2>(tensor_tuple_opt.value()); + auto num_evictions = get_maybe_uvm_scalar(num_evictions_tensor); + num_evictions_ += num_evictions; + } + return tensor_tuple_opt; } folly::coro::Task EmbeddingKVDB::cache_memcpy( const at::Tensor& weights, const std::vector& cached_addr_list) { + auto cache_memcpy_start_ts = facebook::WallClockUtil::NowInUsecFast(); FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( weights.scalar_type(), "cache_memcpy", [&] { + std::vector> tasks; auto weights_data_ptr = weights.data_ptr(); - for (int row_id = 0; row_id < cached_addr_list.size(); row_id++) { - if (cached_addr_list[row_id] == nullptr) { - continue; - } - std::copy( - reinterpret_cast(cached_addr_list[row_id]), - reinterpret_cast(cached_addr_list[row_id]) + - max_D_, - &weights_data_ptr[row_id * max_D_]); // dst_start + auto num_shards = executor_tp_->numThreads(); + for (uint32_t shard_id = 0; shard_id < num_shards; ++shard_id) { + tasks.emplace_back( + folly::coro::co_invoke( + [this, + &cached_addr_list, + &weights_data_ptr, + num_shards, + shard_id]() mutable -> folly::coro::Task { + for (int row_id = 0; row_id < cached_addr_list.size(); + row_id++) { + if (row_id % num_shards != shard_id) { + continue; + } + if (cached_addr_list[row_id] == nullptr) { + continue; + } + std::copy( + reinterpret_cast( + cached_addr_list[row_id]), + reinterpret_cast( + cached_addr_list[row_id]) + + max_D_, + &weights_data_ptr[row_id * max_D_]); // dst_start + } + co_return; + }) + .scheduleOn(executor_tp_.get())); } + folly::coro::blockingWait( + folly::coro::collectAllRange(std::move(tasks))); }); + get_cache_memcpy_duration_ += + facebook::WallClockUtil::NowInUsecFast() - cache_memcpy_start_ts; co_return; } +void EmbeddingKVDB::check_tensor_type_consistency( + const at::Tensor& indices, + const at::Tensor& weights) { + if (index_dtype_.has_value()) { + assert(index_dtype_.value() == indices.scalar_type()); + } else { + index_dtype_ = indices.scalar_type(); + XLOG(INFO) << "[TBE_ID" << unique_id_ << "]L2 cache index dtype is " + << index_dtype_.value(); + } + + if (weights_dtype_.has_value()) { + assert(weights_dtype_.value() == weights.scalar_type()); + } else { + weights_dtype_ = weights.scalar_type(); + XLOG(INFO) << "[TBE_ID" << unique_id_ << "]L2 cache weights dtype is " + << weights_dtype_.value(); + } +} + } // namespace kv_db diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h index 8970f2f33..93495f2da 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h @@ -62,6 +62,59 @@ class CacheContext { std::vector cached_addr_list; }; +/// @ingroup embedding-ssd +/// +/// @brief rocksdb write mode +/// +/// In SSD offloading there are 3 writes in each train iteration +/// FWD_ROCKSDB_READ: cache lookup will move uncached data from rocksdb into L2 +/// cache on fwd path +/// +/// FWD_L1_EVICTION: L1 cache eviciton will evict data into L2 cache on fwd path +/// +/// BWD_L1_CNFLCT_MISS_WRITE_BACK: L1 conflict miss will insert into L2 for +/// embedding update on bwd path +/// +/// All the L2 cache filling above will potentially trigger rocksdb write once +/// L2 cache is full +/// +/// Additionally we will do ssd io on L2 flush +enum RocksdbWriteMode { + FWD_ROCKSDB_READ = 0, + FWD_L1_EVICTION = 1, + BWD_L1_CNFLCT_MISS_WRITE_BACK = 2, + FLUSH = 3, +}; + +/// @ingroup embedding-ssd +/// +/// @brief queue item for background L2/rocksdb update +/// +/// indices/weights/count are the corresponding set() params +/// +/// read_handles is cachelib abstracted indices/embedding pair metadata, will be +/// later used on updating cachelib LRU queue as we separate it from +/// EmbeddingKVDB::get_cache() +/// +/// mode is used for monitoring rocksdb write, checkout RocksdbWriteMode for +/// detailed explanation +struct QueueItem { + at::Tensor indices; + at::Tensor weights; + at::Tensor count; + RocksdbWriteMode mode; + QueueItem( + at::Tensor src_indices, + at::Tensor src_weights, + at::Tensor src_count, + RocksdbWriteMode src_mode) { + indices = src_indices; + weights = src_weights; + count = src_count; + mode = src_mode; + } +}; + /// @ingroup embedding-ssd /// /// @brief A class for interacting with different cache layers and storage @@ -118,13 +171,17 @@ class EmbeddingKVDB : public std::enable_shared_from_this { /// relative element in /// @param count A single element tensor that contains the number of indices /// to be processed + /// @param sleep_ms this is used to specifically sleep in get function, this + /// is needed to reproduce synchronization situation deterministicly, in prod + /// case this will be 0 for sure /// /// @return None /// @note weights will be updated from either L2 cache or storage tier void get( const at::Tensor& indices, const at::Tensor& weights, - const at::Tensor& count); + const at::Tensor& count, + int64_t sleep_ms = 0); /// storage tier counterpart of function get() virtual folly::coro::Task get_kv_db_async( @@ -137,7 +194,7 @@ class EmbeddingKVDB : public std::enable_shared_from_this { const at::Tensor& indices, const at::Tensor& weights, const at::Tensor& count, - const bool is_bwd = false) = 0; + const RocksdbWriteMode w_mode = RocksdbWriteMode::FWD_ROCKSDB_READ) = 0; virtual void compact() = 0; @@ -174,6 +231,14 @@ class EmbeddingKVDB : public std::enable_shared_from_this { const int64_t step, const int64_t interval); + // reset L2 cache, this is used for unittesting to bypass l2 cache + void reset_l2_cache(); + + // block waiting for working items in queue to be finished, this is called by + // get_cache() as embedding read should wait until previous write to be + // finished, it could also be called in unitest to sync + void wait_util_filling_work_done(); + private: /// Find non-negative embedding indices in and shard them into /// #cachelib_pools pieces to be lookedup in parallel @@ -201,10 +266,10 @@ class EmbeddingKVDB : public std::enable_shared_from_this { /// @param count A single element tensor that contains the number of indices /// to be processed /// - /// @return None if L2 is missing, other wise return pair of tensors with - /// length of containing L2 evicted embedding indices and embeddings, - /// invalid pairs will have sentinel value(-1) on - folly::Optional> set_cache( + /// @return None if L2 is missing or no eviction, other wise return tuple of + /// tensors with length of containing L2 evicted embedding indices and + /// embeddings, invalid pairs will have sentinel value(-1) on + folly::Optional> set_cache( const at::Tensor& indices, const at::Tensor& weights, const at::Tensor& count); @@ -226,21 +291,33 @@ class EmbeddingKVDB : public std::enable_shared_from_this { virtual void flush_or_compact(const int64_t timestep) = 0; - // waiting for working item queue to be empty, this is called by get_cache() - // as embedding read should wait until previous write to be finished - void wait_util_filling_work_done(); + void check_tensor_type_consistency( + const at::Tensor& indices, + const at::Tensor& weights); std::unique_ptr l2_cache_; const int64_t unique_id_; const int64_t num_shards_; const int64_t max_D_; + folly::Optional index_dtype_{folly::none}; + folly::Optional weights_dtype_{folly::none}; std::unique_ptr executor_tp_; std::unique_ptr cache_filling_thread_; std::atomic stop_{false}; // buffer queue that stores all the needed indices/weights/action_count to // fill up cache - folly::USPSCQueue, true> - weights_to_fill_queue_; + folly::USPSCQueue weights_to_fill_queue_; + // In non pipelining mode, the sequence is + // - get_cuda(): L2 read and insert L2 cache misses into queue for + // bg L2 write + // - L1 cache eviction: insert into bg queue for L2 write + // - ScratchPad update: insert into bg queue for L2 write + // in non-prefetch pipeline, cuda synchronization guarantee get_cuda() happen + // after SP update + // in prefetch pipeline, cuda sync only guarantee get_cuda() happen after L1 + // cache eviction pipeline case, SP bwd update could happen in parallel with + // L2 read mutex is used for l2 cache to do read / write exclusively + std::mutex l2_cache_mtx_; // perf stats // -- perf of get() function @@ -248,14 +325,18 @@ class EmbeddingKVDB : public std::enable_shared_from_this { // instead of SUM(cache miss per interval) / SUM(lookups per interval) std::atomic num_cache_misses_{0}; std::atomic num_lookups_{0}; + std::atomic num_evictions_{0}; std::atomic get_total_duration_{0}; std::atomic get_cache_lookup_total_duration_{0}; std::atomic get_cache_lookup_wait_filling_thread_duration_{0}; std::atomic get_weights_fillup_total_duration_{0}; + std::atomic get_cache_memcpy_duration_{0}; std::atomic get_tensor_copy_for_cache_update_{0}; + std::atomic get_cache_lock_wait_duration_{0}; // -- perf of set() function std::atomic set_tensor_copy_for_cache_update_{0}; + std::atomic set_cache_lock_wait_duration_{0}; // -- commone path std::atomic total_cache_update_duration_{0}; diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu index d6d926a96..d371c2845 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu @@ -22,12 +22,27 @@ #include "fbgemm_gpu/utils/tensor_utils.h" #include "fbgemm_gpu/utils/vec4.cuh" -constexpr int ALL_TO_PREFETCH_SM_RATIO = 8; - using Tensor = at::Tensor; using namespace fbgemm_gpu; +int get_masked_index_default_pipeline_sms(int device) { + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, device); + + // The default number of SMs for use_pipeline=true is set based on an + // empirical study + if (prop.major == 8) { + // Assume A100 + return 4; + } else if (prop.major == 9) { + // Assume H100 + return 16; + } + constexpr int ALL_TO_PREFETCH_SM_RATIO = 8; + return div_round_up(get_device_sm_cnt_(), ALL_TO_PREFETCH_SM_RATIO); +} + template DEVICE_INLINE void vec4_copy(scalar_t* dst, const scalar_t* src, const int32_t D) { @@ -103,26 +118,11 @@ Tensor masked_index_impl( const auto full_grid_size = div_round_up(N, kMaxThreads / tx); - // The default number of SMs for use_pipeline=true is set based on an - // empirical study - - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, at::cuda::current_device()); - - int DEFAULT_PIPELINE_SMS; - if (prop.major == 8) { - // Assume A100 - DEFAULT_PIPELINE_SMS = 4; - } else if (prop.major == 9) { - // Assume H100 - DEFAULT_PIPELINE_SMS = 16; - } else { - DEFAULT_PIPELINE_SMS = - div_round_up(get_device_sm_cnt_(), ALL_TO_PREFETCH_SM_RATIO); - } + static int masked_index_default_pipeline_sms = + get_masked_index_default_pipeline_sms(at::cuda::current_device()); const int pipeline_grid_size = - preferred_sms == -1 ? DEFAULT_PIPELINE_SMS : preferred_sms; + preferred_sms == -1 ? masked_index_default_pipeline_sms : preferred_sms; TORCH_CHECK( !use_pipeline || pipeline_grid_size >= 1, "preferred_sms must >= 1"); diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index 1e91cefb5..6238fe1be 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -315,8 +315,8 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { return impl_->set(indices, weights, count); } - void get(Tensor indices, Tensor weights, Tensor count) { - return impl_->get(indices, weights, count); + void get(Tensor indices, Tensor weights, Tensor count, int64_t sleep_ms) { + return impl_->get(indices, weights, count, sleep_ms); } std::vector get_mem_usage() { @@ -343,6 +343,14 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { return impl_->flush(); } + void reset_l2_cache() { + return impl_->reset_l2_cache(); + } + + void wait_util_filling_work_done() { + return impl_->wait_util_filling_work_done(); + } + private: // shared pointer since we use shared_from_this() in callbacks. std::shared_ptr impl_; @@ -413,7 +421,20 @@ static auto embedding_rocks_db_wrapper = &EmbeddingRocksDBWrapper::get_rocksdb_io_duration) .def("get_l2cache_perf", &EmbeddingRocksDBWrapper::get_l2cache_perf) .def("set", &EmbeddingRocksDBWrapper::set) - .def("get", &EmbeddingRocksDBWrapper::get); + .def( + "get", + &EmbeddingRocksDBWrapper::get, + "", + { + torch::arg("indices"), + torch::arg("weights"), + torch::arg("count"), + torch::arg("sleep_ms") = 0, + }) + .def("reset_l2_cache", &EmbeddingRocksDBWrapper::reset_l2_cache) + .def( + "wait_util_filling_work_done", + &EmbeddingRocksDBWrapper::wait_util_filling_work_done); TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h index 096745314..e7d32682e 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h @@ -8,11 +8,14 @@ #pragma once +#include +#include + #include +#include #include #include #include -#include #ifdef FBGEMM_FBCODE #include "common/strings/UUID.h" #include "common/time/Time.h" @@ -126,7 +129,50 @@ class Initializer { /// @brief An implementation of EmbeddingKVDB for RocksDB /// class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { + using snapshot_ptr_t = const rocksdb::Snapshot*; + public: + class SnapshotHandle { + public: + explicit SnapshotHandle(EmbeddingRocksDB* db) : db_(db) { + auto num_shards = db->num_shards(); + CHECK_GT(num_shards, 0); + shard_snapshots_.reserve(num_shards); + for (auto shard = 0; shard < num_shards; ++shard) { + const auto* snapshot = db->dbs_[shard]->GetSnapshot(); + CHECK(snapshot != nullptr) + << "ERROR: create_snapshot fails to create a snapshot " + << "for db shard " << shard << ". Please make sure that " + << "inplace_update_support is set to false" << std::endl; + shard_snapshots_.push_back(snapshot); + } + } + + ~SnapshotHandle() { + for (auto shard = 0; shard < db_->dbs_.size(); ++shard) { + snapshot_ptr_t snapshot = shard_snapshots_[shard]; + CHECK(snapshot != nullptr) + << "Unexpected nullptr for snapshot " << shard; + db_->dbs_[shard]->ReleaseSnapshot(snapshot); + } + } + + void release() { + db_->release_snapshot(this); + } + + snapshot_ptr_t get_snapshot_for_shard(size_t shard) const { + CHECK_LE(shard, shard_snapshots_.size()); + return shard_snapshots_[shard]; + } + + private: + friend class EmbeddingRocksDB; + + EmbeddingRocksDB* db_; + std::vector shard_snapshots_; + }; + explicit EmbeddingRocksDB( std::string path, int64_t num_shards, @@ -152,7 +198,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { max_D, l2_cache_size_gb, tbe_unqiue_id, - row_storage_bitwidth / 8) { + row_storage_bitwidth / 8), + max_D_(max_D) { // TODO: lots of tunables. NNI or something for this? rocksdb::Options options; options.create_if_missing = true; @@ -193,7 +240,7 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { // causing flush set this to true to make update on the existing key // allow_concurrent_memtable_write is toggled in pair with // inplace_update_support - options.inplace_update_support = true; + options.inplace_update_support = false; options.avoid_unnecessary_blocking_io = true; options.use_direct_reads = true; @@ -341,11 +388,36 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { } } + ~EmbeddingRocksDB() override { + // clear all the snapshots if not released + if (snapshots_.size() > 0) { + LOG(WARNING) + << snapshots_.size() + << " snapshots have not been released when db is closing. Releasing them now."; + } + snapshots_.clear(); + for (auto shard = 0; shard < dbs_.size(); ++shard) { + dbs_[shard]->Close(); + } + } + + folly::coro::Task get_kv_db_async( + const at::Tensor& indices, + const at::Tensor& weights, + const at::Tensor& count) override { + co_await get_kv_db_async_impl( + indices, + weights, + count, + /*snapshot_handle=*/nullptr); + } + folly::coro::Task set_kv_db_async( const at::Tensor& indices, const at::Tensor& weights, const at::Tensor& count, - const bool is_bwd = false) override { + const kv_db::RocksdbWriteMode w_mode = + kv_db::RocksdbWriteMode::FWD_ROCKSDB_READ) override { RECORD_USER_SCOPE("EmbeddingRocksDB::set"); #ifdef FBGEMM_FBCODE auto start_ts = facebook::WallClockUtil::NowInUsecFast(); @@ -360,36 +432,43 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { -> folly::coro::Task { FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( weights.scalar_type(), "ssd_set", [&] { - CHECK(indices.is_contiguous()); - CHECK(weights.is_contiguous()); - auto indices_acc = indices.accessor(); - auto D = weights.size(1); - CHECK_EQ(indices.size(0), weights.size(0)); - { - rocksdb::WriteBatch batch( - (2 * (count_ + dbs_.size() - 1) / dbs_.size()) * - (sizeof(int64_t) + sizeof(scalar_t) * D)); - for (auto i = 0; i < count_; ++i) { - if (indices_acc[i] < 0) { - continue; - } - if (kv_db_utils::hash_shard( - indices_acc[i], dbs_.size()) != shard) { - continue; - } - batch.Put( - rocksdb::Slice( - reinterpret_cast( - &(indices.data_ptr()[i])), - sizeof(int64_t)), - rocksdb::Slice( - reinterpret_cast( - &(weights.data_ptr()[i * D])), - D * sizeof(scalar_t))); - } - auto s = dbs_[shard]->Write(wo_, &batch); - CHECK(s.ok()); - } + using value_t = scalar_t; + FBGEMM_DISPATCH_INTEGRAL_TYPES( + indices.scalar_type(), "ssd_set", [&] { + using index_t = scalar_t; + CHECK(indices.is_contiguous()); + CHECK(weights.is_contiguous()); + auto indices_acc = indices.accessor(); + auto D = weights.size(1); + CHECK_EQ(indices.size(0), weights.size(0)); + { + rocksdb::WriteBatch batch( + (2 * (count_ + dbs_.size() - 1) / + dbs_.size()) * + (sizeof(index_t) + sizeof(value_t) * D)); + for (auto i = 0; i < count_; ++i) { + if (indices_acc[i] < 0) { + continue; + } + if (kv_db_utils::hash_shard( + indices_acc[i], dbs_.size()) != shard) { + continue; + } + batch.Put( + rocksdb::Slice( + reinterpret_cast( + &(indices.data_ptr()[i])), + sizeof(index_t)), + rocksdb::Slice( + reinterpret_cast( + &(weights + .data_ptr()[i * D])), + D * sizeof(value_t))); + } + auto s = dbs_[shard]->Write(wo_, &batch); + CHECK(s.ok()); + } + }); }); co_return; }) @@ -398,141 +477,60 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { co_await folly::coro::collectAllRange(std::move(tasks)); #ifdef FBGEMM_FBCODE auto duration = facebook::WallClockUtil::NowInUsecFast() - start_ts; - if (is_bwd) { - bwd_write_total_duration_ += duration; - } else { - fwd_write_total_duration_ += duration; + switch (w_mode) { + case kv_db::RocksdbWriteMode::BWD_L1_CNFLCT_MISS_WRITE_BACK: + bwd_l1_cnflct_miss_write_back_dur_ += duration; + break; + case kv_db::RocksdbWriteMode::FWD_L1_EVICTION: + fwd_l1_eviction_dur_ += duration; + break; + case kv_db::RocksdbWriteMode::FWD_ROCKSDB_READ: + fwd_rocksdb_read_dur_ += duration; + break; + case kv_db::RocksdbWriteMode::FLUSH: + flush_write_dur_ += duration; + break; } #endif } - folly::coro::Task get_kv_db_async( - const at::Tensor& indices, - const at::Tensor& weights, - const at::Tensor& count) override { - RECORD_USER_SCOPE("EmbeddingRocksDB::get"); -#ifdef FBGEMM_FBCODE - auto start_ts = facebook::WallClockUtil::NowInUsecFast(); -#endif - std::vector> tasks; - auto count_ = count.item().toLong(); + bool is_valid_snapshot(const SnapshotHandle* snapshot_handle) const { + return snapshots_.find(snapshot_handle) != snapshots_.end(); + } - for (auto shard = 0; shard < dbs_.size(); ++shard) { - tasks.emplace_back( - folly::coro::co_invoke( - [this, &indices, &weights, count_, shard]() mutable - -> folly::coro::Task { - FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( - weights.scalar_type(), "ssd_get", [&] { - CHECK(indices.is_contiguous()); - CHECK(weights.is_contiguous()); - auto indices_data_ptr = indices.data_ptr(); - auto D = weights.size(1); - CHECK_EQ(indices.size(0), weights.size(0)); - auto weights_data_ptr = weights.data_ptr(); - FOLLY_DECLARE_REUSED(keys, std::vector); - FOLLY_DECLARE_REUSED(shard_ids, std::vector); - FOLLY_DECLARE_REUSED( - cfs, std::vector); - FOLLY_DECLARE_REUSED( - values, std::vector); - FOLLY_DECLARE_REUSED( - statuses, std::vector); - auto* dcf = dbs_[shard]->DefaultColumnFamily(); - for (auto i = 0; i < count_; ++i) { - // "no-op"/empty evicted tensor - if (indices_data_ptr[i] == -1) { - continue; - } - if (kv_db_utils::hash_shard( - indices_data_ptr[i], dbs_.size()) != shard) { - continue; - } - shard_ids.push_back(i); - } - std::sort( - shard_ids.begin(), - shard_ids.end(), - [&](int32_t lhs, int32_t rhs) { - const auto lhs_key = rocksdb::Slice( - reinterpret_cast( - &(indices_data_ptr[lhs])), - sizeof(int64_t)); - const auto rhs_key = rocksdb::Slice( - reinterpret_cast( - &(indices_data_ptr[rhs])), - sizeof(int64_t)); - return lhs_key.compare(rhs_key) < 0; - }); - for (const auto& i : shard_ids) { - const auto key = rocksdb::Slice( - reinterpret_cast( - &(indices_data_ptr[i])), - sizeof(int64_t)); - keys.push_back(key); - cfs.push_back(dcf); - } - CHECK_EQ(shard_ids.size(), keys.size()); - CHECK_EQ(shard_ids.size(), cfs.size()); - - values.resize(keys.size()); - statuses.resize(keys.size()); - dbs_[shard]->MultiGet( - ro_, - keys.size(), - cfs.data(), - keys.data(), - values.data(), - statuses.data(), - /*sorted_input=*/true); - const auto& init_storage = - initializers_[shard]->row_storage_; - // Sanity check - TORCH_CHECK( - init_storage.scalar_type() == weights.scalar_type(), - "init_storage (", - toString(init_storage.scalar_type()), - ") and weights scalar (", - toString(weights.scalar_type()), - ") types mismatch"); - auto row_storage_data_ptr = - init_storage.data_ptr(); - for (auto j = 0; j < keys.size(); ++j) { - const auto& s = statuses[j]; - int64_t i = shard_ids[j]; - const auto& value = values[j]; - if (s.ok()) { - if (!std::is_same::value) { - CHECK_EQ(value.size(), D * sizeof(scalar_t)); - } - std::copy( - reinterpret_cast(value.data()), - reinterpret_cast( - value.data() + value.size()), - &(weights_data_ptr[i * D])); - } else { - CHECK(s.IsNotFound()); - int64_t row_index; - initializers_[shard]->producer_queue_.dequeue( - row_index); - std::copy( - &(row_storage_data_ptr[row_index * D]), - &(row_storage_data_ptr[row_index * D + D]), - &(weights_data_ptr[i * D])); - initializers_[shard]->consumer_queue_.enqueue( - row_index); - } - } - }); - co_return; - }) - .scheduleOn(executor_.get())); + const SnapshotHandle* create_snapshot() { + const auto num_snapshots = snapshots_.size(); + if (num_snapshots > 0) { + std::cerr << "WARNING: create_snapshot found " << num_snapshots + << " other snapshots" << std::endl; } - co_await folly::coro::collectAllRange(std::move(tasks)); -#ifdef FBGEMM_FBCODE - auto duration = facebook::WallClockUtil::NowInUsecFast() - start_ts; - read_total_duration_ += duration; -#endif + + auto handle = std::make_unique(this); + auto handlePtr = handle.get(); + snapshots_[handlePtr] = std::move(handle); + return handlePtr; + } + + void release_snapshot(const SnapshotHandle* snapshot_handle) { + snapshots_.erase(snapshot_handle); + } + + void get_range_from_snapshot( + const at::Tensor& weights, + const int64_t start, + const int64_t length, + const SnapshotHandle* snapshot_handle) { + const auto seq_indices = + at::arange(start, start + length, at::TensorOptions().dtype(at::kLong)); + int64_t* count_ = new int64_t[1]; + count_[0] = length; + const auto count = at::from_blob(count_, {1}, at::kLong); + folly::coro::blockingWait( + get_kv_db_async_impl(seq_indices, weights, count, snapshot_handle)); + } + + int64_t get_max_D() { + return max_D_; } // collect mem usage on all db shards, checkout rocks_db_mem_properties @@ -560,17 +558,22 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { const int64_t step, const int64_t interval) { std::vector ret; - ret.reserve(3); + ret.reserve(5); if (step > 0 && step % interval == 0) { - auto read_dur = read_total_duration_.load(); - auto fwd_write_dur = fwd_write_total_duration_.load(); - auto bwd_write_dur = bwd_write_total_duration_.load(); + int64_t reset_val = 0; + auto read_dur = read_total_duration_.exchange(reset_val); + + auto fwd_rocksdb_read_dur = fwd_rocksdb_read_dur_.exchange(reset_val); + auto fwd_l1_eviction_dur = fwd_l1_eviction_dur_.exchange(reset_val); + auto bwd_l1_cnflct_miss_write_back_dur = + bwd_l1_cnflct_miss_write_back_dur_.exchange(reset_val); + auto flush_write_dur = flush_write_dur_.exchange(reset_val); + ret.push_back(double(read_dur) / interval); - ret.push_back(double(fwd_write_dur) / interval); - ret.push_back(double(bwd_write_dur) / interval); - read_total_duration_ = 0; - fwd_write_total_duration_ = 0; - bwd_write_total_duration_ = 0; + ret.push_back(double(fwd_rocksdb_read_dur) / interval); + ret.push_back(double(fwd_l1_eviction_dur) / interval); + ret.push_back(double(bwd_l1_cnflct_miss_write_back_dur) / interval); + ret.push_back(double(flush_write_dur) / interval); } return ret; } @@ -588,6 +591,10 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { } } + int64_t num_shards() const { + return dbs_.size(); + } + private: void flush_or_compact(const int64_t timestep) override { // Only do manual Flush/Compactions if enabled @@ -638,6 +645,152 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { } } + folly::coro::Task get_kv_db_async_impl( + const at::Tensor& indices, + const at::Tensor& weights, + const at::Tensor& count, + const SnapshotHandle* snapshot_handle) { + RECORD_USER_SCOPE("EmbeddingRocksDB::get"); +#ifdef FBGEMM_FBCODE + auto start_ts = facebook::WallClockUtil::NowInUsecFast(); +#endif + std::vector> tasks; + auto count_ = count.item().toLong(); + + for (auto shard = 0; shard < dbs_.size(); ++shard) { + // Get a snapshot for the shard + snapshot_ptr_t snapshot = snapshot_handle == nullptr + ? nullptr + : snapshot_handle->get_snapshot_for_shard(shard); + tasks.emplace_back( + folly::coro::co_invoke( + [this, &indices, &weights, count_, shard, snapshot]() mutable + -> folly::coro::Task { + FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( + weights.scalar_type(), "ssd_get", [&] { + using value_t = scalar_t; + FBGEMM_DISPATCH_INTEGRAL_TYPES( + indices.scalar_type(), "ssd_get", [&] { + using index_t = scalar_t; + CHECK(indices.is_contiguous()); + CHECK(weights.is_contiguous()); + auto indices_data_ptr = indices.data_ptr(); + auto D = weights.size(1); + CHECK_EQ(indices.size(0), weights.size(0)); + auto weights_data_ptr = weights.data_ptr(); + FOLLY_DECLARE_REUSED( + keys, std::vector); + FOLLY_DECLARE_REUSED( + shard_ids, std::vector); + FOLLY_DECLARE_REUSED( + cfs, std::vector); + FOLLY_DECLARE_REUSED( + values, std::vector); + FOLLY_DECLARE_REUSED( + statuses, std::vector); + auto* dcf = dbs_[shard]->DefaultColumnFamily(); + for (auto i = 0; i < count_; ++i) { + // "no-op"/empty evicted tensor + if (indices_data_ptr[i] == -1) { + continue; + } + if (kv_db_utils::hash_shard( + indices_data_ptr[i], dbs_.size()) != + shard) { + continue; + } + shard_ids.push_back(i); + } + std::sort( + shard_ids.begin(), + shard_ids.end(), + [&](int32_t lhs, int32_t rhs) { + const auto lhs_key = rocksdb::Slice( + reinterpret_cast( + &(indices_data_ptr[lhs])), + sizeof(index_t)); + const auto rhs_key = rocksdb::Slice( + reinterpret_cast( + &(indices_data_ptr[rhs])), + sizeof(index_t)); + return lhs_key.compare(rhs_key) < 0; + }); + for (const auto& i : shard_ids) { + const auto key = rocksdb::Slice( + reinterpret_cast( + &(indices_data_ptr[i])), + sizeof(index_t)); + keys.push_back(key); + cfs.push_back(dcf); + } + CHECK_EQ(shard_ids.size(), keys.size()); + CHECK_EQ(shard_ids.size(), cfs.size()); + + values.resize(keys.size()); + statuses.resize(keys.size()); + // Set a snapshot if it is available + ro_.snapshot = snapshot; + dbs_[shard]->MultiGet( + ro_, + keys.size(), + cfs.data(), + keys.data(), + values.data(), + statuses.data(), + /*sorted_input=*/true); + const auto& init_storage = + initializers_[shard]->row_storage_; + // Sanity check + TORCH_CHECK( + init_storage.scalar_type() == + weights.scalar_type(), + "init_storage (", + toString(init_storage.scalar_type()), + ") and weights scalar (", + toString(weights.scalar_type()), + ") types mismatch"); + auto row_storage_data_ptr = + init_storage.data_ptr(); + for (auto j = 0; j < keys.size(); ++j) { + const auto& s = statuses[j]; + int64_t i = shard_ids[j]; + const auto& value = values[j]; + if (s.ok()) { + if (!std::is_same::value) { + CHECK_EQ(value.size(), D * sizeof(value_t)); + } + std::copy( + reinterpret_cast( + value.data()), + reinterpret_cast( + value.data() + value.size()), + &(weights_data_ptr[i * D])); + } else { + CHECK(s.IsNotFound()); + int64_t row_index; + initializers_[shard]->producer_queue_.dequeue( + row_index); + std::copy( + &(row_storage_data_ptr[row_index * D]), + &(row_storage_data_ptr[row_index * D + D]), + &(weights_data_ptr[i * D])); + initializers_[shard]->consumer_queue_.enqueue( + row_index); + } + } + }); + }); + co_return; + }) + .scheduleOn(executor_.get())); + } + co_await folly::coro::collectAllRange(std::move(tasks)); +#ifdef FBGEMM_FBCODE + auto duration = facebook::WallClockUtil::NowInUsecFast() - start_ts; + read_total_duration_ += duration; +#endif + } + std::vector> dbs_; std::vector> initializers_; std::unique_ptr executor_; @@ -650,9 +803,17 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB { int64_t memtable_flush_period_; int64_t compaction_period_; int64_t l0_files_per_compact_; + + // break down on rocksdb write duration for details checkout RocksdbWriteMode std::atomic read_total_duration_{0}; - std::atomic fwd_write_total_duration_{0}; - std::atomic bwd_write_total_duration_{0}; -}; // class EmbeddingKVDB + std::atomic fwd_rocksdb_read_dur_{0}; + std::atomic fwd_l1_eviction_dur_{0}; + std::atomic bwd_l1_cnflct_miss_write_back_dur_{0}; + std::atomic flush_write_dur_{0}; + + std::unordered_map> + snapshots_; + int64_t max_D_; +}; // class EmbeddingRocksDB } // namespace ssd diff --git a/fbgemm_gpu/test/combine/common.py b/fbgemm_gpu/test/combine/common.py index a25d7cc9f..75c8bb529 100644 --- a/fbgemm_gpu/test/combine/common.py +++ b/fbgemm_gpu/test/combine/common.py @@ -11,14 +11,16 @@ import fbgemm_gpu import torch +from fbgemm_gpu.utils.loader import load_torch_module + # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. open_source: bool = getattr(fbgemm_gpu, "open_source", False) if not open_source: - if torch.version.hip: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_hip") - else: - torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine") + load_torch_module( + "//deeplearning/fbgemm/fbgemm_gpu:input_combine", + hip_path="//deeplearning/fbgemm/fbgemm_gpu:input_combine_hip", + ) torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_cpu") diff --git a/fbgemm_gpu/test/jagged/common.py b/fbgemm_gpu/test/jagged/common.py index 6cd60d96c..3bdcacf98 100644 --- a/fbgemm_gpu/test/jagged/common.py +++ b/fbgemm_gpu/test/jagged/common.py @@ -16,6 +16,7 @@ import fbgemm_gpu import fbgemm_gpu.sparse_ops import numpy as np +import numpy.typing as npt import torch from hypothesis import HealthCheck, settings @@ -122,7 +123,7 @@ def generate_jagged_tensor( # dynamo to mark the input as dynamic shape to make sure symbolic # shape is generated mark_dynamic: bool = False, -) -> Tuple[torch.Tensor, List[torch.LongTensor], np.ndarray]: +) -> Tuple[torch.Tensor, List[torch.LongTensor], npt.NDArray]: max_lengths = np.random.randint(low=1, high=10, size=(num_jagged_dim,)) x_offsets: List[torch.LongTensor] = [] num_lengths = outer_dense_size @@ -167,7 +168,7 @@ def generate_jagged_tensor( def to_padded_dense( values: torch.Tensor, offsets: List[torch.LongTensor], - max_lengths: np.ndarray, + max_lengths: npt.NDArray, padding_value: float = 0, ) -> torch.Tensor: outer_dense_size = len(offsets[0]) - 1 diff --git a/fbgemm_gpu/test/quantize/common.py b/fbgemm_gpu/test/quantize/common.py index 392bbfcac..5333cc893 100644 --- a/fbgemm_gpu/test/quantize/common.py +++ b/fbgemm_gpu/test/quantize/common.py @@ -12,6 +12,7 @@ import fbgemm_gpu import numpy as np +import numpy.typing as npt import torch # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. @@ -30,17 +31,17 @@ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") # Eigen/Python round 0.5 away from 0, Numpy rounds to even -round_to_nearest: Callable[[np.ndarray], np.ndarray] = np.vectorize(round) +round_to_nearest: Callable[[npt.NDArray], npt.NDArray] = np.vectorize(round) -def bytes_to_floats(byte_matrix: np.ndarray) -> np.ndarray: +def bytes_to_floats(byte_matrix: npt.NDArray) -> npt.NDArray: floats = np.empty([np.shape(byte_matrix)[0], 1], dtype=np.float32) for i, byte_values in enumerate(byte_matrix): (floats[i],) = struct.unpack("f", bytearray(byte_values)) return floats -def floats_to_bytes(floats: np.ndarray) -> np.ndarray: +def floats_to_bytes(floats: npt.NDArray) -> npt.NDArray: byte_matrix = np.empty([np.shape(floats)[0], 4], dtype=np.uint8) for i, value in enumerate(floats): assert isinstance(value, np.float32), (value, floats) @@ -53,7 +54,7 @@ def floats_to_bytes(floats: np.ndarray) -> np.ndarray: return byte_matrix -def bytes_to_half_floats(byte_matrix: np.ndarray) -> np.ndarray: +def bytes_to_half_floats(byte_matrix: npt.NDArray) -> npt.NDArray: floats = np.empty([np.shape(byte_matrix)[0], 1], dtype=np.float16) for i, byte_values in enumerate(byte_matrix): (floats[i],) = np.frombuffer( @@ -62,7 +63,7 @@ def bytes_to_half_floats(byte_matrix: np.ndarray) -> np.ndarray: return floats -def half_floats_to_bytes(floats: np.ndarray) -> np.ndarray: +def half_floats_to_bytes(floats: npt.NDArray) -> npt.NDArray: byte_matrix = np.empty([np.shape(floats)[0], 2], dtype=np.uint8) for i, value in enumerate(floats): assert isinstance(value, np.float16), (value, floats) @@ -72,7 +73,7 @@ def half_floats_to_bytes(floats: np.ndarray) -> np.ndarray: return byte_matrix -def fused_rowwise_8bit_quantize_reference(data: np.ndarray) -> np.ndarray: +def fused_rowwise_8bit_quantize_reference(data: npt.NDArray) -> npt.NDArray: minimum = np.min(data, axis=-1, keepdims=True) maximum = np.max(data, axis=-1, keepdims=True) span = maximum - minimum @@ -87,7 +88,9 @@ def fused_rowwise_8bit_quantize_reference(data: np.ndarray) -> np.ndarray: return np.concatenate([quantized_data, scale_bytes, bias_bytes], axis=-1) -def fused_rowwise_8bit_dequantize_reference(fused_quantized: np.ndarray) -> np.ndarray: +def fused_rowwise_8bit_dequantize_reference( + fused_quantized: npt.NDArray, +) -> npt.NDArray: scale = bytes_to_floats(fused_quantized[..., -8:-4].astype(np.uint8).reshape(-1, 4)) scale = scale.reshape(fused_quantized.shape[:-1] + (scale.shape[-1],)) bias = bytes_to_floats(fused_quantized[..., -4:].astype(np.uint8).reshape(-1, 4)) @@ -97,8 +100,8 @@ def fused_rowwise_8bit_dequantize_reference(fused_quantized: np.ndarray) -> np.n def fused_rowwise_8bit_dequantize_2bytes_padding_scale_bias_first_reference( - fused_quantized: np.ndarray, -) -> np.ndarray: + fused_quantized: npt.NDArray, +) -> npt.NDArray: scale = bytes_to_half_floats( fused_quantized[..., 0:2].astype(np.uint8).reshape(-1, 2) ) @@ -112,8 +115,8 @@ def fused_rowwise_8bit_dequantize_2bytes_padding_scale_bias_first_reference( def fused_rowwise_8bit_dequantize_reference_half( - fused_quantized: np.ndarray, -) -> np.ndarray: + fused_quantized: npt.NDArray, +) -> npt.NDArray: scale = bytes_to_half_floats( fused_quantized[..., -8:-4].astype(np.uint8).reshape(-1, 4) ) @@ -126,7 +129,7 @@ def fused_rowwise_8bit_dequantize_reference_half( return quantized_data * scale + bias -def fused_rowwise_nbit_quantize_reference(data: np.ndarray, bit: int) -> np.ndarray: +def fused_rowwise_nbit_quantize_reference(data: npt.NDArray, bit: int) -> npt.NDArray: minimum = np.min(data, axis=1).astype(np.float16).astype(np.float32) maximum = np.max(data, axis=1) span = maximum - minimum @@ -165,8 +168,8 @@ def fused_rowwise_nbit_quantize_reference(data: np.ndarray, bit: int) -> np.ndar def fused_rowwise_nbit_quantize_dequantize_reference( - data: np.ndarray, bit: int -) -> np.ndarray: + data: npt.NDArray, bit: int +) -> npt.NDArray: fused_quantized = fused_rowwise_nbit_quantize_reference(data, bit) scale = bytes_to_half_floats(fused_quantized[:, -4:-2].astype(np.uint8)).astype( np.float32 diff --git a/fbgemm_gpu/test/release/__init__.py b/fbgemm_gpu/test/release/__init__.py new file mode 100644 index 000000000..a9fdb3b99 --- /dev/null +++ b/fbgemm_gpu/test/release/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/fbgemm_gpu/test/release/example.json b/fbgemm_gpu/test/release/example.json new file mode 100644 index 000000000..50324b9a3 --- /dev/null +++ b/fbgemm_gpu/test/release/example.json @@ -0,0 +1,12 @@ +{ + "_description": "This is a dict containing example schemas. The schema of future releases need to be backward and forward compatible. For more details, please see https://docs.google.com/document/d/18I0lSkyHHqJ5BY30bx8YhpQHAMOg25nAFV2zeO8PIGk/edit#heading=h.y00l3f1ht5u1", + "_version": 1, + "data": { + "mx4_to_fp32": + "mx4_to_fp32(Tensor tensor, int group_size=32, bool use_triton=True, int ebits=2, int mbits=1) -> Tensor", + "merge_pooled_embeddings": + "merge_pooled_embeddings(Tensor[] pooled_embeddings, int uncat_dim_size, Device target_device, int cat_dim=1) -> Tensor", + "dummy_func": + "dummy_func(str var1, int var2) -> ()" + } +} diff --git a/fbgemm_gpu/test/release/stable_ops.json b/fbgemm_gpu/test/release/stable_ops.json new file mode 100644 index 000000000..d5fe76a24 --- /dev/null +++ b/fbgemm_gpu/test/release/stable_ops.json @@ -0,0 +1,30 @@ +{ + "_description": "This is a dict containing schema of FBGEMM_GPU ops that are marked as stable. The schema of future releases need to be backward and forward compatible. For more details, please see https://docs.google.com/document/d/18I0lSkyHHqJ5BY30bx8YhpQHAMOg25nAFV2zeO8PIGk/edit#heading=h.y00l3f1ht5u1", + "_version": 1, + "data": { + "torch.ops.fbgemm.jagged_to_padded_dense": + "fbgemm::jagged_to_padded_dense(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value = 0) -> Tensor", + "torch.ops.fbgemm.merge_pooled_embeddings": + "fbgemm::merge_pooled_embeddings(Tensor[] pooled_embeddings, SymInt uncat_dim_size, Device target_device, SymInt cat_dim=1) -> Tensor", + "torch.ops.fbgemm.permute_pooled_embs_auto_grad": + "fbgemm::permute_pooled_embs_auto_grad(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor", + "torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf": + "fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(Tensor input, int bit_rate) -> Tensor", + "torch.ops.fbgemm.permute_2D_sparse_data": + "fbgemm::permute_2D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)", + "torch.ops.fbgemm.permute_1D_sparse_data": + "fbgemm::permute_1D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)", + "torch.ops.fbgemm.expand_into_jagged_permute": + "fbgemm::expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, SymInt output_size) -> Tensor", + "torch.ops.fbgemm.block_bucketize_sparse_features": + "fbgemm::block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None, bool keep_orig_idx=False) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)", + "torch.ops.fbgemm.asynchronous_complete_cumsum": + "fbgemm::asynchronous_complete_cumsum(Tensor t_in) -> Tensor", + "torch.ops.fbgemm.offsets_range": + "fbgemm::offsets_range(Tensor offsets, SymInt range_size) -> Tensor", + "torch.ops.fbgemm.segment_sum_csr": + "fbgemm::segment_sum_csr(SymInt batch_size, Tensor csr_seg, Tensor values) -> Tensor", + "torch.ops.fbgemm.keyed_jagged_index_select_dim1": + "fbgemm::keyed_jagged_index_select_dim1(Tensor values, Tensor lengths, Tensor offsets, Tensor indices, SymInt batch_size, Tensor? weights=None, SymInt? selected_lengths_sum=None) -> Tensor[]" + } +} diff --git a/fbgemm_gpu/test/release/stable_release_test.py b/fbgemm_gpu/test/release/stable_release_test.py new file mode 100755 index 000000000..7055fd566 --- /dev/null +++ b/fbgemm_gpu/test/release/stable_release_test.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import json +import os +import unittest +from typing import Callable + +import fbgemm_gpu +import fbgemm_gpu.permute_pooled_embedding_modules +import fbgemm_gpu.sparse_ops + +import torch +from torch._C import FunctionSchema, parse_schema +from torch._utils_internal import get_file_path_2 + +from .utils import infer_schema + +# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. +open_source: bool = getattr(fbgemm_gpu, "open_source", False) + +if open_source: + from test_utils import TestSuite # pyre-fixme[21] + +else: + # pyre-fixme[21] + from fbgemm_gpu.test.test_utils import TestSuite + + +def _check_schema_compatibility( + schema: FunctionSchema, + ref_schema_str: str, +) -> None: + """ + Check if the schema is forward and backward compatible with the reference schema. + This function will raise an Exception error if the schema is not compatible. + + Args: + schema (FunctionSchema): The schema object to check. + ref_schema_str (str): The reference schema in string format. + Returns: + None + """ + assert isinstance(schema, FunctionSchema) + ref_schema = parse_schema(ref_schema_str) + # pyre-fixme[16] + fwd_compatible = schema.check_forward_compatible_with(ref_schema) + # pyre-fixme[16] + bwd_compatible = schema.is_backward_compatible_with(ref_schema) + msg = "" + if not fwd_compatible: + msg += f"Schema of {schema} is not forward compatible with {ref_schema}\n" + # pyre-fixme[16] + if not bwd_compatible: + msg += f"Schema of {schema} is not backward compatible with {ref_schema}" + assert fwd_compatible and bwd_compatible, msg + + +def check_schema_compatibility( + op: Callable, + ref_schema: str, +) -> None: + """ + Check if the schema of the given operator is forward and backward compatible with the reference schema. + This works with python functions whose schema do NOT have positional-only args, varargs, or varkwargs + For ops registered via torch.ops.fbgemm and ops with *args and **kwargs, please use check_schema_compatibility_from_op_name. + + Args: + op (Callable): The operator to check. + ref_schema (str): The reference schema in string format. + Returns: + None + """ + op_schema = infer_schema(op, mutates_args={}) + # pyre-fixme[16] + op_name = op.__name__ + # Create schema string + schema_str = f"{op_name}{op_schema}" + # Create FunctionalSchema + functional_schema = parse_schema(schema_str) + + # Get stable schema to compare against + return _check_schema_compatibility(functional_schema, ref_schema) + + +def check_schema_compatibility_from_op_name( + namespace: Callable, + op_name: str, + ref_schema_str: str, +) -> None: + """ + Check if the schema of the given operator is forward and backward compatible with the reference schema. + Use this function to check registered ops (via torch.ops.fbgemm). + This function will raise an Exception error if the schema is not compatible. + + Args: + namespace (Callable): The namespace of the operator e.g., torch.ops.fbgemm. + op_name (str): The name of the operator. + ref_schema_str (str): The reference schema in string format. + Returns: + None + """ + op = getattr(namespace, op_name) + schema = op._schemas[""] + + return _check_schema_compatibility(schema, ref_schema_str) + + +class StableRelease(TestSuite): # pyre-ignore[11] + def test_stable_schema(self) -> None: + """ + Test the schema compatibility of the operators against stable schema. + This is to ensure that any changes to the ops' schema do not break compatibility of the stable versions. + This test will fail if the current op schema is not forward or backward compatible with the stable schema. + """ + + # Load stable ops from file into dict + stable_dict_file = open( + get_file_path_2("", os.path.dirname(__file__), "stable_ops.json") + ) + stable_op_dict = json.load(stable_dict_file)["data"] + stable_dict_file.close() + # Get all op names + stable_op_names = set(stable_op_dict.keys()) + + # Check compatibility for all ops that are marked stable + for full_op_name in stable_op_names: + # Test the schema given the op name + ref_schema_str = stable_op_dict[full_op_name] + op_name = full_op_name.split(".")[3] + + check_schema_compatibility_from_op_name( + torch.ops.fbgemm, op_name, ref_schema_str + ) + + def test_example_ops(self) -> None: + """ + Test examples for schema compatibility. + """ + + # Load example ops to dict + stable_dict_file = open( + get_file_path_2("", os.path.dirname(__file__), "example.json") + ) + op_dict = json.load(stable_dict_file)["data"] + stable_dict_file.close() + + # Example op 1 + # Expect to pass + check_schema_compatibility( + fbgemm_gpu.sparse_ops.merge_pooled_embeddings, + op_dict["merge_pooled_embeddings"], + ) + + # Example op 2 + # stable schema is: dummy_func(str var1, int var2) -> ()" + def dummy_func(var1: str, var2: int, var3: torch.Tensor) -> None: + pass + + # Expect to fail + with self.assertRaises(AssertionError): # pyre-fixme[16] + check_schema_compatibility( + dummy_func, + op_dict["dummy_func"], + ) + + # Example op 3 + # stable schema is: dummy_func(str var1, int var2) -> ()" + def dummy_func(var1: str, var2: int, var3: str = "default") -> None: + pass + + # Expect to pass + check_schema_compatibility( + dummy_func, + op_dict["dummy_func"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/fbgemm_gpu/test/release/utils.py b/fbgemm_gpu/test/release/utils.py new file mode 100644 index 000000000..005cf38b5 --- /dev/null +++ b/fbgemm_gpu/test/release/utils.py @@ -0,0 +1,245 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +import typing +from typing import Iterable, List, Optional, Sequence, Union # noqa: F401 + +import torch +from torch import device, dtype, Tensor, types + +from torch._library.infer_schema import ( + derived_types, + parse_return, + supported_param, + SUPPORTED_PARAM_TYPES, + tuple_to_list, +) + +# Temporary work around for `infer_schema` + +# `get_supported_param_types` and `SUPPORTED_RETURN_TYPES` are modified from torch/_library/infer_schema.py +# as `torch.library.infer_schema` infers any `int` to be `SymInt` in the schema and does not +# support `str` as return type, which may not reflect the actual signature of the function. +# Other modifications are to address linter warning. +# The rest of the code is copied from `torch/_library/infer_schema.py` +# TO DO: clean up and remove this when we implement our own + + +def error_fn(what: str, sig: Optional[inspect.Signature] = None): + raise ValueError(f"infer_schema(func): {what} " f"Got func with signature {sig})") + + +def convert_type_string(annotation_type: str): + try: + return eval(annotation_type) + except Exception: + error_fn(f"Unsupported type annotation {annotation_type}. It is not a type.") + + +# Modified support param types and return types from torch/_library/infer_schema.py +def get_supported_param_types(): + data = [ + # (python type, schema type, type[] variant, type?[] variant, type[]? variant + (Tensor, "Tensor", True, True, False), + (int, "int", True, False, True), + (float, "float", True, False, True), + (bool, "bool", True, False, True), + (str, "str", False, False, False), + (types.Number, "Scalar", True, False, False), + (dtype, "ScalarType", False, False, False), + (device, "Device", False, False, False), + ] + result = [] + for line in data: + result.extend(derived_types(*line)) + return dict(result) + + +SUPPORTED_RETURN_TYPES = { + Tensor: "Tensor", + typing.List[Tensor]: "Tensor[]", + int: "int", + float: "float", + bool: "bool", + str: "str", + types.Number: "Scalar", +} + + +def check_param_annotation(name: str, annotation: type, sig: inspect.Signature): + if annotation is inspect.Parameter.empty: + error_fn(f"Parameter {name} must have a type annotation.", sig) + + # The annotation might be converted to a string by annotation, + # we convert it to the actual type. + annotation_type = annotation + if isinstance(annotation_type, str): + annotation_type = convert_type_string(annotation_type) + + if annotation_type not in SUPPORTED_PARAM_TYPES.keys(): + if annotation_type.__origin__ is tuple: + list_type = tuple_to_list(annotation_type) + example_type_str = "\n\n" + # Only suggest the list type if this type is supported. + if list_type in SUPPORTED_PARAM_TYPES.keys(): + example_type_str = f"For example, {list_type}.\n\n" + error_fn( + f"Parameter {name} has unsupported type {annotation}. " + f"We do not support Tuple inputs in schema. As a workaround, please try to use List instead. " + f"{example_type_str}" + f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}.", + sig, + ) + else: + error_fn( + f"Parameter {name} has unsupported type {annotation}. " + f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}.", + sig, + ) + return annotation_type + + +def get_schema_type( + schema_type: str, + mutates_args: Union[str, Iterable[str]], + name: str, + sig: inspect.Signature, + idx: int, +): + if isinstance(mutates_args, str): + if mutates_args != "unknown": + raise ValueError( + "mutates_args must either be a sequence of the names of " + "the arguments that are mutated or the string 'unknown'. " + ) + if schema_type.startswith("Tensor"): + schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" + elif name in mutates_args: + if not schema_type.startswith("Tensor"): + error_fn( + f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated" + ) + schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" + return schema_type + + +def check_mutates_args( + mutates_args: Union[str, Iterable[str]], sig: inspect.Signature, seen_args: set +): + if mutates_args != "unknown": + mutates_args_not_seen = set(mutates_args) - seen_args + if len(mutates_args_not_seen) > 0: + error_fn( + f"{mutates_args_not_seen} in mutates_args were not found in " + f"the custom op's signature. " + f"mutates_args should contain the names of all args that the " + f"custom op mutates, or just the string 'unknown' if you don't know.", + sig, + ) + + +def get_return_annonation( + return_annotation: type, +): + if isinstance(return_annotation, str): + return_annotation = convert_type_string(return_annotation) + return parse_return(return_annotation, error_fn) + + +def infer_schema( + prototype_function: typing.Callable, + /, + *, + mutates_args, + op_name: Optional[str] = None, +) -> str: + r""" + This is modified from torch._library.infer_schema.infer_schema. + + Parses the schema of a given function with type hints. The schema is inferred from the + function's type hints, and can be used to define a new operator. + + We make the following assumptions: + + * None of the outputs alias any of the inputs or each other. + * | String type annotations "device, dtype, Tensor, types" without library specification are + | assumed to be torch.*. Similarly, string type annotations "Optional, List, Sequence, Union" + | without library specification are assumed to be typing.*. + * | Only the args listed in ``mutates_args`` are being mutated. If ``mutates_args`` is "unknown", + | it assumes that all inputs to the operator are being mutates. + + Callers (e.g. the custom ops API) are responsible for checking these assumptions. + + Args: + prototype_function: The function from which to infer a schema for from its type annotations. + op_name (Optional[str]): The name of the operator in the schema. If ``name`` is None, then the + name is not included in the inferred schema. Note that the input schema to + ``torch.library.Library.define`` requires a operator name. + mutates_args ("unknown" | Iterable[str]): The arguments that are mutated in the function. + + Returns: + The inferred schema. + + Example: + >>> def foo_impl(x: torch.Tensor) -> torch.Tensor: + >>> return x.sin() + >>> + >>> infer_schema(foo_impl, op_name="foo", mutates_args={}) + foo(Tensor x) -> Tensor + >>> + >>> infer_schema(foo_impl, mutates_args={}) + (Tensor x) -> Tensor + """ + sig = inspect.signature(prototype_function) + + params = [] + seen_args = set() + saw_kwarg_only_arg = False + for idx, (name, param) in enumerate(sig.parameters.items()): + if not supported_param(param): + error_fn( + "We do not support positional-only args, varargs, or varkwargs.", sig + ) + + if param.kind == inspect.Parameter.KEYWORD_ONLY: + # The first time we see a kwarg-only arg, add "*" to the schema. + if not saw_kwarg_only_arg: + params.append("*") + saw_kwarg_only_arg = True + + annotation_type = check_param_annotation(name, param.annotation, sig) + + schema_type = SUPPORTED_PARAM_TYPES[annotation_type] + schema_type = get_schema_type(schema_type, mutates_args, name, sig, idx) + + seen_args.add(name) + if param.default is inspect.Parameter.empty: + params.append(f"{schema_type} {name}") + else: + default_repr = None + if param.default is None or isinstance(param.default, (int, float, bool)): + default_repr = str(param.default) + elif isinstance(param.default, (str, torch.device)): + default_repr = f'"{param.default}"' + elif isinstance(param.default, torch.dtype): + dtype_repr = str(param.default) + torch_dot = "torch." + assert dtype_repr.startswith(torch_dot) + default_repr = dtype_repr[len(torch_dot) :] + else: + error_fn( + f"Parameter {name} has an unsupported default value type {type(param.default)}. " + f"Please file an issue on GitHub so we can prioritize this.", + sig, + ) + params.append(f"{schema_type} {name}={default_repr}") + check_mutates_args(mutates_args, sig, seen_args) + + ret = get_return_annonation(sig.return_annotation) + if op_name is not None: + return f"{op_name}({', '.join(params)}) -> {ret}" + return f"({', '.join(params)}) -> {ret}" diff --git a/fbgemm_gpu/test/sparse/misc_ops_test.py b/fbgemm_gpu/test/sparse/misc_ops_test.py index fb9e29c81..41187b502 100644 --- a/fbgemm_gpu/test/sparse/misc_ops_test.py +++ b/fbgemm_gpu/test/sparse/misc_ops_test.py @@ -18,6 +18,7 @@ import numpy as np import torch from hypothesis import given, settings, Verbosity +from torch.fx.experimental.symbolic_shapes import ShapeEnv from .common import extend_test_class, open_source @@ -261,6 +262,31 @@ def test_bottom_unique_k_per_row( all_indices_deduped_ref = torch.as_tensor(all_indices[:, :, :L]) torch.testing.assert_close(all_indices_deduped, all_indices_deduped_ref) + def test_lengths_range(self) -> None: + # When 'output_shape' is None, the function will return a tensor with dynamic shape. + with self.assertRaisesRegex( + torch._subclasses.fake_tensor.DynamicOutputShapeException, + "fbgemm.lengths_range.default", + ): + with torch._subclasses.fake_tensor.FakeTensorMode( + shape_env=ShapeEnv( + allow_dynamic_output_shape_ops=False, + ), + ): + lengths = torch.tensor([3, 2, 4, 10], dtype=torch.int32) + _ = torch.ops.fbgemm.lengths_range(lengths, None) + + with torch._subclasses.fake_tensor.FakeTensorMode( + shape_env=ShapeEnv( + allow_dynamic_output_shape_ops=False, + ), + ): + lengths = torch.tensor([3, 2, 4, 10], dtype=torch.int32) + output_shape = [1, 2, 4, 4] + actual_result = torch.ops.fbgemm.lengths_range(lengths, output_shape) + + self.assertEqual(actual_result.shape, (1 * 2 * 4 * 4,)) + extend_test_class(MiscOpsTest) diff --git a/fbgemm_gpu/test/sparse/pack_segments_test.py b/fbgemm_gpu/test/sparse/pack_segments_test.py index c6383017a..d6b40328e 100644 --- a/fbgemm_gpu/test/sparse/pack_segments_test.py +++ b/fbgemm_gpu/test/sparse/pack_segments_test.py @@ -15,6 +15,7 @@ import hypothesis.strategies as st import numpy as np +import numpy.typing as npt import torch from hypothesis import given, settings @@ -27,7 +28,7 @@ from fbgemm_gpu.test.test_utils import gpu_available -def get_n_rand_num_summing_to_k(n: int, k: int) -> np.ndarray: +def get_n_rand_num_summing_to_k(n: int, k: int) -> npt.NDArray: """Get a list of `n` integers which collectively sum to `k`, drawn uniformly from the set of all such lists. @@ -58,7 +59,7 @@ def _pack_segments_ref( lengths: torch.Tensor, tensor: torch.Tensor, max_length: Optional[int] = None, - ) -> np.ndarray: + ) -> npt.NDArray: lengths = lengths.numpy() sections = np.split(tensor, np.cumsum(lengths)) max_length = np.max(lengths, initial=0) if max_length is None else max_length diff --git a/fbgemm_gpu/test/tbe/common.py b/fbgemm_gpu/test/tbe/common.py index df38e15de..40f1b49e7 100644 --- a/fbgemm_gpu/test/tbe/common.py +++ b/fbgemm_gpu/test/tbe/common.py @@ -17,8 +17,16 @@ # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. open_source: bool = getattr(fbgemm_gpu, "open_source", False) -if not open_source: +if open_source: + # pyre-ignore[21] + from test_utils import gpu_unavailable, running_on_github +else: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:cumem_utils") + from fbgemm_gpu.test.test_utils import ( # noqa F401 + gpu_unavailable, + running_on_github, + ) + torch.ops.import_module("fbgemm_gpu.sparse_ops") settings.register_profile("derandomize", derandomize=True) diff --git a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py index ed4245dbd..439797688 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py @@ -347,6 +347,62 @@ def test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function( ) torch.testing.assert_close(output_uvm, output_ref, equal_nan=True) + @given( + weights_ty=st.sampled_from( + [ + SparseType.FP32, + SparseType.FP16, + SparseType.INT8, + SparseType.INT4, + SparseType.INT2, + ] + ), + ) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) + def test_int_nbit_split_embedding_cpu_mixed_indices_offsets_dtypes( + self, + weights_ty: SparseType, + ) -> None: + T = random.randint(1, 5) + B = random.randint(1, 128) + L = random.randint(1, 20) + D = random.randint(2, 256) + log_E = random.randint(3, 5) + + iters = 4 + E = int(10**log_E) + + D_alignment = ( + 1 if weights_ty.bit_rate() % 8 == 0 else int(8 / weights_ty.bit_rate()) + ) + D = round_up(D, D_alignment) + + Ds = [D] * T + Es = [E] * T + cpu_locations = [EmbeddingLocation.HOST] * T + + cc = IntNBitTableBatchedEmbeddingBagsCodegen( + [("", E, D, weights_ty, M) for (E, D, M) in zip(Es, Ds, cpu_locations)], + device=torch.device("cpu"), + ) + cc.fill_random_weights() + + requests = generate_requests( + iters, B, T, L, min(Es), reuse=0.1, emulate_pruning=False, use_cpu=True + ) + dtypes_combo = [ + (torch.int64, torch.int64), + (torch.int32, torch.int32), + (torch.int32, torch.int64), + (torch.int64, torch.int32), + ] + for i, req in enumerate(requests): + indices, offsets = req.unpack_2() + indices_dtype, offsets_dtype = dtypes_combo[i] + indices = indices.to(indices_dtype) + offsets = offsets.to(offsets_dtype) + _ = cc(indices, offsets) + if __name__ == "__main__": unittest.main() diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_l2_cache_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_l2_cache_test.py new file mode 100644 index 000000000..ef5a9f0a2 --- /dev/null +++ b/fbgemm_gpu/test/tbe/ssd/ssd_l2_cache_test.py @@ -0,0 +1,231 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +# pyre-ignore-all-errors[3,6,56] + +import tempfile + +import threading +import time +import unittest + +from typing import Any, Dict, List, Tuple + +import hypothesis.strategies as st +import numpy as np +import torch +from fbgemm_gpu.split_embedding_configs import SparseType +from fbgemm_gpu.tbe.ssd import SSDTableBatchedEmbeddingBags +from fbgemm_gpu.tbe.utils import round_up +from hypothesis import given, settings, Verbosity + +from .. import common # noqa E402 +from ..common import gpu_unavailable, running_on_github + +MAX_EXAMPLES = 20 +default_st: Dict[str, Any] = { + "T": st.integers(min_value=1, max_value=10), + "D": st.integers(min_value=2, max_value=128), + "log_E": st.integers(min_value=2, max_value=3), + "mixed": st.booleans(), + "weights_precision": st.sampled_from([SparseType.FP32, SparseType.FP16]), +} + +default_settings: Dict[str, Any] = { + "verbosity": Verbosity.verbose, + "max_examples": MAX_EXAMPLES, + "deadline": None, +} + + +@unittest.skipIf(*running_on_github) +@unittest.skipIf(*gpu_unavailable) +class SSDCheckpointTest(unittest.TestCase): + def generate_fbgemm_ssd_tbe( + self, + T: int, + D: int, + log_E: int, + weights_precision: SparseType, + mixed: bool, + enable_l2: bool = True, + ) -> Tuple[SSDTableBatchedEmbeddingBags, List[int], List[int], int]: + E = int(10**log_E) + D = D * 4 + if not mixed: + Ds = [D] * T + Es = [E] * T + else: + Ds = [ + round_up(np.random.randint(low=int(0.25 * D), high=int(1.0 * D)), 4) + for _ in range(T) + ] + Es = [ + np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T) + ] + + feature_table_map = list(range(T)) + emb = SSDTableBatchedEmbeddingBags( + embedding_specs=[(E, D) for (E, D) in zip(Es, Ds)], + feature_table_map=feature_table_map, + ssd_storage_directory=tempfile.mkdtemp(), + cache_sets=1, + ssd_uniform_init_lower=-0.1, + ssd_uniform_init_upper=0.1, + weights_precision=weights_precision, + l2_cache_size=1 if enable_l2 else 0, + ) + return emb, Es, Ds, max(Ds) + + # @given(**default_st, do_flush=st.sampled_from([True, False])) + # @settings(**default_settings) + # def test_l2_flush( + # self, + # T: int, + # D: int, + # log_E: int, + # mixed: bool, + # weights_precision: SparseType, + # do_flush: bool, + # ) -> None: + # emb, Es, Ds, max_D = self.generate_fbgemm_ssd_tbe( + # T, D, log_E, weights_precision, mixed + # ) + # indices = torch.arange(start=0, end=sum(Es)) + # weights = torch.randn( + # indices.numel(), max_D, dtype=weights_precision.as_dtype() + # ) + # weights_from_l2 = torch.empty_like(weights) + # count = torch.as_tensor([indices.numel()]) + # emb.ssd_db.set_cuda(indices, weights, count, 1) + # emb.ssd_db.get_cuda(indices.clone(), weights_from_l2, count) + + # torch.cuda.synchronize() + # assert torch.equal(weights, weights_from_l2) + # import logging + + # logging.info(f"wgqtest {do_flush=}") + # weights_from_ssd = torch.empty_like(weights) + # if do_flush: + # emb.ssd_db.flush() + # emb.ssd_db.reset_l2_cache() + # emb.ssd_db.get_cuda(indices, weights_from_ssd, count) + # torch.cuda.synchronize() + # if do_flush: + # assert torch.equal(weights, weights_from_ssd) + # else: + # assert not torch.equal(weights, weights_from_ssd) + + # @given(**default_st, enable_l2=st.sampled_from([True, False])) + # @settings(**default_settings) + # def test_l2_io( + # self, + # T: int, + # D: int, + # log_E: int, + # mixed: bool, + # weights_precision: SparseType, + # enable_l2: bool, + # ) -> None: + # emb, Es, Ds, max_D = self.generate_fbgemm_ssd_tbe( + # T, D, log_E, weights_precision, mixed, enable_l2 + # ) + # E = int(10**log_E) + # num_rounds = 10 + # N = E + # total_indices = torch.tensor([]) + + # indices = torch.as_tensor( + # np.random.choice(E, replace=False, size=(N,)), dtype=torch.int64 + # ) + # weights = torch.randn( + # indices.numel(), max_D, dtype=weights_precision.as_dtype() + # ) + # sub_N = N // num_rounds + + # for _ in range(num_rounds): + # sub_indices = torch.as_tensor( + # np.random.choice(E, replace=False, size=(sub_N,)), dtype=torch.int64 + # ) + # sub_weights = weights[sub_indices, :] + # sub_weights_out = torch.empty_like(sub_weights) + # count = torch.as_tensor([sub_indices.numel()]) + # emb.ssd_db.set_cuda(sub_indices, sub_weights, count, 1) + # emb.ssd_db.get_cuda(sub_indices.clone(), sub_weights_out, count) + # torch.cuda.synchronize() + # assert torch.equal(sub_weights, sub_weights_out) + # total_indices = torch.cat((total_indices, sub_indices)) + # # dedup + # used_unique_indices = torch.tensor( + # list(set(total_indices.tolist())), dtype=torch.int64 + # ) + # stored_weights = weights[used_unique_indices, :] + # weights_out = torch.empty_like(stored_weights) + # count = torch.as_tensor([used_unique_indices.numel()]) + # emb.ssd_db.get_cuda(used_unique_indices.clone(), weights_out, count) + # torch.cuda.synchronize() + # assert torch.equal(stored_weights, weights_out) + + # emb.ssd_db.flush() + # emb.ssd_db.reset_l2_cache() + # weights_out = torch.empty_like(stored_weights) + # count = torch.as_tensor([used_unique_indices.numel()]) + # emb.ssd_db.get_cuda(used_unique_indices.clone(), weights_out, count) + # torch.cuda.synchronize() + # assert torch.equal(stored_weights, weights_out) + + @given(**default_st) + @settings(**default_settings) + def test_l2_prefetch_compatibility( + self, + T: int, + D: int, + log_E: int, + mixed: bool, + weights_precision: SparseType, + ) -> None: + weights_precision: SparseType = SparseType.FP32 + emb, Es, Ds, max_D = self.generate_fbgemm_ssd_tbe( + T, D, log_E, weights_precision, mixed + ) + E = int(10**log_E) + N = E + indices = torch.as_tensor( + np.random.choice(E, replace=False, size=(N,)), dtype=torch.int64 + ) + weights = torch.randn(N, max_D, dtype=weights_precision.as_dtype()) + new_weights = weights + 1 + weights_out = torch.empty_like(weights) + count = torch.as_tensor([E]) + emb.ssd_db.set(indices, weights, count) + emb.ssd_db.wait_util_filling_work_done() + + event = threading.Event() + get_sleep_ms = 50 + + # pyre-ignore + def trigger_get() -> None: + event.set() + emb.ssd_db.get(indices.clone(), weights_out, count, get_sleep_ms) + + # pyre-ignore + def trigger_set() -> None: + event.wait() + time.sleep( + get_sleep_ms / 1000.0 / 2 + ) # sleep half of the sleep time in get, making sure set is trigger after get but before get is done + emb.ssd_db.set(indices, new_weights, count) + + thread1 = threading.Thread(target=trigger_get) + thread2 = threading.Thread(target=trigger_set) + thread1.start() + thread2.start() + thread1.join() + thread2.join() + assert torch.equal(weights, weights_out) + emb.ssd_db.get(indices.clone(), weights_out, count) + assert torch.equal(new_weights, weights_out) diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py index 7cc75d9e9..61ac98d6e 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py @@ -94,15 +94,23 @@ def get_physical_table_arg_indices_(self, feature_table_map: List[int]): @given( weights_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]), + indice_int64_t=st.sampled_from([True, False]), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) - def test_ssd(self, weights_precision: SparseType) -> None: + def test_ssd(self, indice_int64_t: bool, weights_precision: SparseType) -> None: import tempfile E = int(1e4) D = 128 N = 100 - indices = torch.as_tensor(np.random.choice(E, replace=False, size=(N,))) + if indice_int64_t: + indices = torch.as_tensor( + np.random.choice(E, replace=False, size=(N,)), dtype=torch.int64 + ) + else: + indices = torch.as_tensor( + np.random.choice(E, replace=False, size=(N,)), dtype=torch.int32 + ) weights = torch.randn(N, D, dtype=weights_precision.as_dtype()) output_weights = torch.empty_like(weights) count = torch.tensor([N])