From 7e50cd9de88e481c787c3a5111028604be81b63e Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Fri, 12 Jan 2024 22:44:33 -0800 Subject: [PATCH] Re-organize TBE tests, pt 1 (#2263) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2263 - Re-organize TBE tests, pt 1 Reviewed By: spcyppt Differential Revision: D52714927 fbshipit-source-id: bd9b80c9db6b33e28770b00f58cae3af16e2b2b7 --- .github/scripts/fbgemm_gpu_test.bash | 6 +- .../comm_codec_test.py} | 0 fbgemm_gpu/test/tbe/__init__.py | 6 + fbgemm_gpu/test/tbe/backward_adagrad_test.py | 875 +++++ fbgemm_gpu/test/tbe/common.py | 61 + fbgemm_gpu/test/tbe/failures_dict_fast.json | 428 ++ fbgemm_gpu/test/tbe/forward_test.py | 889 +++++ .../inference_converter_test.py} | 0 .../split_table_batched_embeddings_test.py | 3444 +++++------------ .../utils_test.py} | 0 10 files changed, 3197 insertions(+), 2512 deletions(-) rename fbgemm_gpu/test/{quantize_comm_test.py => quantize/comm_codec_test.py} (100%) create mode 100644 fbgemm_gpu/test/tbe/__init__.py create mode 100644 fbgemm_gpu/test/tbe/backward_adagrad_test.py create mode 100644 fbgemm_gpu/test/tbe/common.py create mode 100644 fbgemm_gpu/test/tbe/failures_dict_fast.json create mode 100644 fbgemm_gpu/test/tbe/forward_test.py rename fbgemm_gpu/test/{split_embedding_inference_converter_test.py => tbe/inference_converter_test.py} (100%) rename fbgemm_gpu/test/{ => tbe}/split_table_batched_embeddings_test.py (78%) rename fbgemm_gpu/test/{split_embeddings_utils_test.py => tbe/utils_test.py} (100%) diff --git a/.github/scripts/fbgemm_gpu_test.bash b/.github/scripts/fbgemm_gpu_test.bash index be17318f0..fae8c90d8 100644 --- a/.github/scripts/fbgemm_gpu_test.bash +++ b/.github/scripts/fbgemm_gpu_test.bash @@ -84,19 +84,19 @@ run_fbgemm_gpu_tests () { # These are either non-tests or currently-broken tests in both FBGEMM_GPU and FBGEMM_GPU-CPU local files_to_skip=( - ./test_utils.py - ./split_table_batched_embeddings_test.py + ./tbe/split_table_batched_embeddings_test.py ./ssd_split_table_batched_embeddings_test.py ) if [ "$fbgemm_variant" == "cpu" ]; then # These are tests that are currently broken in FBGEMM_GPU-CPU local ignored_tests=( + ./tbe/forward_test.py ./uvm_test.py ) elif [ "$fbgemm_variant" == "rocm" ]; then - # https://github.com/pytorch/FBGEMM/issues/1559 local ignored_tests=( + # https://github.com/pytorch/FBGEMM/issues/1559 ./batched_unary_embeddings_test.py ) else diff --git a/fbgemm_gpu/test/quantize_comm_test.py b/fbgemm_gpu/test/quantize/comm_codec_test.py similarity index 100% rename from fbgemm_gpu/test/quantize_comm_test.py rename to fbgemm_gpu/test/quantize/comm_codec_test.py diff --git a/fbgemm_gpu/test/tbe/__init__.py b/fbgemm_gpu/test/tbe/__init__.py new file mode 100644 index 000000000..a9fdb3b99 --- /dev/null +++ b/fbgemm_gpu/test/tbe/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/fbgemm_gpu/test/tbe/backward_adagrad_test.py b/fbgemm_gpu/test/tbe/backward_adagrad_test.py new file mode 100644 index 000000000..a183b3150 --- /dev/null +++ b/fbgemm_gpu/test/tbe/backward_adagrad_test.py @@ -0,0 +1,875 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-ignore-all-errors[56] + +import copy +import unittest + +import fbgemm_gpu +import hypothesis.strategies as st +import numpy as np +import torch + +from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType +from fbgemm_gpu.split_embedding_utils import ( + b_indices, + get_table_batched_offsets_from_dense, + round_up, + to_device, +) +from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( + CacheAlgorithm, + EmbeddingLocation, + PoolingMode, +) +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + ComputeDevice, + SplitTableBatchedEmbeddingBagsCodegen, + WeightDecayMode, +) + +from hypothesis import assume, given, HealthCheck, settings, Verbosity + +from . import common # noqa E402,F401 +from .common import ( # noqa E402 + format_ref_tensors_in_mixed_B_layout, + gen_mixed_B_batch_sizes, + MAX_EXAMPLES_LONG_RUNNING, +) + +torch.ops.import_module("fbgemm_gpu.sparse_ops") + +# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. +open_source: bool = getattr(fbgemm_gpu, "open_source", False) + +if open_source: + # pyre-ignore[21] + from test_utils import ( + gpu_available, + gpu_unavailable, + gradcheck, + optests, + TEST_WITH_ROCM, + ) +else: + from fbgemm_gpu.test.test_utils import ( + gpu_available, + gpu_unavailable, + gradcheck, + optests, + TEST_WITH_ROCM, + ) + + +VERBOSITY: Verbosity = Verbosity.verbose + + +@optests.generate_opcheck_tests(fast=True) +@unittest.skipIf(*gpu_unavailable) +class BackwardAdagradTest(unittest.TestCase): + def execute_backward_adagrad_( # noqa C901 + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + D_gradcheck: int, + weights_precision: SparseType, + stochastic_rounding: bool, + weighted: bool, + row_wise: bool, + mixed: bool, + mixed_B: bool, + use_cache: bool, + cache_algorithm: CacheAlgorithm, + pooling_mode: PoolingMode, + use_cpu: bool, + output_dtype: SparseType, + weight_decay_mode: WeightDecayMode = WeightDecayMode.NONE, + ) -> None: + # NOTE: cache is not applicable to CPU version. + assume(not use_cpu or not use_cache) + + # NOTE: torch.autograd.gradcheck() is too time-consuming for CPU version + # so we have to limit (T * B * L * D)! + assume(not use_cpu or T * B * L * D <= 1024) + assume(not (use_cpu and weights_precision == SparseType.FP16)) + + assume( + pooling_mode == PoolingMode.SUM or not weighted + ) # No bag ops only work on GPUs, no mixed, no weighted + assume(not use_cpu or pooling_mode != PoolingMode.NONE) + assume(not mixed or pooling_mode != PoolingMode.NONE) + assume(not weighted or pooling_mode != PoolingMode.NONE) + # TODO: Support these cases + assume( + not mixed_B + or ( + weights_precision != SparseType.INT8 + and output_dtype != SparseType.INT8 + and not use_cpu + and not use_cache + and pooling_mode != PoolingMode.NONE + ) + ) + + emb_op = SplitTableBatchedEmbeddingBagsCodegen + if pooling_mode == PoolingMode.SUM: + mode = "sum" + do_pooling = True + elif pooling_mode == PoolingMode.MEAN: + mode = "mean" + do_pooling = True + elif pooling_mode == PoolingMode.NONE: + mode = "sum" + do_pooling = False + else: + # This proves that we have exhaustively checked all PoolingModes + raise RuntimeError("Unknown PoolingMode!") + + # stochastic rounding only implemented for rowwise + assume(not stochastic_rounding or row_wise) + # only row-wise supports caching + assume(row_wise or not use_cache) + + E = int(10**log_E) + if use_cpu: + D = (D + 15) // 16 * 4 + else: + D = D * 4 + if not mixed: + Ds = [D] * T + Es = [E] * T + else: + Ds = [ + round_up(np.random.randint(low=int(0.25 * D), high=int(1.0 * D)), 4) + for _ in range(T) + ] + Es = [ + np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T) + ] + + if not mixed_B: + Bs = [B] * T + else: + low = max(int(0.25 * B), 1) + high = int(B) + if low == high: + Bs = [B] * T + else: + Bs = [np.random.randint(low=low, high=high) for _ in range(T)] + + compute_device = ComputeDevice.CUDA + if use_cpu: + managed = [EmbeddingLocation.HOST] * T + compute_device = ComputeDevice.CPU + elif TEST_WITH_ROCM: + # ROCm managed memory allocation is under development + managed = [EmbeddingLocation.DEVICE] * T + elif use_cache: + managed = [EmbeddingLocation.MANAGED_CACHING] * T + if mixed: + average_D = sum(Ds) // T + for t, d in enumerate(Ds): + managed[t] = ( + EmbeddingLocation.DEVICE if d < average_D else managed[t] + ) + else: + managed = [ + np.random.choice( + [ + EmbeddingLocation.DEVICE, + EmbeddingLocation.MANAGED, + ] + ) + for _ in range(T) + ] + if do_pooling: + bs = [ + to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True), use_cpu) + for (E, D) in zip(Es, Ds) + ] + else: + bs = [ + to_device(torch.nn.Embedding(E, D, sparse=True), use_cpu) + for (E, D) in zip(Es, Ds) + ] + + if weights_precision == SparseType.FP16: + bs = [b.half() for b in bs] + + feature_table_map = list(range(T)) + # autograd with shared embedding only works for exact + table_to_replicate = T // 2 + # pyre-fixme[6]: For 2nd param expected `Embedding` but got + # `Union[Embedding, EmbeddingBag]`. + bs.insert(table_to_replicate, bs[table_to_replicate]) + feature_table_map.insert(table_to_replicate, table_to_replicate) + + num_features = len(feature_table_map) + if not mixed_B: + Bs = [B] * num_features + Bs_rank_feature = [[0]] + else: + Bs_rank_feature, Bs = gen_mixed_B_batch_sizes(B, num_features) + + xs = [ + to_device( + torch.from_numpy( + np.random.choice(range(Es[t]), size=(b, L), replace=True).astype( + np.int64 + ) + ), + use_cpu, + ) + for t, b in zip(feature_table_map, Bs) + ] + xws = [to_device(torch.randn(size=(b, L)), use_cpu) for b in Bs] + xws_acc_type = copy.deepcopy(xws) + + if weights_precision == SparseType.FP16 and not use_cpu: + xws = [xw.half() for xw in xws] + + fs = ( + [ + b_indices(b, x, use_cpu=use_cpu, do_pooling=do_pooling) + for (b, x) in zip(bs, xs) + ] + if not weighted + else [ + b_indices( + b, + x, + per_sample_weights=xw.view(-1), + use_cpu=use_cpu, + do_pooling=do_pooling, + ) + for (b, x, xw) in zip(bs, xs, xws) + ] + ) + gos = [torch.randn_like(f) for f in fs] + [f.backward(go) for (f, go) in zip(fs, gos)] + # do SGD update + lr = 0.5 + eps = 0.2 + + optimizer = ( + OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD + ) + cc = emb_op( + embedding_specs=[ + (E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed) + ], + feature_table_map=feature_table_map, + optimizer=optimizer, + learning_rate=lr, + eps=eps, + weights_precision=weights_precision, + stochastic_rounding=stochastic_rounding, + pooling_mode=pooling_mode, + output_dtype=output_dtype, + ) + + del bs[table_to_replicate] + for t in range(T): + cc.split_embedding_weights()[t].data.copy_(bs[t].weight) + + x = torch.cat([x.contiguous().flatten() for x in xs], dim=0) + xw = torch.cat([xw.contiguous().flatten() for xw in xws_acc_type], dim=0) + + (indices, offsets) = get_table_batched_offsets_from_dense( + x, L, sum(Bs), use_cpu=use_cpu + ) + + batch_size_per_feature_per_rank = Bs_rank_feature if mixed_B else None + + fc2 = ( + cc( + indices, + offsets, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) + if not weighted + else cc( + indices, + offsets, + to_device(xw.contiguous().view(-1), use_cpu), + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) + ) + if do_pooling: + if mixed_B: + goc = format_ref_tensors_in_mixed_B_layout(gos, Bs_rank_feature) + else: + goc = torch.cat([go.view(B, -1) for go in gos], dim=1) + else: + goc = torch.cat(gos, dim=0) + fc2.backward(goc) + cc.flush() + split_optimizer_states = cc.split_optimizer_states() + assert len(split_optimizer_states) == T + + get_optimizer_states = None + if row_wise: + # get_optimizer_state should/must be implemented for rowwise + get_optimizer_states = cc.get_optimizer_state() + assert len(get_optimizer_states) == T + + tolerance = ( + 1.0e-4 + if weights_precision == SparseType.FP32 and output_dtype == SparseType.FP32 + else 1.0e-2 + ) + + for t in range(T): + expected_keys = {"sum"} + if row_wise and weight_decay_mode == WeightDecayMode.COUNTER: + (m1, c1, c2) = split_optimizer_states[t] + expected_keys.update( + [ + "prev_iter", + "row_counter", + ] + ) + else: + (m1,) = split_optimizer_states[t] + if get_optimizer_states is not None: + optimizer_states_dict = get_optimizer_states[t] + assert set(optimizer_states_dict.keys()) == expected_keys + # pyre-fixme[16]: `Optional` has no attribute `float`. + ref_optimizer_state = bs[t].weight.grad.float().cpu().to_dense().pow(2) + torch.testing.assert_close( + m1.float().cpu(), + ref_optimizer_state.mean(dim=1) if row_wise else ref_optimizer_state, + atol=tolerance, + rtol=tolerance, + ) + for t in range(T): + # optimizer_state = squares (no row-wise) or sum squares (row-wise) + if row_wise and weight_decay_mode == WeightDecayMode.COUNTER: + (m1, c1, c2) = split_optimizer_states[t] + else: + (m1,) = split_optimizer_states[t] + torch.testing.assert_close( + cc.split_embedding_weights()[t].float().cpu(), + torch.addcdiv( + bs[t].weight.float().cpu(), + value=-lr, + tensor1=bs[t].weight.grad.float().cpu().to_dense(), + tensor2=m1.float() + .sqrt_() + .add_(eps) + .view(Es[t], 1 if row_wise else Ds[t]) + .cpu(), + ), + atol=tolerance, + rtol=tolerance, + ) + if use_cpu: + D_gradcheck = (D_gradcheck + 15) // 16 * 4 + else: + D_gradcheck = D_gradcheck * 4 + cc = emb_op( + embedding_specs=[ + (E, D_gradcheck, M, compute_device) for (E, M) in zip(Es, managed) + ], + feature_table_map=feature_table_map, + optimizer=optimizer, + learning_rate=0.0, + eps=eps, + weights_precision=weights_precision, + stochastic_rounding=stochastic_rounding, + # NOTE: only SUM pooling can work with per_sample_weights! + pooling_mode=PoolingMode.SUM, + output_dtype=output_dtype, + ) + per_sample_weights = to_device(xw.contiguous().view(-1), use_cpu) + per_sample_weights.requires_grad = True + indices.requires_grad = False + offsets.requires_grad = False + for param in cc.parameters(): + param.requires_grad = False + gradcheck( + cc, + ( + indices, + offsets, + per_sample_weights, + None, + batch_size_per_feature_per_rank, + ), + ) + + per_sample_weights = to_device(xw.contiguous().view(-1), use_cpu) + per_sample_weights.requires_grad = True + indices.requires_grad = False + offsets.requires_grad = False + for param in cc.parameters(): + param.requires_grad = False + y = cc( + indices, + offsets, + per_sample_weights, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) + y.sum().backward() + # pyre-fixme[16]: `Optional` has no attribute `clone`. + indice_weight_grad_all = per_sample_weights.grad.clone().cpu() + T_ = len(xws) + feature_requires_grad = to_device( + torch.tensor(np.random.choice([0, 1], replace=True, size=(T_,))).int(), + use_cpu, + ) + per_sample_weights = per_sample_weights.detach().clone() + per_sample_weights.requires_grad = True + y = cc( + indices, + offsets, + per_sample_weights, + feature_requires_grad=feature_requires_grad, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) + y.sum().backward() + indice_weight_grad_mask = per_sample_weights.grad.clone().cpu() + + if gpu_available and not TEST_WITH_ROCM: + torch.cuda.synchronize() + + acc_B = 0 + for t in range(T_): + B = Bs[t] + table_indice_weight_grad_mask = indice_weight_grad_mask[ + acc_B : acc_B + B * L + ] + table_indice_weight_grad_all = indice_weight_grad_all[acc_B : acc_B + B * L] + acc_B += B * L + if feature_requires_grad[t]: + torch.testing.assert_close( + table_indice_weight_grad_mask, + table_indice_weight_grad_all, + ) + else: + torch.testing.assert_close( + table_indice_weight_grad_mask, + torch.zeros_like(table_indice_weight_grad_mask), + ) + + @given( + T=st.integers(min_value=1, max_value=5), + D=st.integers(min_value=2, max_value=128), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=0, max_value=20), + D_gradcheck=st.integers(min_value=1, max_value=2), + weights_precision=st.just(SparseType.FP16), + stochastic_rounding=st.booleans(), + weighted=st.booleans(), + row_wise=st.booleans(), + mixed=st.booleans(), + mixed_B=st.booleans(), + use_cache=st.booleans(), + cache_algorithm=st.sampled_from(CacheAlgorithm), + use_cpu=st.booleans() + if (gpu_available and not TEST_WITH_ROCM) + else st.just(False) + if (gpu_available and TEST_WITH_ROCM) + else st.just(True), + output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], + ) + def test_backward_adagrad_fp16_pmSUM( # noqa C901 + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + D_gradcheck: int, + weights_precision: SparseType, + stochastic_rounding: bool, + weighted: bool, + row_wise: bool, + mixed: bool, + mixed_B: bool, + use_cache: bool, + cache_algorithm: CacheAlgorithm, + use_cpu: bool, + output_dtype: SparseType, + ) -> None: + # VBE is supported in rowwise_adagrad only + if not row_wise: + mixed_B = False + self.execute_backward_adagrad_( + T, + D, + B, + log_E, + L, + D_gradcheck, + weights_precision, + stochastic_rounding, + weighted, + row_wise, + mixed, + mixed_B, + use_cache, + cache_algorithm, + PoolingMode.SUM, + use_cpu, + output_dtype, + ) + + @given( + T=st.integers(min_value=1, max_value=5), + D=st.integers(min_value=2, max_value=128), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=0, max_value=20), + D_gradcheck=st.integers(min_value=1, max_value=2), + weights_precision=st.just(SparseType.FP16), + stochastic_rounding=st.booleans(), + weighted=st.booleans(), + row_wise=st.booleans(), + mixed=st.booleans(), + mixed_B=st.booleans(), + use_cache=st.booleans(), + cache_algorithm=st.sampled_from(CacheAlgorithm), + use_cpu=st.booleans() + if (gpu_available and not TEST_WITH_ROCM) + else st.just(False) + if (gpu_available and TEST_WITH_ROCM) + else st.just(True), + output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], + ) + def test_backward_adagrad_fp16_pmMEAN( # noqa C901 + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + D_gradcheck: int, + weights_precision: SparseType, + stochastic_rounding: bool, + weighted: bool, + row_wise: bool, + mixed: bool, + mixed_B: bool, + use_cache: bool, + cache_algorithm: CacheAlgorithm, + use_cpu: bool, + output_dtype: SparseType, + ) -> None: + # VBE is supported in rowwise_adagrad only + if not row_wise: + mixed_B = False + self.execute_backward_adagrad_( + T, + D, + B, + log_E, + L, + D_gradcheck, + weights_precision, + stochastic_rounding, + weighted, + row_wise, + mixed, + mixed_B, + use_cache, + cache_algorithm, + PoolingMode.MEAN, + use_cpu, + output_dtype, + ) + + @given( + T=st.integers(min_value=1, max_value=5), + D=st.integers(min_value=2, max_value=128), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=0, max_value=20), + D_gradcheck=st.integers(min_value=1, max_value=2), + weights_precision=st.just(SparseType.FP16), + stochastic_rounding=st.booleans(), + weighted=st.booleans(), + row_wise=st.booleans(), + mixed=st.booleans(), + use_cache=st.booleans(), + cache_algorithm=st.sampled_from(CacheAlgorithm), + use_cpu=st.booleans() + if (gpu_available and not TEST_WITH_ROCM) + else st.just(False) + if (gpu_available and TEST_WITH_ROCM) + else st.just(True), + output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], + ) + def test_backward_adagrad_fp16_pmNONE( # noqa C901 + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + D_gradcheck: int, + weights_precision: SparseType, + stochastic_rounding: bool, + weighted: bool, + row_wise: bool, + mixed: bool, + use_cache: bool, + cache_algorithm: CacheAlgorithm, + use_cpu: bool, + output_dtype: SparseType, + ) -> None: + self.execute_backward_adagrad_( + T, + D, + B, + log_E, + L, + D_gradcheck, + weights_precision, + stochastic_rounding, + weighted, + row_wise, + mixed, + False, # mixed_B + use_cache, + cache_algorithm, + PoolingMode.NONE, + use_cpu, + output_dtype, + ) + + @given( + T=st.integers(min_value=1, max_value=5), + D=st.integers(min_value=2, max_value=128), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=0, max_value=20), + D_gradcheck=st.integers(min_value=1, max_value=2), + weights_precision=st.just(SparseType.FP32), + stochastic_rounding=st.booleans(), + weighted=st.booleans(), + row_wise=st.booleans(), + mixed=st.booleans(), + mixed_B=st.booleans(), + use_cache=st.booleans(), + cache_algorithm=st.sampled_from(CacheAlgorithm), + use_cpu=st.booleans() + if (gpu_available and not TEST_WITH_ROCM) + else st.just(False) + if (gpu_available and TEST_WITH_ROCM) + else st.just(True), + output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], + ) + def test_backward_adagrad_fp32_pmSUM( # noqa C901 + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + D_gradcheck: int, + weights_precision: SparseType, + stochastic_rounding: bool, + weighted: bool, + row_wise: bool, + mixed: bool, + mixed_B: bool, + use_cache: bool, + cache_algorithm: CacheAlgorithm, + use_cpu: bool, + output_dtype: SparseType, + ) -> None: + # VBE is supported in rowwise_adagrad only + if not row_wise: + mixed_B = False + self.execute_backward_adagrad_( + T, + D, + B, + log_E, + L, + D_gradcheck, + weights_precision, + stochastic_rounding, + weighted, + row_wise, + mixed, + mixed_B, + use_cache, + cache_algorithm, + PoolingMode.SUM, + use_cpu, + output_dtype, + ) + + @given( + T=st.integers(min_value=1, max_value=5), + D=st.integers(min_value=2, max_value=128), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=0, max_value=20), + D_gradcheck=st.integers(min_value=1, max_value=2), + weights_precision=st.just(SparseType.FP32), + stochastic_rounding=st.booleans(), + weighted=st.booleans(), + row_wise=st.booleans(), + mixed=st.booleans(), + mixed_B=st.booleans(), + use_cache=st.booleans(), + cache_algorithm=st.sampled_from(CacheAlgorithm), + use_cpu=st.booleans() + if (gpu_available and not TEST_WITH_ROCM) + else st.just(False) + if (gpu_available and TEST_WITH_ROCM) + else st.just(True), + output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], + ) + def test_backward_adagrad_fp32_pmMEAN( # noqa C901 + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + D_gradcheck: int, + weights_precision: SparseType, + stochastic_rounding: bool, + weighted: bool, + row_wise: bool, + mixed: bool, + mixed_B: bool, + use_cache: bool, + cache_algorithm: CacheAlgorithm, + use_cpu: bool, + output_dtype: SparseType, + ) -> None: + # VBE is supported in rowwise_adagrad only + if not row_wise: + mixed_B = False + self.execute_backward_adagrad_( + T, + D, + B, + log_E, + L, + D_gradcheck, + weights_precision, + stochastic_rounding, + weighted, + row_wise, + mixed, + mixed_B, + use_cache, + cache_algorithm, + PoolingMode.MEAN, + use_cpu, + output_dtype, + ) + + @given( + T=st.integers(min_value=1, max_value=5), + D=st.integers(min_value=2, max_value=128), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=0, max_value=20), + D_gradcheck=st.integers(min_value=1, max_value=2), + weights_precision=st.just(SparseType.FP32), + stochastic_rounding=st.booleans(), + weighted=st.booleans(), + row_wise=st.booleans(), + mixed=st.booleans(), + use_cache=st.booleans(), + cache_algorithm=st.sampled_from(CacheAlgorithm), + use_cpu=st.booleans() + if (gpu_available and not TEST_WITH_ROCM) + else st.just(False) + if (gpu_available and TEST_WITH_ROCM) + else st.just(True), + output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], + ) + def test_backward_adagrad_fp32_pmNONE( # noqa C901 + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + D_gradcheck: int, + weights_precision: SparseType, + stochastic_rounding: bool, + weighted: bool, + row_wise: bool, + mixed: bool, + use_cache: bool, + cache_algorithm: CacheAlgorithm, + use_cpu: bool, + output_dtype: SparseType, + ) -> None: + self.execute_backward_adagrad_( + T, + D, + B, + log_E, + L, + D_gradcheck, + weights_precision, + stochastic_rounding, + weighted, + row_wise, + mixed, + False, # mixed_B + use_cache, + cache_algorithm, + PoolingMode.NONE, + use_cpu, + output_dtype, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/fbgemm_gpu/test/tbe/common.py b/fbgemm_gpu/test/tbe/common.py new file mode 100644 index 000000000..e4c9bb6be --- /dev/null +++ b/fbgemm_gpu/test/tbe/common.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Tuple + +import numpy as np +import torch + +from hypothesis import settings, Verbosity + +torch.ops.import_module("fbgemm_gpu.sparse_ops") + + +MAX_EXAMPLES = 40 + +# For long running tests reduce the number of iterations to reduce timeout errors. +MAX_EXAMPLES_LONG_RUNNING = 15 + + +VERBOSITY: Verbosity = Verbosity.verbose + + +settings.register_profile("derandomize", derandomize=True) +settings.load_profile("derandomize") + + +def gen_mixed_B_batch_sizes(B: int, T: int) -> Tuple[List[List[int]], List[int]]: + num_ranks = np.random.randint(low=1, high=4) + low = max(int(0.25 * B), 1) + high = int(B) + if low == high: + Bs_rank_feature = [[B] * num_ranks for _ in range(T)] + else: + Bs_rank_feature = [ + np.random.randint(low=low, high=high, size=num_ranks).tolist() + for _ in range(T) + ] + Bs = [sum(Bs_feature) for Bs_feature in Bs_rank_feature] + return Bs_rank_feature, Bs + + +def format_ref_tensors_in_mixed_B_layout( + ref_tensors: List[torch.Tensor], Bs_rank_feature: List[List[int]] +) -> torch.Tensor: + # Relayout the reference tensor + # Jagged dimension: (rank, table, local batch) + num_ranks = len(Bs_rank_feature[0]) + split_tensors = [[] for _ in range(num_ranks)] # shape (rank, table) + for t, ref_tensor in enumerate(ref_tensors): + assert ref_tensor.shape[0] == sum(Bs_rank_feature[t]) + tensors = ref_tensor.split(Bs_rank_feature[t]) + for r, tensor in enumerate(tensors): + split_tensors[r].append(tensor.flatten()) + concat_list = [] + for r in range(num_ranks): + concat_list += split_tensors[r] + return torch.cat(concat_list, dim=0) diff --git a/fbgemm_gpu/test/tbe/failures_dict_fast.json b/fbgemm_gpu/test/tbe/failures_dict_fast.json new file mode 100644 index 000000000..64f89df39 --- /dev/null +++ b/fbgemm_gpu/test/tbe/failures_dict_fast.json @@ -0,0 +1,428 @@ +{ + "_description": "This is a dict containing failures for tests autogenerated by generate_opcheck_tests. For more details, please see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit", + "_version": 1, + "data": { + "fbgemm::asynchronous_complete_cumsum": {}, + "fbgemm::bounds_check_indices": {}, + "fbgemm::dense_embedding_codegen_lookup_function": { + "SplitTableBatchedEmbeddingsTest.test_autograd_registration__test_backward_dense": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_dense": { + "comment": "", + "status": "xfail" + } + }, + "fbgemm::direct_mapped_lru_cache_populate_byte": {}, + "fbgemm::direct_mapped_lxu_cache_lookup": { + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_direct_mapped_uvm_cache_stats": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_uvm_cache": { + "comment": "", + "status": "xfail" + } + }, + "fbgemm::emb_inplace_update": {}, + "fbgemm::get_unique_indices": { + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_unique_lxu_cache_lookup": { + "comment": "", + "status": "xfail" + } + }, + "fbgemm::int_nbit_split_embedding_codegen_lookup_function": { + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function": { + "comment": "", + "status": "xsuccess" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_cache_miss_counter": { + "comment": "", + "status": "xsuccess" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_direct_mapped_uvm_cache_stats": { + "comment": "", + "status": "xsuccess" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_fused_pooled_emb_quant": { + "comment": "", + "status": "xsuccess" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_uvm_cache": { + "comment": "", + "status": "xsuccess" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_uvm_cache_stats": { + "comment": "", + "status": "xsuccess" + } + }, + "fbgemm::int_nbit_split_embedding_uvm_caching_codegen_lookup_function": { + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_schema__test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function": { + "comment": "", + "status": "xfail" + } + }, + "fbgemm::lfu_cache_populate": { + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_sgd": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_sgd_really_long_segments": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_pipeline": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_schema__test_cache_pipeline": { + "comment": "", + "status": "skip" + } + }, + "fbgemm::lfu_cache_populate_byte": { + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_uvm_cache": { + "comment": "", + "status": "xfail" + } + }, + "fbgemm::linearize_cache_indices": { + "BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp16_pmMEAN": { + "comment": "", + "status": "skip" + }, + "BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp16_pmNONE": { + "comment": "", + "status": "skip" + }, + "BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp16_pmSUM": { + "comment": "", + "status": "skip" + }, + "BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp32_pmMEAN": { + "comment": "", + "status": "skip" + }, + "BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp32_pmNONE": { + "comment": "", + "status": "skip" + }, + "BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp32_pmSUM": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_sgd": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_sgd_really_long_segments": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_miss_counter": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_pipeline": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_prefetch_pipeline": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_prefetch_pipeline_stream_1": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_prefetch_pipeline_stream_2": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_linearize_cache_indices": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_cache_miss_counter": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_direct_mapped_uvm_cache_stats": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_uvm_cache": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_uvm_cache_stats": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_stb_uvm_cache_stats": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_unique_lxu_cache_lookup": { + "comment": "", + "status": "xfail" + } + }, + "fbgemm::linearize_cache_indices_from_row_idx": { + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_linearize_cache_indices_from_row_idx": { + "comment": "", + "status": "xfail" + } + }, + "fbgemm::lru_cache_populate": { + "BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp16_pmMEAN": { + "comment": "", + "status": "skip" + }, + "BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp16_pmNONE": { + "comment": "", + "status": "skip" + }, + "BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp16_pmSUM": { + "comment": "", + "status": "skip" + }, + "BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp32_pmMEAN": { + "comment": "", + "status": "skip" + }, + "BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp32_pmNONE": { + "comment": "", + "status": "skip" + }, + "BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp32_pmSUM": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_sgd": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_sgd_really_long_segments": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_miss_counter": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_pipeline": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_prefetch_pipeline": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_prefetch_pipeline_stream_1": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_prefetch_pipeline_stream_2": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_stb_uvm_cache_stats": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_schema__test_cache_pipeline": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_schema__test_cache_prefetch_pipeline": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_schema__test_cache_prefetch_pipeline_stream_1": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_schema__test_cache_prefetch_pipeline_stream_2": { + "comment": "", + "status": "skip" + } + }, + "fbgemm::lru_cache_populate_byte": {}, + "fbgemm::lxu_cache_flush": {}, + "fbgemm::lxu_cache_locking_counter_decrement": {}, + "fbgemm::lxu_cache_lookup": { + "BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp16_pmMEAN": { + "comment": "", + "status": "xfail" + }, + "BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp16_pmNONE": { + "comment": "", + "status": "xfail" + }, + "BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp16_pmSUM": { + "comment": "", + "status": "xfail" + }, + "BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp32_pmMEAN": { + "comment": "", + "status": "xfail" + }, + "BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp32_pmNONE": { + "comment": "", + "status": "xfail" + }, + "BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp32_pmSUM": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_sgd": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_backward_sgd_really_long_segments": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_miss_counter": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_pipeline": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_prefetch_pipeline": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_prefetch_pipeline_stream_1": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_cache_prefetch_pipeline_stream_2": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_lxu_cache_lookup": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_cache_miss_counter": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_uvm_cache": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_uvm_cache_stats": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_stb_uvm_cache_stats": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_unique_lxu_cache_lookup": { + "comment": "", + "status": "xfail" + } + }, + "fbgemm::new_managed_tensor": {}, + "fbgemm::new_unified_tensor": { + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_cache_miss_counter": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_cache_update_function": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_direct_mapped_uvm_cache_stats": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_gpu_no_cache": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_gpu_no_cache_fp8_2048": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_uvm_cache": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_uvm_cache_stats": { + "comment": "", + "status": "xfail" + } + }, + "fbgemm::pruned_array_lookup": { + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_uvm_cache": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_pruning": { + "comment": "", + "status": "xfail" + } + }, + "fbgemm::pruned_hashmap_insert": {}, + "fbgemm::pruned_hashmap_lookup": { + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_uvm_cache": { + "comment": "", + "status": "xfail" + }, + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_pruning": { + "comment": "", + "status": "xfail" + } + }, + "fbgemm::reset_weight_momentum": { + "SplitTableBatchedEmbeddingsTest.test_faketensor__test_reset_embedding_weight_momentum": { + "comment": "", + "status": "skip" + }, + "SplitTableBatchedEmbeddingsTest.test_schema__test_reset_embedding_weight_momentum": { + "comment": "", + "status": "skip" + } + }, + "fbgemm::split_embedding_codegen_forward_unweighted_cuda": {}, + "fbgemm::split_embedding_codegen_forward_weighted_cuda": {}, + "fbgemm::split_embedding_codegen_lookup_adagrad_function": {}, + "fbgemm::split_embedding_codegen_lookup_adagrad_function_cpu": {}, + "fbgemm::split_embedding_codegen_lookup_adam_function": {}, + "fbgemm::split_embedding_codegen_lookup_lamb_function": {}, + "fbgemm::split_embedding_codegen_lookup_lars_sgd_function": {}, + "fbgemm::split_embedding_codegen_lookup_none_function": {}, + "fbgemm::split_embedding_codegen_lookup_partial_rowwise_adam_function": {}, + "fbgemm::split_embedding_codegen_lookup_partial_rowwise_lamb_function": {}, + "fbgemm::split_embedding_codegen_lookup_rowwise_adagrad_function": {}, + "fbgemm::split_embedding_codegen_lookup_rowwise_adagrad_function_cpu": {}, + "fbgemm::split_embedding_codegen_lookup_rowwise_weighted_adagrad_function": {}, + "fbgemm::split_embedding_codegen_lookup_sgd_function": {}, + "fbgemm::split_embedding_codegen_lookup_sgd_function_cpu": {} + } +} diff --git a/fbgemm_gpu/test/tbe/forward_test.py b/fbgemm_gpu/test/tbe/forward_test.py new file mode 100644 index 000000000..b007a8205 --- /dev/null +++ b/fbgemm_gpu/test/tbe/forward_test.py @@ -0,0 +1,889 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-ignore-all-errors[56] + +import copy +import random +import unittest + +from typing import Callable, Dict, List + +import fbgemm_gpu +import hypothesis.strategies as st +import numpy as np +import torch + +from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType +from fbgemm_gpu.split_embedding_utils import ( + b_indices, + generate_requests, + get_table_batched_offsets_from_dense, + round_up, + to_device, +) +from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( + CacheAlgorithm, + EmbeddingLocation, + PoolingMode, +) +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + ComputeDevice, + SplitTableBatchedEmbeddingBagsCodegen, +) + +from hypothesis import assume, given, HealthCheck, settings, Verbosity + +torch.ops.import_module("fbgemm_gpu.sparse_ops") + +# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. +open_source: bool = getattr(fbgemm_gpu, "open_source", False) + +if open_source: + # pyre-ignore[21] + from test_utils import gpu_unavailable, optests, TEST_WITH_ROCM +else: + from fbgemm_gpu.test.test_utils import gpu_unavailable, optests, TEST_WITH_ROCM + + +VERBOSITY: Verbosity = Verbosity.verbose + + +from . import common # noqa E402,F401 +from .common import ( # noqa E402 + format_ref_tensors_in_mixed_B_layout, + gen_mixed_B_batch_sizes, + MAX_EXAMPLES, + MAX_EXAMPLES_LONG_RUNNING, +) + +# pyre-ignore +additional_decorators: Dict[str, List[Callable]] = { + # TODO: Implement the operator registrations later + "test_faketensor__test_forward_cpu_int8": [ + unittest.skip("Operator not implemented for Meta tensors"), + ], + "test_faketensor__test_forward_fused_pooled_emb_quant": [ + unittest.skip("Operator not implemented for Meta tensors"), + ], + "test_faketensor__test_forward_gpu_no_cache_int8": [ + unittest.skip("Operator not implemented for Meta tensors"), + ], + "test_faketensor__test_forward_gpu_uvm_cache_int8": [ + unittest.skip("Operator not implemented for Meta tensors"), + ], +} + + +@optests.generate_opcheck_tests(fast=True, additional_decorators=additional_decorators) +class ForwardTest(unittest.TestCase): + def execute_forward_( # noqa C901 + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weights_precision: SparseType, + weighted: bool, + mixed: bool, + mixed_B: bool, + use_cache: bool, + cache_algorithm: CacheAlgorithm, + pooling_mode: PoolingMode, + use_cpu: bool, + output_dtype: SparseType, + use_experimental_tbe: bool, + ) -> None: + # NOTE: cache is not applicable to CPU version. + assume(not use_cpu or not use_cache) + # NOTE: limit (T * B * L * D) to avoid timeout for CPU version! + assume(not use_cpu or T * B * L * D <= 2048) + # NOTE: CPU does not support FP16. + assume(not (use_cpu and weights_precision == SparseType.FP16)) + + # NOTE: weighted operation can be done only for SUM. + assume(pooling_mode == PoolingMode.SUM or not weighted) + # NOTE: No bag ops only work on GPUs, no mixed + assume(not use_cpu or pooling_mode != PoolingMode.NONE) + assume(not mixed or pooling_mode != PoolingMode.NONE) + # TODO: Support these cases + assume( + not mixed_B + or ( + weights_precision != SparseType.INT8 + and output_dtype != SparseType.INT8 + and not use_cpu + and not use_cache + and pooling_mode != PoolingMode.NONE + ) + ) + + emb_op = SplitTableBatchedEmbeddingBagsCodegen + if pooling_mode == PoolingMode.SUM: + mode = "sum" + do_pooling = True + elif pooling_mode == PoolingMode.MEAN: + mode = "mean" + do_pooling = True + elif pooling_mode == PoolingMode.NONE: + mode = "sum" + do_pooling = False + else: + # This proves that we have exhaustively checked all PoolingModes + raise RuntimeError("Unknown PoolingMode!") + + E = int(10**log_E) + if use_cpu: + D = (D + 15) // 16 * 4 + else: + D = D * 4 + if not mixed: + Ds = [D] * T + Es = [E] * T + else: + Ds = [ + round_up(np.random.randint(low=int(0.25 * D), high=int(1.0 * D)), 4) + for _ in range(T) + ] + Es = [ + np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T) + ] + + if not mixed_B: + Bs = [B] * T + Bs_rank_feature = [[0]] + else: + Bs_rank_feature, Bs = gen_mixed_B_batch_sizes(B, T) + + compute_device = ComputeDevice.CUDA + if use_cpu: + managed = [EmbeddingLocation.HOST] * T + compute_device = ComputeDevice.CPU + elif TEST_WITH_ROCM: + # ROCm managed memory allocation is under development + managed = [EmbeddingLocation.DEVICE] * T + elif use_cache: + managed = [EmbeddingLocation.MANAGED_CACHING] * T + if mixed: + average_D = sum(Ds) // T + for t, d in enumerate(Ds): + managed[t] = ( + EmbeddingLocation.DEVICE if d < average_D else managed[t] + ) + else: + managed = [ + np.random.choice( + [ + EmbeddingLocation.DEVICE, + EmbeddingLocation.MANAGED, + ] + ) + for _ in range(T) + ] + if do_pooling: + bs = [ + to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True), use_cpu) + for (E, D) in zip(Es, Ds) + ] + else: + bs = [ + to_device(torch.nn.Embedding(E, D, sparse=True), use_cpu) + for (E, D) in zip(Es, Ds) + ] + if weights_precision == SparseType.INT8: + for t in range(T): + bs[t].weight.data.copy_( + torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloat( + torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized( + bs[t].weight.data + ) + ) + ) + + if weights_precision == SparseType.FP16: + bs = [b.half() for b in bs] + + # Generate indices + xs = [ + to_device(torch.randint(low=0, high=e, size=(b, L)), use_cpu) + for e, b in zip(Es, Bs) + ] + # Generate positional weights + xws = [to_device(torch.randn(size=(b, L)), use_cpu) for b in Bs] + xws_acc_type = copy.deepcopy(xws) + + if weights_precision == SparseType.FP16: + xws = [xw.half() for xw in xws] + + # Run baseline + fs = ( + [ + b_indices(b, x, use_cpu=use_cpu, do_pooling=do_pooling) + for (b, x) in zip(bs, xs) + ] + if not weighted + else [ + b_indices( + b, + x, + per_sample_weights=xw.view(-1), + use_cpu=use_cpu, + do_pooling=do_pooling, + ) + for (b, x, xw) in zip(bs, xs, xws) + ] + ) + + if do_pooling: + if mixed_B: + f = format_ref_tensors_in_mixed_B_layout(fs, Bs_rank_feature) + else: + f = torch.cat([f.view(B, -1) for f in fs], dim=1) + else: + f = torch.cat(fs, dim=0).view(-1, D) + + # Create a TBE op + cc = emb_op( + embedding_specs=[ + ( + E, + D, + EmbeddingLocation(M), + compute_device, + ) + for (E, D, M) in zip(Es, Ds, managed) + ], + weights_precision=weights_precision, + optimizer=OptimType.EXACT_ROWWISE_ADAGRAD + if mixed_B + else OptimType.EXACT_SGD, + learning_rate=0.05, + cache_algorithm=cache_algorithm, + pooling_mode=pooling_mode, + output_dtype=output_dtype, + use_experimental_tbe=use_experimental_tbe, + ) + # NOTE: test TorchScript-compatible! + cc = torch.jit.script(cc) + + for t in range(T): + cc.split_embedding_weights()[t].data.copy_( + bs[t].weight + if weights_precision != SparseType.INT8 + else torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(bs[t].weight) + ) + + x = torch.cat([x.contiguous().flatten() for x in xs], dim=0) + xw = torch.cat([xw.contiguous().flatten() for xw in xws_acc_type], dim=0) + + (indices, offsets) = get_table_batched_offsets_from_dense( + x, L, sum(Bs), use_cpu + ) + + batch_size_per_feature_per_rank = Bs_rank_feature if mixed_B else None + + # Run TBE + fc2 = ( + cc( + indices, + offsets, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) + if not weighted + else cc( + indices, + offsets, + to_device(xw.contiguous().view(-1), use_cpu), + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) + ) + + # Compare results: f = baseline, fc2 = TBE + tolerance = ( + 1.0e-5 + if weights_precision == SparseType.FP32 and output_dtype == SparseType.FP32 + else 8.0e-3 + ) + torch.testing.assert_close( + fc2.float(), f.float(), atol=tolerance, rtol=tolerance + ) + + def test_forward_cpu_int8( + self, + ) -> None: + weights_precision = SparseType.INT8 + use_cpu = True + T = random.randint(1, 10) + D = random.randint(2, min(256, int(2048 / T))) + B = random.randint(1, min(128, int(2048 / T / D))) + L = random.randint(0, min(20, int(2048 / T / D / B))) + log_E = random.randint(3, 5) + + use_cache = False + # cache_algorithm is don't care as we don't use cache. + cache_algorithm = CacheAlgorithm.LRU + + pooling_mode = random.choice( + [ + PoolingMode.SUM, + PoolingMode.MEAN, + ] + ) + mixed = False + mixed_B = False + if pooling_mode == PoolingMode.SUM: + weighted = random.choice([True, False]) + else: + weighted = False + self.execute_forward_( + T, + D, + B, + log_E, + L, + weights_precision, + weighted, + mixed, + mixed_B, + use_cache, + cache_algorithm, + pooling_mode, + use_cpu, + SparseType.FP32, + False, # use_experimental_tbe + ) + + def test_forward_cpu_fp32( + self, + ) -> None: + weights_precision = SparseType.FP32 + use_cpu = True + T = random.randint(1, 10) + D = random.randint(2, min(256, int(2048 / T))) + B = random.randint(1, min(128, int(2048 / T / D))) + L = random.randint(0, min(20, int(2048 / T / D / B))) + log_E = random.randint(3, 5) + + use_cache = False + # cache_algorithm is don't care as we don't use cache. + cache_algorithm = CacheAlgorithm.LRU + + pooling_mode = random.choice( + [ + PoolingMode.SUM, + PoolingMode.MEAN, + ] + ) + mixed = False + mixed_B = False + if pooling_mode == PoolingMode.SUM: + weighted = random.choice([True, False]) + else: + weighted = False + self.execute_forward_( + T, + D, + B, + log_E, + L, + weights_precision, + weighted, + mixed, + mixed_B, + use_cache, + cache_algorithm, + pooling_mode, + use_cpu, + SparseType.FP32, + False, # use_experimental_tbe + ) + + @unittest.skipIf(*gpu_unavailable) + def test_forward_gpu_no_cache_int8( + self, + ) -> None: + weights_precision = SparseType.INT8 + use_cpu = False + T = random.randint(1, 10) + D = random.randint(2, 256) + B = random.randint(1, 128) + L = random.randint(0, 20) + log_E = random.randint(3, 5) + + use_cache = False + # cache_algorithm is don't care as we don't use cache. + cache_algorithm = CacheAlgorithm.LRU + + pooling_mode = random.choice( + [ + PoolingMode.SUM, + PoolingMode.MEAN, + PoolingMode.NONE, + ] + ) + if pooling_mode == PoolingMode.NONE: + mixed = False + else: + mixed = random.choice([True, False]) + mixed_B = False + if pooling_mode == PoolingMode.SUM: + weighted = random.choice([True, False]) + else: + weighted = False + self.execute_forward_( + T, + D, + B, + log_E, + L, + weights_precision, + weighted, + mixed, + mixed_B, + use_cache, + cache_algorithm, + pooling_mode, + use_cpu, + SparseType.FP32, + False, # use_experimental_tbe + ) + + @unittest.skipIf(*gpu_unavailable) + @given( + use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], + ) + def test_forward_gpu_no_cache_fp16( + self, + use_experimental_tbe: bool, + ) -> None: + weights_precision = SparseType.FP16 + use_cpu = False + T = random.randint(1, 10) + D = random.randint(2, 256) + B = random.randint(1, 128) + L = random.randint(0, 20) + log_E = random.randint(3, 5) + + use_cache = False + # cache_algorithm is don't care as we don't use cache. + cache_algorithm = CacheAlgorithm.LRU + + pooling_mode = random.choice( + [ + PoolingMode.SUM, + PoolingMode.MEAN, + ] + + ([PoolingMode.NONE] if not use_experimental_tbe else []) + ) + if pooling_mode == PoolingMode.NONE: + mixed = False + mixed_B = False + else: + mixed = random.choice([True, False]) + mixed_B = ( + random.choice([True, False]) if not use_experimental_tbe else False + ) + if pooling_mode == PoolingMode.SUM: + weighted = random.choice([True, False]) + else: + weighted = False + self.execute_forward_( + T, + D, + B, + log_E, + L, + weights_precision, + weighted, + mixed, + mixed_B, + use_cache, + cache_algorithm, + pooling_mode, + use_cpu, + SparseType.FP32, + use_experimental_tbe, + ) + + @unittest.skipIf(*gpu_unavailable) + @given( + use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], + ) + def test_forward_gpu_no_cache_fp32( + self, + use_experimental_tbe: bool, + ) -> None: + weights_precision = SparseType.FP32 + use_cpu = False + T = random.randint(1, 10) + D = random.randint(2, 256) + B = random.randint(1, 128) + L = random.randint(0, 20) + log_E = random.randint(3, 5) + + use_cache = False + # cache_algorithm is don't care as we don't use cache. + cache_algorithm = CacheAlgorithm.LRU + + pooling_mode = random.choice( + [ + PoolingMode.SUM, + PoolingMode.MEAN, + ] + + ([PoolingMode.NONE] if not use_experimental_tbe else []) + ) + if pooling_mode == PoolingMode.NONE: + mixed = False + mixed_B = False + else: + mixed = random.choice([True, False]) + mixed_B = ( + random.choice([True, False]) if not use_experimental_tbe else False + ) + if pooling_mode == PoolingMode.SUM: + weighted = random.choice([True, False]) + else: + weighted = False + self.execute_forward_( + T, + D, + B, + log_E, + L, + weights_precision, + weighted, + mixed, + mixed_B, + use_cache, + cache_algorithm, + pooling_mode, + use_cpu, + SparseType.FP32, + use_experimental_tbe, + ) + + @unittest.skipIf(*gpu_unavailable) + @given( + cache_algorithm=st.sampled_from(CacheAlgorithm), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], + ) + def test_forward_gpu_uvm_cache_int8( + self, + cache_algorithm: CacheAlgorithm, + ) -> None: + weights_precision = SparseType.INT8 + use_cpu = False + T = random.randint(1, 10) + D = random.randint(2, 256) + B = random.randint(1, 128) + L = random.randint(0, 20) + log_E = random.randint(3, 5) + + use_cache = True + + pooling_mode = random.choice( + [ + PoolingMode.SUM, + PoolingMode.MEAN, + PoolingMode.NONE, + ] + ) + output_dtype = random.choice( + [ + SparseType.FP32, + SparseType.FP16, + ] + ) + if pooling_mode == PoolingMode.NONE: + mixed = False + else: + mixed = random.choice([True, False]) + mixed_B = False + if pooling_mode == PoolingMode.SUM: + weighted = random.choice([True, False]) + else: + weighted = False + self.execute_forward_( + T, + D, + B, + log_E, + L, + weights_precision, + weighted, + mixed, + mixed_B, + use_cache, + cache_algorithm, + pooling_mode, + use_cpu, + output_dtype, + False, # use_experimental_tbe + ) + + @unittest.skipIf(*gpu_unavailable) + @given( + cache_algorithm=st.sampled_from(CacheAlgorithm), + use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], + ) + def test_forward_gpu_uvm_cache_fp16( + self, + cache_algorithm: CacheAlgorithm, + use_experimental_tbe: bool, + ) -> None: + weights_precision = SparseType.FP16 + use_cpu = False + T = random.randint(1, 10) + D = random.randint(2, 256) + B = random.randint(1, 128) + L = random.randint(0, 20) + log_E = random.randint(3, 5) + + use_cache = True + + pooling_mode = random.choice( + [ + PoolingMode.SUM, + PoolingMode.MEAN, + ] + + ([PoolingMode.NONE] if not use_experimental_tbe else []) + ) + output_dtype = random.choice( + [ + SparseType.FP32, + SparseType.FP16, + ] + ) + if pooling_mode == PoolingMode.NONE: + mixed = False + else: + mixed = random.choice([True, False]) + mixed_B = False + if pooling_mode == PoolingMode.SUM: + weighted = random.choice([True, False]) + else: + weighted = False + self.execute_forward_( + T, + D, + B, + log_E, + L, + weights_precision, + weighted, + mixed, + mixed_B, + use_cache, + cache_algorithm, + pooling_mode, + use_cpu, + output_dtype, + use_experimental_tbe, + ) + + @unittest.skipIf(*gpu_unavailable) + @given( + cache_algorithm=st.sampled_from(CacheAlgorithm), + use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], + ) + def test_forward_gpu_uvm_cache_fp32( + self, + cache_algorithm: CacheAlgorithm, + use_experimental_tbe: bool, + ) -> None: + weights_precision = SparseType.FP32 + use_cpu = False + T = random.randint(1, 10) + D = random.randint(2, 256) + B = random.randint(1, 128) + L = random.randint(0, 20) + log_E = random.randint(3, 5) + + use_cache = True + + pooling_mode = random.choice( + [ + PoolingMode.SUM, + PoolingMode.MEAN, + ] + + ([PoolingMode.NONE] if not use_experimental_tbe else []) + ) + output_dtype = random.choice( + [ + SparseType.FP32, + SparseType.FP16, + ] + ) + if pooling_mode == PoolingMode.NONE: + mixed = False + else: + mixed = random.choice([True, False]) + mixed_B = False + if pooling_mode == PoolingMode.SUM: + weighted = random.choice([True, False]) + else: + weighted = False + self.execute_forward_( + T, + D, + B, + log_E, + L, + weights_precision, + weighted, + mixed, + mixed_B, + use_cache, + cache_algorithm, + pooling_mode, + use_cpu, + output_dtype, + use_experimental_tbe, + ) + + @unittest.skipIf(*gpu_unavailable) + @given( + T=st.integers(min_value=1, max_value=10), + D=st.integers(min_value=2, max_value=128), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=0, max_value=20), + output_dtype=st.sampled_from([SparseType.FP16, SparseType.INT8]), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + suppress_health_check=[HealthCheck.filter_too_much], + ) + def test_forward_fused_pooled_emb_quant( + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + output_dtype: SparseType, + ) -> None: + Ds = [ + round_up(np.random.randint(low=int(max(0.25 * D, 1)), high=int(1.0 * D)), 4) + for _ in range(T) + ] + E = int(10**log_E) + Es = [np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T)] + + op = SplitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + ( + E, + D, + EmbeddingLocation.DEVICE, + ComputeDevice.CUDA, + ) + for (E, D) in zip(Es, Ds) + ], + output_dtype=output_dtype, + device=torch.cuda.current_device(), + ) + op_ref = SplitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + ( + E, + D, + EmbeddingLocation.DEVICE, + ComputeDevice.CUDA, + ) + for (E, D) in zip(Es, Ds) + ], + output_dtype=SparseType.FP32, + device=torch.cuda.current_device(), + ) + # sync weights between two ops + split_weights = op.split_embedding_weights() + ref_split_weights = op_ref.split_embedding_weights() + for t in range(T): + split_weights[t].data.copy_(ref_split_weights[t]) + + requests = generate_requests(2, B, T, L, min(Es), reuse=0.1) + + for indices, offsets, _ in requests: + lowp_pooled_output = op( + indices=indices, + offsets=offsets, + ) + fp32_pooled_output = op_ref( + indices=indices, + offsets=offsets, + ) + lowp_pooled_emb_split = [ + d + 8 if output_dtype == SparseType.INT8 else d for d in op.dims + ] + lowp_pooled_output_per_table = torch.split( + lowp_pooled_output, lowp_pooled_emb_split, dim=1 + ) + deq_lowp_pooled_output_per_table = [ + torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloat(t.contiguous()) + if output_dtype == SparseType.INT8 + else t.float() + for t in lowp_pooled_output_per_table + ] + fp32_pooled_output_per_table = torch.split( + fp32_pooled_output, op.dims, dim=1 + ) + dq_fp32_pooled_output_per_table = [ + torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloat( + torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized( + t.contiguous() + ).contiguous() + ) + if output_dtype == SparseType.INT8 + else t.half().float() + for t in fp32_pooled_output_per_table + ] + cat_deq_lowp_pooled_output = torch.cat( + deq_lowp_pooled_output_per_table, dim=1 + ) + cat_dq_fp32_pooled_output = torch.cat( + dq_fp32_pooled_output_per_table, dim=1 + ) + torch.testing.assert_close( + cat_deq_lowp_pooled_output, cat_dq_fp32_pooled_output + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/fbgemm_gpu/test/split_embedding_inference_converter_test.py b/fbgemm_gpu/test/tbe/inference_converter_test.py similarity index 100% rename from fbgemm_gpu/test/split_embedding_inference_converter_test.py rename to fbgemm_gpu/test/tbe/inference_converter_test.py diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/tbe/split_table_batched_embeddings_test.py similarity index 78% rename from fbgemm_gpu/test/split_table_batched_embeddings_test.py rename to fbgemm_gpu/test/tbe/split_table_batched_embeddings_test.py index 52a3b555c..434ab28d6 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/tbe/split_table_batched_embeddings_test.py @@ -183,6 +183,31 @@ def format_ref_tensors_in_mixed_B_layout( "test_faketensor__test_nbit_direct_mapped_uvm_cache_stats": [ unittest.skip("very slow"), ], + # Implement the operator registrations later + "test_faketensor__test_forward_cpu_int8": [ + unittest.skip("Operator not implemented for Meta tensors"), + ], + "test_faketensor__test_forward_fused_pooled_emb_quant": [ + unittest.skip("Operator not implemented for Meta tensors"), + ], + "test_faketensor__test_forward_gpu_no_cache_int8": [ + unittest.skip("Operator not implemented for Meta tensors"), + ], + "test_faketensor__test_forward_gpu_uvm_cache_int8": [ + unittest.skip("Operator not implemented for Meta tensors"), + ], + "test_faketensor__test_nbit_forward_cpu": [ + unittest.skip("Operator not implemented for Meta tensors"), + ], + "test_faketensor__test_nbit_forward_fused_pooled_emb_quant": [ + unittest.skip("Operator not implemented for Meta tensors"), + ], + "test_faketensor__test_nbit_forward_gpu_no_cache": [ + unittest.skip("Operator not implemented for Meta tensors"), + ], + "test_faketensor__test_nbit_forward_gpu_no_cache_fp8_2048": [ + unittest.skip("Operator not implemented for Meta tensors"), + ], } @@ -190,1446 +215,931 @@ def format_ref_tensors_in_mixed_B_layout( class SplitTableBatchedEmbeddingsTest(unittest.TestCase): _do_cuda_memory_leak_check = True - def execute_forward_( # noqa C901 + @unittest.skipIf(*gpu_unavailable) + @given( + T=st.integers(min_value=1, max_value=10), + D=st.integers(min_value=2, max_value=128), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=0, max_value=20), + weights_ty=st.sampled_from( + [ + SparseType.FP32, + SparseType.FP16, + SparseType.INT8, + SparseType.INT4, + # FIXME: INT2 caused big numerical error for this test + # SparseType.INT2, + ] + ), + output_dtype=st.sampled_from( + [ + SparseType.FP16, + SparseType.BF16, + SparseType.INT8, + # SparseType.INT4, + ] + ) + if not TEST_WITH_ROCM + else st.sampled_from( + [ + SparseType.FP16, + # The counterparts of __nv_bfloat16 and __nv_bfloat162 are not supported on ROCm + SparseType.INT8, + # SparseType.INT4, + ] + ), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + suppress_health_check=[HealthCheck.filter_too_much], + ) + def test_nbit_forward_fused_pooled_emb_quant( self, T: int, D: int, B: int, log_E: int, L: int, - weights_precision: SparseType, - weighted: bool, - mixed: bool, - mixed_B: bool, - use_cache: bool, - cache_algorithm: CacheAlgorithm, - pooling_mode: PoolingMode, - use_cpu: bool, + weights_ty: SparseType, output_dtype: SparseType, - use_experimental_tbe: bool, ) -> None: - # NOTE: cache is not applicable to CPU version. - assume(not use_cpu or not use_cache) - # NOTE: limit (T * B * L * D) to avoid timeout for CPU version! - assume(not use_cpu or T * B * L * D <= 2048) - # NOTE: CPU does not support FP16. - assume(not (use_cpu and weights_precision == SparseType.FP16)) - - # NOTE: weighted operation can be done only for SUM. - assume(pooling_mode == PoolingMode.SUM or not weighted) - # NOTE: No bag ops only work on GPUs, no mixed - assume(not use_cpu or pooling_mode != PoolingMode.NONE) - assume(not mixed or pooling_mode != PoolingMode.NONE) - # TODO: Support these cases + D_alignment = max(weights_ty.align_size() for t in range(T)) + D_alignment = max(D_alignment, output_dtype.align_size()) + D = round_up(D, D_alignment) + # BF16 output only works for CUDA device sm80+ (e.g., A100) assume( - not mixed_B - or ( - weights_precision != SparseType.INT8 - and output_dtype != SparseType.INT8 - and not use_cpu - and not use_cache - and pooling_mode != PoolingMode.NONE - ) + torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (8, 0) + or not output_dtype == SparseType.BF16 ) - - emb_op = SplitTableBatchedEmbeddingBagsCodegen - if pooling_mode == PoolingMode.SUM: - mode = "sum" - do_pooling = True - elif pooling_mode == PoolingMode.MEAN: - mode = "mean" - do_pooling = True - elif pooling_mode == PoolingMode.NONE: - mode = "sum" - do_pooling = False - else: - # This proves that we have exhaustively checked all PoolingModes - raise RuntimeError("Unknown PoolingMode!") - - E = int(10**log_E) - if use_cpu: - D = (D + 15) // 16 * 4 - else: - D = D * 4 - if not mixed: - Ds = [D] * T - Es = [E] * T - else: - Ds = [ - round_up(np.random.randint(low=int(0.25 * D), high=int(1.0 * D)), 4) - for _ in range(T) - ] - Es = [ - np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T) - ] - - if not mixed_B: - Bs = [B] * T - Bs_rank_feature = [[0]] - else: - Bs_rank_feature, Bs = gen_mixed_B_batch_sizes(B, T) - - compute_device = ComputeDevice.CUDA - if use_cpu: - managed = [EmbeddingLocation.HOST] * T - compute_device = ComputeDevice.CPU - elif TEST_WITH_ROCM: - # ROCm managed memory allocation is under development - managed = [EmbeddingLocation.DEVICE] * T - elif use_cache: - managed = [EmbeddingLocation.MANAGED_CACHING] * T - if mixed: - average_D = sum(Ds) // T - for t, d in enumerate(Ds): - managed[t] = ( - EmbeddingLocation.DEVICE if d < average_D else managed[t] - ) - else: - managed = [ - np.random.choice( - [ - EmbeddingLocation.DEVICE, - EmbeddingLocation.MANAGED, - ] - ) - for _ in range(T) - ] - if do_pooling: - bs = [ - to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True), use_cpu) - for (E, D) in zip(Es, Ds) - ] - else: - bs = [ - to_device(torch.nn.Embedding(E, D, sparse=True), use_cpu) - for (E, D) in zip(Es, Ds) - ] - if weights_precision == SparseType.INT8: - for t in range(T): - bs[t].weight.data.copy_( - torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloat( - torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized( - bs[t].weight.data - ) - ) - ) - - if weights_precision == SparseType.FP16: - bs = [b.half() for b in bs] - - # Generate indices - xs = [ - to_device(torch.randint(low=0, high=e, size=(b, L)), use_cpu) - for e, b in zip(Es, Bs) + Ds = [ + round_up( + np.random.randint(low=int(max(0.25 * D, 1)), high=int(1.0 * D)), + D_alignment, + ) + for _ in range(T) ] - # Generate positional weights - xws = [to_device(torch.randn(size=(b, L)), use_cpu) for b in Bs] - xws_acc_type = copy.deepcopy(xws) - - if weights_precision == SparseType.FP16: - xws = [xw.half() for xw in xws] + Ds = [D] * T + E = int(10**log_E) + Es = [np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T)] - # Run baseline - fs = ( - [ - b_indices(b, x, use_cpu=use_cpu, do_pooling=do_pooling) - for (b, x) in zip(bs, xs) - ] - if not weighted - else [ - b_indices( - b, - x, - per_sample_weights=xw.view(-1), - use_cpu=use_cpu, - do_pooling=do_pooling, + weights_ty_list = [weights_ty] * T + managed = [EmbeddingLocation.DEVICE] * T + op = IntNBitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + ( + "", + E, + D, + W_TY, + EmbeddingLocation(M), ) - for (b, x, xw) in zip(bs, xs, xws) - ] + for (E, D, M, W_TY) in zip(Es, Ds, managed, weights_ty_list) + ], + output_dtype=output_dtype, + device=torch.cuda.current_device(), ) + # Initialize the random weights for int nbit table split embedding bag + op.fill_random_weights() - if do_pooling: - if mixed_B: - f = format_ref_tensors_in_mixed_B_layout(fs, Bs_rank_feature) - else: - f = torch.cat([f.view(B, -1) for f in fs], dim=1) - else: - f = torch.cat(fs, dim=0).view(-1, D) - - # Create a TBE op - cc = emb_op( + op_ref = IntNBitTableBatchedEmbeddingBagsCodegen( embedding_specs=[ ( + "", E, D, + W_TY, EmbeddingLocation(M), - compute_device, ) - for (E, D, M) in zip(Es, Ds, managed) + for (E, D, M, W_TY) in zip(Es, Ds, managed, weights_ty_list) ], - weights_precision=weights_precision, - optimizer=OptimType.EXACT_ROWWISE_ADAGRAD - if mixed_B - else OptimType.EXACT_SGD, - learning_rate=0.05, - cache_algorithm=cache_algorithm, - pooling_mode=pooling_mode, - output_dtype=output_dtype, - use_experimental_tbe=use_experimental_tbe, + output_dtype=SparseType.FP32, + device=torch.cuda.current_device(), ) - # NOTE: test TorchScript-compatible! - cc = torch.jit.script(cc) + # Initialize the random weights for int nbit table split embedding bag + op_ref.fill_random_weights() + # sync weights between two ops + split_weights = op.split_embedding_weights() + ref_split_weights = op_ref.split_embedding_weights() for t in range(T): - cc.split_embedding_weights()[t].data.copy_( - bs[t].weight - if weights_precision != SparseType.INT8 - else torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(bs[t].weight) + (weights, scale_shift) = split_weights[t] + (ref_weights, ref_scale_shift) = ref_split_weights[t] + self.assertEqual(weights.size(), ref_weights.size()) + element_size = weights_ty_list[t].bit_rate() / 8.0 + rand_tensor = torch.rand( + ref_weights.shape[0], int(ref_weights.shape[1] / element_size) + ) + rand_weights, rand_scale_shift = quantize_embs( + rand_tensor, weights_ty_list[t] ) + ref_weights.copy_(rand_weights) + weights.copy_(ref_weights) + if rand_scale_shift is not None: + self.assertIsNotNone(scale_shift) + self.assertIsNotNone(ref_scale_shift) + ref_scale_shift.copy_(rand_scale_shift) + scale_shift.copy_(ref_scale_shift) - x = torch.cat([x.contiguous().flatten() for x in xs], dim=0) - xw = torch.cat([xw.contiguous().flatten() for xw in xws_acc_type], dim=0) - - (indices, offsets) = get_table_batched_offsets_from_dense( - x, L, sum(Bs), use_cpu - ) - - batch_size_per_feature_per_rank = Bs_rank_feature if mixed_B else None - - # Run TBE - fc2 = ( - cc( - indices, - offsets, - batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + requests = generate_requests(1, B, T, L, min(Es), reuse=0.1) + for indices, offsets, _ in requests: + lowp_pooled_output = op( + indices=indices.int(), + offsets=offsets.int(), ) - if not weighted - else cc( - indices, - offsets, - to_device(xw.contiguous().view(-1), use_cpu), - batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + fp32_pooled_output = op_ref( + indices=indices.int(), + offsets=offsets.int(), + ) + lowp_pooled_emb_split = [ + d + 8 if output_dtype == SparseType.INT8 else d for d in Ds + ] + lowp_pooled_output_per_table = torch.split( + lowp_pooled_output, lowp_pooled_emb_split, dim=1 + ) + deq_lowp_pooled_output_per_table = [ + torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloat(t.contiguous()) + if output_dtype == SparseType.INT8 + else t.float() + for t in lowp_pooled_output_per_table + ] + fp32_pooled_output_per_table = torch.split(fp32_pooled_output, Ds, dim=1) + dq_fp32_pooled_output_per_table = [ + torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloat( + torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized( + t.contiguous() + ).contiguous() + ).contiguous() + if output_dtype == SparseType.INT8 + else t.half().float() + for t in fp32_pooled_output_per_table + ] + cat_deq_lowp_pooled_output = torch.cat( + deq_lowp_pooled_output_per_table, dim=1 + ) + cat_dq_fp32_pooled_output = torch.cat( + dq_fp32_pooled_output_per_table, dim=1 + ) + torch.testing.assert_close( + cat_deq_lowp_pooled_output, + cat_dq_fp32_pooled_output, + rtol=1e-2, + atol=1e-2, + equal_nan=True, ) - ) - - # Compare results: f = baseline, fc2 = TBE - tolerance = ( - 1.0e-5 - if weights_precision == SparseType.FP32 and output_dtype == SparseType.FP32 - else 8.0e-3 - ) - torch.testing.assert_close( - fc2.float(), f.float(), atol=tolerance, rtol=tolerance - ) - - def test_forward_cpu_int8( - self, - ) -> None: - weights_precision = SparseType.INT8 - use_cpu = True - T = random.randint(1, 10) - D = random.randint(2, min(256, int(2048 / T))) - B = random.randint(1, min(128, int(2048 / T / D))) - L = random.randint(0, min(20, int(2048 / T / D / B))) - log_E = random.randint(3, 5) - - use_cache = False - # cache_algorithm is don't care as we don't use cache. - cache_algorithm = CacheAlgorithm.LRU - pooling_mode = random.choice( + @unittest.skipIf(*gpu_unavailable) + @given( + T=st.integers(min_value=1, max_value=10), + D=st.integers(min_value=2, max_value=128), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=0, max_value=20), + weights_ty=st.sampled_from( [ - PoolingMode.SUM, - PoolingMode.MEAN, + SparseType.FP32, + SparseType.FP16, + SparseType.INT8, + SparseType.INT4, + SparseType.INT2, ] - ) - mixed = False - mixed_B = False - if pooling_mode == PoolingMode.SUM: - weighted = random.choice([True, False]) - else: - weighted = False - self.execute_forward_( - T, - D, - B, - log_E, - L, - weights_precision, - weighted, - mixed, - mixed_B, - use_cache, - cache_algorithm, - pooling_mode, - use_cpu, - SparseType.FP32, - False, # use_experimental_tbe - ) - - def test_forward_cpu_fp32( - self, - ) -> None: - weights_precision = SparseType.FP32 - use_cpu = True - T = random.randint(1, 10) - D = random.randint(2, min(256, int(2048 / T))) - B = random.randint(1, min(128, int(2048 / T / D))) - L = random.randint(0, min(20, int(2048 / T / D / B))) - log_E = random.randint(3, 5) - - use_cache = False - # cache_algorithm is don't care as we don't use cache. - cache_algorithm = CacheAlgorithm.LRU - - pooling_mode = random.choice( + ), + output_dtype=st.sampled_from( [ - PoolingMode.SUM, - PoolingMode.MEAN, + SparseType.FP16, + SparseType.BF16, + SparseType.INT8, ] ) - mixed = False - mixed_B = False - if pooling_mode == PoolingMode.SUM: - weighted = random.choice([True, False]) - else: - weighted = False - self.execute_forward_( - T, - D, - B, - log_E, - L, - weights_precision, - weighted, - mixed, - mixed_B, - use_cache, - cache_algorithm, - pooling_mode, - use_cpu, - SparseType.FP32, - False, # use_experimental_tbe - ) - - @unittest.skipIf(*gpu_unavailable) - def test_forward_gpu_no_cache_int8( - self, - ) -> None: - weights_precision = SparseType.INT8 - use_cpu = False - T = random.randint(1, 10) - D = random.randint(2, 256) - B = random.randint(1, 128) - L = random.randint(0, 20) - log_E = random.randint(3, 5) - - use_cache = False - # cache_algorithm is don't care as we don't use cache. - cache_algorithm = CacheAlgorithm.LRU - - pooling_mode = random.choice( + if not TEST_WITH_ROCM + else st.sampled_from( [ - PoolingMode.SUM, - PoolingMode.MEAN, - PoolingMode.NONE, + SparseType.FP16, + # The counterparts of __nv_bfloat16 and __nv_bfloat162 are not supported on ROCm + SparseType.INT8, ] - ) - if pooling_mode == PoolingMode.NONE: - mixed = False - else: - mixed = random.choice([True, False]) - mixed_B = False - if pooling_mode == PoolingMode.SUM: - weighted = random.choice([True, False]) - else: - weighted = False - self.execute_forward_( - T, - D, - B, - log_E, - L, - weights_precision, - weighted, - mixed, - mixed_B, - use_cache, - cache_algorithm, - pooling_mode, - use_cpu, - SparseType.FP32, - False, # use_experimental_tbe - ) - - @unittest.skipIf(*gpu_unavailable) - @given( - use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False), + ), ) @settings( verbosity=VERBOSITY, max_examples=MAX_EXAMPLES_LONG_RUNNING, deadline=None, - suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], + suppress_health_check=[HealthCheck.filter_too_much], ) - def test_forward_gpu_no_cache_fp16( + def test_nbit_split_embedding_weights_with_scale_and_bias( self, - use_experimental_tbe: bool, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weights_ty: SparseType, + output_dtype: SparseType, ) -> None: - weights_precision = SparseType.FP16 - use_cpu = False - T = random.randint(1, 10) - D = random.randint(2, 256) - B = random.randint(1, 128) - L = random.randint(0, 20) - log_E = random.randint(3, 5) - - use_cache = False - # cache_algorithm is don't care as we don't use cache. - cache_algorithm = CacheAlgorithm.LRU - - pooling_mode = random.choice( - [ - PoolingMode.SUM, - PoolingMode.MEAN, - ] - + ([PoolingMode.NONE] if not use_experimental_tbe else []) + D_alignment = max(weights_ty.align_size() for t in range(T)) + D_alignment = max(D_alignment, output_dtype.align_size()) + D = round_up(D, D_alignment) + # BF16 output only works for CUDA device sm80+ (e.g., A100) + assume( + torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (8, 0) + or not output_dtype == SparseType.BF16 ) - if pooling_mode == PoolingMode.NONE: - mixed = False - mixed_B = False - else: - mixed = random.choice([True, False]) - mixed_B = ( - random.choice([True, False]) if not use_experimental_tbe else False + Ds = [ + round_up( + np.random.randint(low=int(max(0.25 * D, 1)), high=int(1.0 * D)), + D_alignment, ) - if pooling_mode == PoolingMode.SUM: - weighted = random.choice([True, False]) - else: - weighted = False - self.execute_forward_( - T, - D, - B, - log_E, - L, - weights_precision, - weighted, - mixed, - mixed_B, - use_cache, - cache_algorithm, - pooling_mode, - use_cpu, - SparseType.FP32, - use_experimental_tbe, - ) - - @unittest.skipIf(*gpu_unavailable) - @given( - use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False), - ) - @settings( - verbosity=VERBOSITY, - max_examples=MAX_EXAMPLES_LONG_RUNNING, - deadline=None, - suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], - ) - def test_forward_gpu_no_cache_fp32( - self, - use_experimental_tbe: bool, - ) -> None: - weights_precision = SparseType.FP32 - use_cpu = False - T = random.randint(1, 10) - D = random.randint(2, 256) - B = random.randint(1, 128) - L = random.randint(0, 20) - log_E = random.randint(3, 5) - - use_cache = False - # cache_algorithm is don't care as we don't use cache. - cache_algorithm = CacheAlgorithm.LRU + for _ in range(T) + ] + Ds = [D] * T + E = int(10**log_E) + Es = [np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T)] - pooling_mode = random.choice( - [ - PoolingMode.SUM, - PoolingMode.MEAN, - ] - + ([PoolingMode.NONE] if not use_experimental_tbe else []) + weights_ty_list = [weights_ty] * T + managed = [EmbeddingLocation.DEVICE] * T + op = IntNBitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + ( + "", + E, + D, + W_TY, + EmbeddingLocation(M), + ) + for (E, D, M, W_TY) in zip(Es, Ds, managed, weights_ty_list) + ], + output_dtype=output_dtype, + device=torch.cuda.current_device(), ) - if pooling_mode == PoolingMode.NONE: - mixed = False - mixed_B = False - else: - mixed = random.choice([True, False]) - mixed_B = ( - random.choice([True, False]) if not use_experimental_tbe else False - ) - if pooling_mode == PoolingMode.SUM: - weighted = random.choice([True, False]) - else: - weighted = False - self.execute_forward_( - T, - D, - B, - log_E, - L, - weights_precision, - weighted, - mixed, - mixed_B, - use_cache, - cache_algorithm, - pooling_mode, - use_cpu, - SparseType.FP32, - use_experimental_tbe, + # Initialize the random weights for int nbit table split embedding bag + op.fill_random_weights() + + # sync weights between two ops + split_weights = op.split_embedding_weights() + split_weights_with_scale_bias = op.split_embedding_weights_with_scale_bias( + split_scale_bias_mode=2 ) + for t in range(T): + (weights, scale_bias) = split_weights[t] + (weights2, scale, bias) = split_weights_with_scale_bias[t] + torch.testing.assert_close(weights2, weights) + if scale is None: + self.assertIsNone(scale_bias) + self.assertIsNone(bias) + else: + torch.testing.assert_close( + scale.cpu(), + torch.tensor( + scale_bias[:, : scale_bias.size(1) // 2] + .contiguous() + .cpu() + .numpy() + .view(np.float16) + ), + ) + torch.testing.assert_close( + bias.cpu(), + torch.tensor( + scale_bias[:, scale_bias.size(1) // 2 :] + .contiguous() + .cpu() + .numpy() + .view(np.float16) + ), + ) - @unittest.skipIf(*gpu_unavailable) @given( - cache_algorithm=st.sampled_from(CacheAlgorithm), - ) - @settings( - verbosity=VERBOSITY, - max_examples=MAX_EXAMPLES_LONG_RUNNING, - deadline=None, - suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], - ) - def test_forward_gpu_uvm_cache_int8( - self, - cache_algorithm: CacheAlgorithm, - ) -> None: - weights_precision = SparseType.INT8 - use_cpu = False - T = random.randint(1, 10) - D = random.randint(2, 256) - B = random.randint(1, 128) - L = random.randint(0, 20) - log_E = random.randint(3, 5) - - use_cache = True - - pooling_mode = random.choice( + T=st.integers(min_value=1, max_value=3), + D=st.integers(min_value=2, max_value=128), + B=st.integers(min_value=1, max_value=32), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=0, max_value=10), + weights_precision=st.sampled_from([SparseType.FP16, SparseType.FP32]), + weighted=st.booleans(), + mixed=st.booleans(), + long_segments=st.booleans(), + pooling_mode=st.sampled_from( [ PoolingMode.SUM, PoolingMode.MEAN, PoolingMode.NONE, ] - ) - output_dtype = random.choice( - [ - SparseType.FP32, - SparseType.FP16, - ] - ) - if pooling_mode == PoolingMode.NONE: - mixed = False - else: - mixed = random.choice([True, False]) - mixed_B = False - if pooling_mode == PoolingMode.SUM: - weighted = random.choice([True, False]) - else: - weighted = False - self.execute_forward_( - T, - D, - B, - log_E, - L, - weights_precision, - weighted, - mixed, - mixed_B, - use_cache, - cache_algorithm, - pooling_mode, - use_cpu, - output_dtype, - False, # use_experimental_tbe - ) - - @unittest.skipIf(*gpu_unavailable) - @given( - cache_algorithm=st.sampled_from(CacheAlgorithm), - use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False), + ), + use_cpu=st.booleans() + if (gpu_available and not TEST_WITH_ROCM) + else st.just(False) + if (gpu_available and TEST_WITH_ROCM) + else st.just(True), + output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), ) @settings( verbosity=VERBOSITY, - max_examples=MAX_EXAMPLES_LONG_RUNNING, + max_examples=10, deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], ) - def test_forward_gpu_uvm_cache_fp16( + def test_backward_dense( # noqa C901 self, - cache_algorithm: CacheAlgorithm, - use_experimental_tbe: bool, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weights_precision: SparseType, + weighted: bool, + mixed: bool, + long_segments: bool, + pooling_mode: PoolingMode, + use_cpu: bool, + output_dtype: SparseType, ) -> None: - weights_precision = SparseType.FP16 - use_cpu = False - T = random.randint(1, 10) - D = random.randint(2, 256) - B = random.randint(1, 128) - L = random.randint(0, 20) - log_E = random.randint(3, 5) + # NOTE: torch.autograd.gradcheck() is too time-consuming for CPU version + # so we have to limit (T * B * L * D)! + assume(not use_cpu or T * B * L * D <= 2048) + assume(pooling_mode == PoolingMode.SUM or not weighted) + assume(not (use_cpu and weights_precision == SparseType.FP16)) + # No bag ops only work on GPUs, no mixed, no weighted + assume(not use_cpu or pooling_mode != PoolingMode.NONE) + assume(not mixed or pooling_mode != PoolingMode.NONE) + assume(not weighted or pooling_mode != PoolingMode.NONE) - use_cache = True + emb_op = DenseTableBatchedEmbeddingBagsCodegen + if pooling_mode == PoolingMode.SUM: + mode = "sum" + do_pooling = True + elif pooling_mode == PoolingMode.MEAN: + mode = "mean" + do_pooling = True + elif pooling_mode == PoolingMode.NONE: + mode = "sum" + do_pooling = False + else: + # This proves that we have exhaustively checked all PoolingModes + raise RuntimeError("Unknown PoolingMode!") - pooling_mode = random.choice( - [ - PoolingMode.SUM, - PoolingMode.MEAN, + E = int(10**log_E) + if use_cpu: + D = (D + 15) // 16 * 4 + else: + D = D * 4 + if not mixed: + Ds = [D] * T + Es = [E] * T + else: + Ds = [ + round_up(np.random.randint(low=int(0.25 * D), high=int(1.0 * D)), 4) + for _ in range(T) ] - + ([PoolingMode.NONE] if not use_experimental_tbe else []) - ) - output_dtype = random.choice( - [ - SparseType.FP32, - SparseType.FP16, + Es = [ + np.random.randint(low=int(0.5 * E), high=int(2 * E)) for _ in range(T) + ] + if do_pooling: + bs = [ + to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=False), use_cpu) + for (E, D) in zip(Es, Ds) ] - ) - if pooling_mode == PoolingMode.NONE: - mixed = False - else: - mixed = random.choice([True, False]) - mixed_B = False - if pooling_mode == PoolingMode.SUM: - weighted = random.choice([True, False]) else: - weighted = False - self.execute_forward_( - T, - D, - B, - log_E, - L, - weights_precision, - weighted, - mixed, - mixed_B, - use_cache, - cache_algorithm, - pooling_mode, - use_cpu, - output_dtype, - use_experimental_tbe, - ) + bs = [ + to_device(torch.nn.Embedding(E, D, sparse=False), use_cpu) + for (E, D) in zip(Es, Ds) + ] - @unittest.skipIf(*gpu_unavailable) - @given( - cache_algorithm=st.sampled_from(CacheAlgorithm), - use_experimental_tbe=st.booleans() if not TEST_WITH_ROCM else st.just(False), - ) - @settings( - verbosity=VERBOSITY, - max_examples=MAX_EXAMPLES_LONG_RUNNING, - deadline=None, - suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], - ) - def test_forward_gpu_uvm_cache_fp32( - self, - cache_algorithm: CacheAlgorithm, - use_experimental_tbe: bool, - ) -> None: - weights_precision = SparseType.FP32 - use_cpu = False - T = random.randint(1, 10) - D = random.randint(2, 256) - B = random.randint(1, 128) - L = random.randint(0, 20) - log_E = random.randint(3, 5) + if weights_precision == SparseType.FP16: + bs = [b.half() for b in bs] + + xs = [ + to_device( + torch.from_numpy( + np.random.choice(range(e), size=(B, L), replace=True).astype( + np.int64 + ) + ), + use_cpu, + ) + for e in Es + ] + if long_segments and L > 0 and weights_precision != SparseType.FP16: + for x in xs: + x[:, 0] = 0 - use_cache = True + xws = [to_device(torch.randn(size=(B, L)), use_cpu) for _ in range(T)] + xws_acc_type = copy.deepcopy(xws) - pooling_mode = random.choice( + if weights_precision == SparseType.FP16: + xws = [xw.half() for xw in xws] + + fs = ( [ - PoolingMode.SUM, - PoolingMode.MEAN, + b_indices(b, x, use_cpu=use_cpu, do_pooling=do_pooling) + for (b, x) in zip(bs, xs) ] - + ([PoolingMode.NONE] if not use_experimental_tbe else []) - ) - output_dtype = random.choice( - [ - SparseType.FP32, - SparseType.FP16, + if not weighted + else [ + b_indices( + b, + x, + per_sample_weights=xw.view(-1), + use_cpu=use_cpu, + do_pooling=do_pooling, + ) + for (b, x, xw) in zip(bs, xs, xws) ] ) - if pooling_mode == PoolingMode.NONE: - mixed = False + gos = [torch.randn_like(f) for f in fs] + [f.backward(go) for (f, go) in zip(fs, gos)] + + # pyre-fixme[16]: `Optional` has no attribute `view`. + grad_weights = torch.cat([b.weight.grad.view(-1) for b in bs]) + if weights_precision == SparseType.FP16 and not use_cpu: + grad_weights = grad_weights.half() + + cc = emb_op( + embedding_specs=[(E, D) for (E, D) in zip(Es, Ds)], + pooling_mode=pooling_mode, + use_cpu=use_cpu, + weights_precision=weights_precision, + output_dtype=output_dtype, + ) + if do_pooling: + # NOTE: test TorchScript-compatible! + cc = torch.jit.script(cc) + + for t in range(T): + cc.split_embedding_weights()[t].data.copy_(bs[t].weight) + + x = torch.cat([x.view(1, B, L) for x in xs], dim=0) + xw = torch.cat([xw.view(1, B, L) for xw in xws_acc_type], dim=0) + + (indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu=use_cpu) + fc2 = ( + cc(indices, offsets) + if not weighted + else cc(indices, offsets, to_device(xw.contiguous().view(-1), use_cpu)) + ) + + if do_pooling: + f = torch.cat([f.view(B, -1) for f in fs], dim=1) else: - mixed = random.choice([True, False]) - mixed_B = False - if pooling_mode == PoolingMode.SUM: - weighted = random.choice([True, False]) + f = torch.cat(fs, dim=0).view(-1, D) + + torch.testing.assert_close( + fc2.float(), + f.float(), + atol=5.0e-3 + if weights_precision == SparseType.FP16 or output_dtype == SparseType.FP16 + else 1.0e-5, + rtol=5.0e-3 + if weights_precision == SparseType.FP16 or output_dtype == SparseType.FP16 + else 1.0e-5, + ) + if do_pooling: + goc = torch.cat([go.view(B, -1) for go in gos], dim=1) else: - weighted = False - self.execute_forward_( - T, - D, - B, - log_E, - L, - weights_precision, - weighted, - mixed, - mixed_B, - use_cache, - cache_algorithm, - pooling_mode, + goc = torch.cat(gos, dim=0) + fc2.backward(goc) + torch.testing.assert_close( + cc.weights.grad, + grad_weights, + atol=5.0e-3 + if weights_precision == SparseType.FP16 or output_dtype == SparseType.FP16 + else 1.0e-4, + rtol=5.0e-3 + if weights_precision == SparseType.FP16 or output_dtype == SparseType.FP16 + else 1.0e-4, + ) + + cc = DenseTableBatchedEmbeddingBagsCodegen( + [(E, D) for (E, D) in zip(Es, Ds)], + # NOTE: only SUM pooling can work with per_sample_weights! + pooling_mode=PoolingMode.SUM, + use_cpu=use_cpu, + ) + + per_sample_weights = to_device(xw.contiguous().view(-1), use_cpu) + per_sample_weights.requires_grad = True + indices.requires_grad = False + offsets.requires_grad = False + for param in cc.parameters(): + param.requires_grad = False + y = cc(indices, offsets, per_sample_weights) + y.sum().backward() + # pyre-fixme[16]: `Optional` has no attribute `clone`. + indice_weight_grad_all = per_sample_weights.grad.clone().cpu() + T_ = len(xws) + feature_requires_grad = to_device( + torch.tensor(np.random.choice([0, 1], replace=True, size=(T_,))).int(), use_cpu, - output_dtype, - use_experimental_tbe, + ) + per_sample_weights = per_sample_weights.detach().clone() + per_sample_weights.requires_grad = True + y = cc( + indices, + offsets, + per_sample_weights, + feature_requires_grad=feature_requires_grad, + ) + y.sum().backward() + indice_weight_grad_mask = per_sample_weights.grad.clone().cpu() + for t in range(T_): + if feature_requires_grad[t]: + torch.testing.assert_close( + indice_weight_grad_mask.view(T_, B, L)[t], + indice_weight_grad_all.view(T_, B, L)[t], + ) + else: + torch.testing.assert_close( + indice_weight_grad_mask.view(T_, B, L)[t], + torch.zeros_like(indice_weight_grad_mask.view(T_, B, L)[t]), + ) + + per_sample_weights = to_device(xw.contiguous().view(-1), use_cpu) + cc = cc.float() + per_sample_weights = per_sample_weights.float() + per_sample_weights.requires_grad = True + indices.requires_grad = False + offsets.requires_grad = False + for param in cc.parameters(): + param.requires_grad = False + gradcheck( + cc, (indices, offsets, per_sample_weights), eps=1e-2, atol=1e-3, rtol=1e-3 ) - @unittest.skipIf(*gpu_unavailable) @given( - T=st.integers(min_value=1, max_value=10), - D=st.integers(min_value=2, max_value=128), + T=st.integers(min_value=1, max_value=5), + D=st.integers(min_value=2, max_value=256), B=st.integers(min_value=1, max_value=128), log_E=st.integers(min_value=3, max_value=5), L=st.integers(min_value=0, max_value=20), - output_dtype=st.sampled_from([SparseType.FP16, SparseType.INT8]), + weights_precision=st.sampled_from([SparseType.FP16, SparseType.FP32]), + weighted=st.booleans(), + long_segments=st.booleans(), + pooling_mode=st.sampled_from( + [ + PoolingMode.SUM, + PoolingMode.MEAN, + PoolingMode.NONE, + ] + ), + output_dtype=st.sampled_from([SparseType.FP16, SparseType.FP32]), ) @settings( verbosity=VERBOSITY, - max_examples=MAX_EXAMPLES_LONG_RUNNING, + max_examples=MAX_EXAMPLES, deadline=None, - suppress_health_check=[HealthCheck.filter_too_much], + suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], ) - def test_forward_fused_pooled_emb_quant( - self, - T: int, - D: int, - B: int, - log_E: int, - L: int, - output_dtype: SparseType, - ) -> None: - Ds = [ - round_up(np.random.randint(low=int(max(0.25 * D, 1)), high=int(1.0 * D)), 4) - for _ in range(T) - ] - E = int(10**log_E) - Es = [np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T)] + def test_backward_none(self, **kwargs: Any) -> None: + self.execute_backward_none_(**kwargs) - op = SplitTableBatchedEmbeddingBagsCodegen( - embedding_specs=[ - ( - E, - D, - EmbeddingLocation.DEVICE, - ComputeDevice.CUDA, - ) - for (E, D) in zip(Es, Ds) - ], - output_dtype=output_dtype, - device=torch.cuda.current_device(), - ) - op_ref = SplitTableBatchedEmbeddingBagsCodegen( - embedding_specs=[ - ( - E, - D, - EmbeddingLocation.DEVICE, - ComputeDevice.CUDA, - ) - for (E, D) in zip(Es, Ds) - ], - output_dtype=SparseType.FP32, - device=torch.cuda.current_device(), - ) - # sync weights between two ops - split_weights = op.split_embedding_weights() - ref_split_weights = op_ref.split_embedding_weights() - for t in range(T): - split_weights[t].data.copy_(ref_split_weights[t]) - - requests = generate_requests(2, B, T, L, min(Es), reuse=0.1) - - for indices, offsets, _ in requests: - lowp_pooled_output = op( - indices=indices, - offsets=offsets, - ) - fp32_pooled_output = op_ref( - indices=indices, - offsets=offsets, - ) - lowp_pooled_emb_split = [ - d + 8 if output_dtype == SparseType.INT8 else d for d in op.dims - ] - lowp_pooled_output_per_table = torch.split( - lowp_pooled_output, lowp_pooled_emb_split, dim=1 - ) - deq_lowp_pooled_output_per_table = [ - torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloat(t.contiguous()) - if output_dtype == SparseType.INT8 - else t.float() - for t in lowp_pooled_output_per_table - ] - fp32_pooled_output_per_table = torch.split( - fp32_pooled_output, op.dims, dim=1 - ) - dq_fp32_pooled_output_per_table = [ - torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloat( - torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized( - t.contiguous() - ).contiguous() - ) - if output_dtype == SparseType.INT8 - else t.half().float() - for t in fp32_pooled_output_per_table - ] - cat_deq_lowp_pooled_output = torch.cat( - deq_lowp_pooled_output_per_table, dim=1 - ) - cat_dq_fp32_pooled_output = torch.cat( - dq_fp32_pooled_output_per_table, dim=1 - ) - torch.testing.assert_close( - cat_deq_lowp_pooled_output, cat_dq_fp32_pooled_output - ) - - @unittest.skipIf(*gpu_unavailable) @given( - T=st.integers(min_value=1, max_value=10), - D=st.integers(min_value=2, max_value=128), + T=st.integers(min_value=1, max_value=5), + D=st.integers(min_value=2, max_value=256), B=st.integers(min_value=1, max_value=128), log_E=st.integers(min_value=3, max_value=5), L=st.integers(min_value=0, max_value=20), - weights_ty=st.sampled_from( - [ - SparseType.FP32, - SparseType.FP16, - SparseType.INT8, - SparseType.INT4, - # FIXME: INT2 caused big numerical error for this test - # SparseType.INT2, - ] - ), - output_dtype=st.sampled_from( - [ - SparseType.FP16, - SparseType.BF16, - SparseType.INT8, - # SparseType.INT4, - ] - ) - if not TEST_WITH_ROCM - else st.sampled_from( + weights_precision=st.sampled_from([SparseType.FP16, SparseType.FP32]), + weighted=st.booleans(), + long_segments=st.booleans(), + pooling_mode=st.sampled_from( [ - SparseType.FP16, - # The counterparts of __nv_bfloat16 and __nv_bfloat162 are not supported on ROCm - SparseType.INT8, - # SparseType.INT4, + PoolingMode.SUM, + PoolingMode.MEAN, + PoolingMode.NONE, ] ), + output_dtype=st.sampled_from([SparseType.FP16, SparseType.FP32]), ) @settings( verbosity=VERBOSITY, - max_examples=MAX_EXAMPLES_LONG_RUNNING, + max_examples=MAX_EXAMPLES, deadline=None, - suppress_health_check=[HealthCheck.filter_too_much], + suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], ) - def test_nbit_forward_fused_pooled_emb_quant( + def test_backward_none_with_rowwise_adagrad(self, **kwargs: Any) -> None: + self.execute_backward_none_(optimizer=OptimType.EXACT_ROWWISE_ADAGRAD, **kwargs) + + def execute_backward_none_( # noqa C901 self, T: int, D: int, B: int, log_E: int, L: int, - weights_ty: SparseType, + weights_precision: SparseType, + weighted: bool, + long_segments: bool, + pooling_mode: PoolingMode, output_dtype: SparseType, + optimizer: Optional[OptimType] = None, ) -> None: - D_alignment = max(weights_ty.align_size() for t in range(T)) - D_alignment = max(D_alignment, output_dtype.align_size()) - D = round_up(D, D_alignment) - # BF16 output only works for CUDA device sm80+ (e.g., A100) - assume( - torch.cuda.is_available() - and torch.cuda.get_device_capability() >= (8, 0) - or not output_dtype == SparseType.BF16 - ) - Ds = [ - round_up( - np.random.randint(low=int(max(0.25 * D, 1)), high=int(1.0 * D)), - D_alignment, - ) - for _ in range(T) - ] - Ds = [D] * T - E = int(10**log_E) - Es = [np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T)] + use_cpu = False + mixed = False + use_cache = False - weights_ty_list = [weights_ty] * T - managed = [EmbeddingLocation.DEVICE] * T - op = IntNBitTableBatchedEmbeddingBagsCodegen( - embedding_specs=[ - ( - "", - E, - D, - W_TY, - EmbeddingLocation(M), - ) - for (E, D, M, W_TY) in zip(Es, Ds, managed, weights_ty_list) - ], - output_dtype=output_dtype, - device=torch.cuda.current_device(), - ) - # Initialize the random weights for int nbit table split embedding bag - op.fill_random_weights() + # NOTE: cache is not applicable to CPU version. + assume(not use_cpu or not use_cache) + # NOTE: limit (T * B * L * D) to avoid timeout for CPU version! + assume(not use_cpu or T * B * L * D <= 2048) + assume(not (use_cpu and weights_precision == SparseType.FP16)) + # No bag ops only work on GPUs, no mixed, no weighted + assume(not use_cpu or pooling_mode != PoolingMode.NONE) + assume(not mixed or pooling_mode != PoolingMode.NONE) + assume(not weighted or pooling_mode != PoolingMode.NONE) - op_ref = IntNBitTableBatchedEmbeddingBagsCodegen( - embedding_specs=[ - ( - "", - E, - D, - W_TY, - EmbeddingLocation(M), - ) - for (E, D, M, W_TY) in zip(Es, Ds, managed, weights_ty_list) - ], - output_dtype=SparseType.FP32, - device=torch.cuda.current_device(), - ) - # Initialize the random weights for int nbit table split embedding bag - op_ref.fill_random_weights() + assume(pooling_mode == PoolingMode.SUM or not weighted) - # sync weights between two ops - split_weights = op.split_embedding_weights() - ref_split_weights = op_ref.split_embedding_weights() - for t in range(T): - (weights, scale_shift) = split_weights[t] - (ref_weights, ref_scale_shift) = ref_split_weights[t] - self.assertEqual(weights.size(), ref_weights.size()) - element_size = weights_ty_list[t].bit_rate() / 8.0 - rand_tensor = torch.rand( - ref_weights.shape[0], int(ref_weights.shape[1] / element_size) - ) - rand_weights, rand_scale_shift = quantize_embs( - rand_tensor, weights_ty_list[t] - ) - ref_weights.copy_(rand_weights) - weights.copy_(ref_weights) - if rand_scale_shift is not None: - self.assertIsNotNone(scale_shift) - self.assertIsNotNone(ref_scale_shift) - ref_scale_shift.copy_(rand_scale_shift) - scale_shift.copy_(ref_scale_shift) + if pooling_mode == PoolingMode.SUM: + mode = "sum" + do_pooling = True + elif pooling_mode == PoolingMode.MEAN: + mode = "mean" + do_pooling = True + elif pooling_mode == PoolingMode.NONE: + mode = "sum" + do_pooling = False + else: + # This proves that we have exhaustively checked all PoolingModes + raise RuntimeError("Unknown PoolingMode!") - requests = generate_requests(1, B, T, L, min(Es), reuse=0.1) - for indices, offsets, _ in requests: - lowp_pooled_output = op( - indices=indices.int(), - offsets=offsets.int(), - ) - fp32_pooled_output = op_ref( - indices=indices.int(), - offsets=offsets.int(), - ) - lowp_pooled_emb_split = [ - d + 8 if output_dtype == SparseType.INT8 else d for d in Ds - ] - lowp_pooled_output_per_table = torch.split( - lowp_pooled_output, lowp_pooled_emb_split, dim=1 - ) - deq_lowp_pooled_output_per_table = [ - torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloat(t.contiguous()) - if output_dtype == SparseType.INT8 - else t.float() - for t in lowp_pooled_output_per_table + E = int(10**log_E) + if use_cpu: + D = (D + 15) // 16 * 4 + else: + D = D * 4 + if not mixed: + Ds = [D] * T + Es = [E] * T + else: + Ds = [ + round_up(np.random.randint(low=int(0.25 * D), high=int(1.0 * D)), 4) + for _ in range(T) ] - fp32_pooled_output_per_table = torch.split(fp32_pooled_output, Ds, dim=1) - dq_fp32_pooled_output_per_table = [ - torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloat( - torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized( - t.contiguous() - ).contiguous() - ).contiguous() - if output_dtype == SparseType.INT8 - else t.half().float() - for t in fp32_pooled_output_per_table + Es = [ + np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T) ] - cat_deq_lowp_pooled_output = torch.cat( - deq_lowp_pooled_output_per_table, dim=1 - ) - cat_dq_fp32_pooled_output = torch.cat( - dq_fp32_pooled_output_per_table, dim=1 - ) - torch.testing.assert_close( - cat_deq_lowp_pooled_output, - cat_dq_fp32_pooled_output, - rtol=1e-2, - atol=1e-2, - equal_nan=True, - ) - - @unittest.skipIf(*gpu_unavailable) - @given( - T=st.integers(min_value=1, max_value=10), - D=st.integers(min_value=2, max_value=128), - B=st.integers(min_value=1, max_value=128), - log_E=st.integers(min_value=3, max_value=5), - L=st.integers(min_value=0, max_value=20), - weights_ty=st.sampled_from( - [ - SparseType.FP32, - SparseType.FP16, - SparseType.INT8, - SparseType.INT4, - SparseType.INT2, - ] - ), - output_dtype=st.sampled_from( - [ - SparseType.FP16, - SparseType.BF16, - SparseType.INT8, - ] - ) - if not TEST_WITH_ROCM - else st.sampled_from( - [ - SparseType.FP16, - # The counterparts of __nv_bfloat16 and __nv_bfloat162 are not supported on ROCm - SparseType.INT8, - ] - ), - ) - @settings( - verbosity=VERBOSITY, - max_examples=MAX_EXAMPLES_LONG_RUNNING, - deadline=None, - suppress_health_check=[HealthCheck.filter_too_much], - ) - def test_nbit_split_embedding_weights_with_scale_and_bias( - self, - T: int, - D: int, - B: int, - log_E: int, - L: int, - weights_ty: SparseType, - output_dtype: SparseType, - ) -> None: - D_alignment = max(weights_ty.align_size() for t in range(T)) - D_alignment = max(D_alignment, output_dtype.align_size()) - D = round_up(D, D_alignment) - # BF16 output only works for CUDA device sm80+ (e.g., A100) - assume( - torch.cuda.is_available() - and torch.cuda.get_device_capability() >= (8, 0) - or not output_dtype == SparseType.BF16 - ) - Ds = [ - round_up( - np.random.randint(low=int(max(0.25 * D, 1)), high=int(1.0 * D)), - D_alignment, - ) - for _ in range(T) - ] - Ds = [D] * T - E = int(10**log_E) - Es = [np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T)] - - weights_ty_list = [weights_ty] * T - managed = [EmbeddingLocation.DEVICE] * T - op = IntNBitTableBatchedEmbeddingBagsCodegen( - embedding_specs=[ - ( - "", - E, - D, - W_TY, - EmbeddingLocation(M), - ) - for (E, D, M, W_TY) in zip(Es, Ds, managed, weights_ty_list) - ], - output_dtype=output_dtype, - device=torch.cuda.current_device(), - ) - # Initialize the random weights for int nbit table split embedding bag - op.fill_random_weights() - - # sync weights between two ops - split_weights = op.split_embedding_weights() - split_weights_with_scale_bias = op.split_embedding_weights_with_scale_bias( - split_scale_bias_mode=2 - ) - for t in range(T): - (weights, scale_bias) = split_weights[t] - (weights2, scale, bias) = split_weights_with_scale_bias[t] - torch.testing.assert_close(weights2, weights) - if scale is None: - self.assertIsNone(scale_bias) - self.assertIsNone(bias) - else: - torch.testing.assert_close( - scale.cpu(), - torch.tensor( - scale_bias[:, : scale_bias.size(1) // 2] - .contiguous() - .cpu() - .numpy() - .view(np.float16) - ), - ) - torch.testing.assert_close( - bias.cpu(), - torch.tensor( - scale_bias[:, scale_bias.size(1) // 2 :] - .contiguous() - .cpu() - .numpy() - .view(np.float16) - ), - ) - - @given( - T=st.integers(min_value=1, max_value=3), - D=st.integers(min_value=2, max_value=128), - B=st.integers(min_value=1, max_value=32), - log_E=st.integers(min_value=3, max_value=5), - L=st.integers(min_value=0, max_value=10), - weights_precision=st.sampled_from([SparseType.FP16, SparseType.FP32]), - weighted=st.booleans(), - mixed=st.booleans(), - long_segments=st.booleans(), - pooling_mode=st.sampled_from( - [ - PoolingMode.SUM, - PoolingMode.MEAN, - PoolingMode.NONE, - ] - ), - use_cpu=st.booleans() - if (gpu_available and not TEST_WITH_ROCM) - else st.just(False) - if (gpu_available and TEST_WITH_ROCM) - else st.just(True), - output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), - ) - @settings( - verbosity=VERBOSITY, - max_examples=10, - deadline=None, - suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], - ) - def test_backward_dense( # noqa C901 - self, - T: int, - D: int, - B: int, - log_E: int, - L: int, - weights_precision: SparseType, - weighted: bool, - mixed: bool, - long_segments: bool, - pooling_mode: PoolingMode, - use_cpu: bool, - output_dtype: SparseType, - ) -> None: - # NOTE: torch.autograd.gradcheck() is too time-consuming for CPU version - # so we have to limit (T * B * L * D)! - assume(not use_cpu or T * B * L * D <= 2048) - assume(pooling_mode == PoolingMode.SUM or not weighted) - assume(not (use_cpu and weights_precision == SparseType.FP16)) - # No bag ops only work on GPUs, no mixed, no weighted - assume(not use_cpu or pooling_mode != PoolingMode.NONE) - assume(not mixed or pooling_mode != PoolingMode.NONE) - assume(not weighted or pooling_mode != PoolingMode.NONE) - - emb_op = DenseTableBatchedEmbeddingBagsCodegen - if pooling_mode == PoolingMode.SUM: - mode = "sum" - do_pooling = True - elif pooling_mode == PoolingMode.MEAN: - mode = "mean" - do_pooling = True - elif pooling_mode == PoolingMode.NONE: - mode = "sum" - do_pooling = False - else: - # This proves that we have exhaustively checked all PoolingModes - raise RuntimeError("Unknown PoolingMode!") - - E = int(10**log_E) + compute_device = ComputeDevice.CUDA if use_cpu: - D = (D + 15) // 16 * 4 - else: - D = D * 4 - if not mixed: - Ds = [D] * T - Es = [E] * T + managed = [EmbeddingLocation.HOST] * T + compute_device = ComputeDevice.CPU + elif TEST_WITH_ROCM: + # ROCm managed memory allocation is under development + managed = [EmbeddingLocation.DEVICE] * T + elif use_cache: + managed = [EmbeddingLocation.MANAGED_CACHING] * T + if mixed: + average_D = sum(Ds) // T + for t, d in enumerate(Ds): + managed[t] = ( + EmbeddingLocation.DEVICE if d < average_D else managed[t] + ) else: - Ds = [ - round_up(np.random.randint(low=int(0.25 * D), high=int(1.0 * D)), 4) + managed = [ + np.random.choice( + [ + EmbeddingLocation.DEVICE, + ] + ) for _ in range(T) ] - Es = [ - np.random.randint(low=int(0.5 * E), high=int(2 * E)) for _ in range(T) - ] if do_pooling: bs = [ - to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=False), use_cpu) + to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True), use_cpu) for (E, D) in zip(Es, Ds) ] else: bs = [ - to_device(torch.nn.Embedding(E, D, sparse=False), use_cpu) + to_device(torch.nn.Embedding(E, D, sparse=True), use_cpu) for (E, D) in zip(Es, Ds) ] if weights_precision == SparseType.FP16: bs = [b.half() for b in bs] + feature_table_map = list(range(T)) xs = [ to_device( torch.from_numpy( - np.random.choice(range(e), size=(B, L), replace=True).astype( - np.int64 - ) + np.random.choice(range(Es[t]), size=(B, L)).astype(np.int64) ), use_cpu, ) - for e in Es + for t in feature_table_map ] - if long_segments and L > 0 and weights_precision != SparseType.FP16: + + if long_segments and L > 0: for x in xs: x[:, 0] = 0 - xws = [to_device(torch.randn(size=(B, L)), use_cpu) for _ in range(T)] + xws = [to_device(torch.randn(size=(B, L)), use_cpu) for _ in range(len(xs))] xws_acc_type = copy.deepcopy(xws) if weights_precision == SparseType.FP16: xws = [xw.half() for xw in xws] - fs = ( - [ - b_indices(b, x, use_cpu=use_cpu, do_pooling=do_pooling) - for (b, x) in zip(bs, xs) - ] - if not weighted - else [ - b_indices( - b, - x, - per_sample_weights=xw.view(-1), - use_cpu=use_cpu, - do_pooling=do_pooling, - ) - for (b, x, xw) in zip(bs, xs, xws) - ] - ) - gos = [torch.randn_like(f) for f in fs] - [f.backward(go) for (f, go) in zip(fs, gos)] + x = torch.cat([x.view(1, B, L) for x in xs], dim=0) + xw = torch.cat([xw.view(1, B, L) for xw in xws_acc_type], dim=0) - # pyre-fixme[16]: `Optional` has no attribute `view`. - grad_weights = torch.cat([b.weight.grad.view(-1) for b in bs]) - if weights_precision == SparseType.FP16 and not use_cpu: - grad_weights = grad_weights.half() + (indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu=use_cpu) + embedding_specs = [ + (E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed) + ] - cc = emb_op( - embedding_specs=[(E, D) for (E, D) in zip(Es, Ds)], - pooling_mode=pooling_mode, - use_cpu=use_cpu, + # Hyperparameters in case optimizer is not None + lr = 0.5 + eps = 0.2 + stochastic_rounding = random.choice([True, False]) + + if optimizer is None: + fs = ( + [ + b_indices(b, x, use_cpu=use_cpu, do_pooling=do_pooling) + for (b, x) in zip(bs, xs) + ] + if not weighted + else [ + b_indices( + b, + x, + per_sample_weights=xw.view(-1), + use_cpu=use_cpu, + do_pooling=do_pooling, + ) + for (b, x, xw) in zip(bs, xs, xws) + ] + ) + gos: Union[List[Tensor], Tensor] = [torch.randn_like(f) for f in fs] + [f.backward(go) for (f, go) in zip(fs, gos)] + else: + bs_ = SplitTableBatchedEmbeddingBagsCodegen( + embedding_specs=embedding_specs, + optimizer=optimizer, + feature_table_map=feature_table_map, + weights_precision=weights_precision, + pooling_mode=pooling_mode, + output_dtype=output_dtype, + learning_rate=lr, + eps=eps, + stochastic_rounding=stochastic_rounding, + ) + + for t in range(T): + bs_.split_embedding_weights()[t].data.copy_(bs[t].weight) + + fs = ( + bs_(indices, offsets) + if not weighted + else bs_( + indices, + offsets, + to_device(xw.contiguous().view(-1), use_cpu), + ) + ) + gos: Union[List[Tensor], Tensor] = torch.rand_like(fs) + fs.backward(gos) + + cc = SplitTableBatchedEmbeddingBagsCodegen( + embedding_specs=embedding_specs, + optimizer=OptimType.NONE, + feature_table_map=feature_table_map, weights_precision=weights_precision, + pooling_mode=pooling_mode, output_dtype=output_dtype, ) - if do_pooling: - # NOTE: test TorchScript-compatible! - cc = torch.jit.script(cc) for t in range(T): cc.split_embedding_weights()[t].data.copy_(bs[t].weight) - x = torch.cat([x.view(1, B, L) for x in xs], dim=0) - xw = torch.cat([xw.view(1, B, L) for xw in xws_acc_type], dim=0) + total_unique_indices = 0 + # Compute number of unique indices + for t in range(len(feature_table_map)): + start = offsets[t * B] + end = offsets[(t + 1) * B] + uniq_indices = indices[start:end].unique() + total_unique_indices += uniq_indices.numel() - (indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu=use_cpu) fc2 = ( - cc(indices, offsets) + cc(indices, offsets, total_unique_indices=total_unique_indices) if not weighted - else cc(indices, offsets, to_device(xw.contiguous().view(-1), use_cpu)) + else cc( + indices, + offsets, + to_device(xw.contiguous().view(-1), use_cpu), + total_unique_indices=total_unique_indices, + ) ) + if optimizer is None: + assert type(gos) is list + if do_pooling: + goc = torch.cat([go.view(B, -1) for go in gos], dim=1) + else: + goc = torch.cat(gos, dim=0) + else: + assert type(gos) is Tensor + goc = gos.clone() + fc2.backward(goc) - if do_pooling: - f = torch.cat([f.view(B, -1) for f in fs], dim=1) + if optimizer is not None: + params = SplitEmbeddingOptimizerParams(weights_dev=cc.weights_dev) + embedding_args = SplitEmbeddingArgs( + weights_placements=cc.weights_placements, + weights_offsets=cc.weights_offsets, + max_D=cc.max_D, + ) + optim = SplitEmbeddingRowwiseAdagrad( + params, + embedding_args, + embedding_specs, + feature_table_map, + learning_rate=lr, + eps=eps, + stochastic_rounding=stochastic_rounding, + ) + optim.step() + + if use_cache: + cc.flush() + + if optimizer is None: + test_tensor = cc.weights_dev.grad + weight_grads = [] + for t in range(T): + grad = bs[t].weight.grad + # Check grad to suppress pyre error + assert grad is not None + weight_grads.append(grad) + ref_grad = torch.concat(weight_grads, dim=0).to_sparse().coalesce() + ref_tensor = ( + ref_grad.half() if weights_precision == SparseType.FP16 else ref_grad + ) else: - f = torch.cat(fs, dim=0).view(-1, D) + indices = cc.weights_dev.grad._indices().flatten() + # Select only the part in the table that is updated + test_tensor = torch.index_select(cc.weights_dev.view(-1, D), 0, indices) + ref_tensor = torch.index_select(bs_.weights_dev.view(-1, D), 0, indices) - torch.testing.assert_close( - fc2.float(), - f.float(), - atol=5.0e-3 - if weights_precision == SparseType.FP16 or output_dtype == SparseType.FP16 - else 1.0e-5, - rtol=5.0e-3 - if weights_precision == SparseType.FP16 or output_dtype == SparseType.FP16 - else 1.0e-5, + tolerance = ( + 1.0e-2 + if long_segments + else ( + 1.0e-4 + if weights_precision == SparseType.FP32 + and output_dtype == SparseType.FP32 + else 1.0e-2 + ) ) - if do_pooling: - goc = torch.cat([go.view(B, -1) for go in gos], dim=1) - else: - goc = torch.cat(gos, dim=0) - fc2.backward(goc) torch.testing.assert_close( - cc.weights.grad, - grad_weights, - atol=5.0e-3 - if weights_precision == SparseType.FP16 or output_dtype == SparseType.FP16 - else 1.0e-4, - rtol=5.0e-3 - if weights_precision == SparseType.FP16 or output_dtype == SparseType.FP16 - else 1.0e-4, + test_tensor, + ref_tensor, + atol=tolerance, + rtol=tolerance, ) - cc = DenseTableBatchedEmbeddingBagsCodegen( - [(E, D) for (E, D) in zip(Es, Ds)], - # NOTE: only SUM pooling can work with per_sample_weights! - pooling_mode=PoolingMode.SUM, - use_cpu=use_cpu, - ) + def execute_backward_sgd_( # noqa C901 + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weights_precision: SparseType, + weighted: bool, + mixed: bool, + mixed_B: bool, + use_cache: bool, + cache_algorithm: CacheAlgorithm, + long_segments: bool, + pooling_mode: PoolingMode, + use_cpu: bool, + output_dtype: SparseType, + ) -> None: + # NOTE: cache is not applicable to CPU version. + assume(not use_cpu or not use_cache) + # NOTE: limit (T * B * L * D) to avoid timeout for CPU version! + assume(not use_cpu or T * B * L * D <= 2048) + assume(not (use_cpu and weights_precision == SparseType.FP16)) + # No bag ops only work on GPUs, no mixed, no weighted + assume(not use_cpu or pooling_mode != PoolingMode.NONE) + assume(not mixed or pooling_mode != PoolingMode.NONE) + assume(not weighted or pooling_mode != PoolingMode.NONE) - per_sample_weights = to_device(xw.contiguous().view(-1), use_cpu) - per_sample_weights.requires_grad = True - indices.requires_grad = False - offsets.requires_grad = False - for param in cc.parameters(): - param.requires_grad = False - y = cc(indices, offsets, per_sample_weights) - y.sum().backward() - # pyre-fixme[16]: `Optional` has no attribute `clone`. - indice_weight_grad_all = per_sample_weights.grad.clone().cpu() - T_ = len(xws) - feature_requires_grad = to_device( - torch.tensor(np.random.choice([0, 1], replace=True, size=(T_,))).int(), - use_cpu, + assume(pooling_mode == PoolingMode.SUM or not weighted) + # TODO: Support these cases + assume( + not mixed_B + or ( + weights_precision != SparseType.INT8 + and output_dtype != SparseType.INT8 + and not use_cpu + and not use_cache + and pooling_mode != PoolingMode.NONE + ) ) - per_sample_weights = per_sample_weights.detach().clone() - per_sample_weights.requires_grad = True - y = cc( - indices, - offsets, - per_sample_weights, - feature_requires_grad=feature_requires_grad, - ) - y.sum().backward() - indice_weight_grad_mask = per_sample_weights.grad.clone().cpu() - for t in range(T_): - if feature_requires_grad[t]: - torch.testing.assert_close( - indice_weight_grad_mask.view(T_, B, L)[t], - indice_weight_grad_all.view(T_, B, L)[t], - ) - else: - torch.testing.assert_close( - indice_weight_grad_mask.view(T_, B, L)[t], - torch.zeros_like(indice_weight_grad_mask.view(T_, B, L)[t]), - ) - - per_sample_weights = to_device(xw.contiguous().view(-1), use_cpu) - cc = cc.float() - per_sample_weights = per_sample_weights.float() - per_sample_weights.requires_grad = True - indices.requires_grad = False - offsets.requires_grad = False - for param in cc.parameters(): - param.requires_grad = False - gradcheck( - cc, (indices, offsets, per_sample_weights), eps=1e-2, atol=1e-3, rtol=1e-3 - ) - - @given( - T=st.integers(min_value=1, max_value=5), - D=st.integers(min_value=2, max_value=256), - B=st.integers(min_value=1, max_value=128), - log_E=st.integers(min_value=3, max_value=5), - L=st.integers(min_value=0, max_value=20), - weights_precision=st.sampled_from([SparseType.FP16, SparseType.FP32]), - weighted=st.booleans(), - long_segments=st.booleans(), - pooling_mode=st.sampled_from( - [ - PoolingMode.SUM, - PoolingMode.MEAN, - PoolingMode.NONE, - ] - ), - output_dtype=st.sampled_from([SparseType.FP16, SparseType.FP32]), - ) - @settings( - verbosity=VERBOSITY, - max_examples=MAX_EXAMPLES, - deadline=None, - suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], - ) - def test_backward_none(self, **kwargs: Any) -> None: - self.execute_backward_none_(**kwargs) - - @given( - T=st.integers(min_value=1, max_value=5), - D=st.integers(min_value=2, max_value=256), - B=st.integers(min_value=1, max_value=128), - log_E=st.integers(min_value=3, max_value=5), - L=st.integers(min_value=0, max_value=20), - weights_precision=st.sampled_from([SparseType.FP16, SparseType.FP32]), - weighted=st.booleans(), - long_segments=st.booleans(), - pooling_mode=st.sampled_from( - [ - PoolingMode.SUM, - PoolingMode.MEAN, - PoolingMode.NONE, - ] - ), - output_dtype=st.sampled_from([SparseType.FP16, SparseType.FP32]), - ) - @settings( - verbosity=VERBOSITY, - max_examples=MAX_EXAMPLES, - deadline=None, - suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], - ) - def test_backward_none_with_rowwise_adagrad(self, **kwargs: Any) -> None: - self.execute_backward_none_(optimizer=OptimType.EXACT_ROWWISE_ADAGRAD, **kwargs) - - def execute_backward_none_( # noqa C901 - self, - T: int, - D: int, - B: int, - log_E: int, - L: int, - weights_precision: SparseType, - weighted: bool, - long_segments: bool, - pooling_mode: PoolingMode, - output_dtype: SparseType, - optimizer: Optional[OptimType] = None, - ) -> None: - use_cpu = False - mixed = False - use_cache = False - - # NOTE: cache is not applicable to CPU version. - assume(not use_cpu or not use_cache) - # NOTE: limit (T * B * L * D) to avoid timeout for CPU version! - assume(not use_cpu or T * B * L * D <= 2048) - assume(not (use_cpu and weights_precision == SparseType.FP16)) - # No bag ops only work on GPUs, no mixed, no weighted - assume(not use_cpu or pooling_mode != PoolingMode.NONE) - assume(not mixed or pooling_mode != PoolingMode.NONE) - assume(not weighted or pooling_mode != PoolingMode.NONE) - - assume(pooling_mode == PoolingMode.SUM or not weighted) + emb_op = SplitTableBatchedEmbeddingBagsCodegen if pooling_mode == PoolingMode.SUM: mode = "sum" do_pooling = True @@ -1659,6 +1169,17 @@ def execute_backward_none_( # noqa C901 Es = [ np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T) ] + + if not mixed_B: + Bs = [B] * T + else: + low = max(int(0.25 * B), 1) + high = int(B) + if low == high: + Bs = [B] * T + else: + Bs = [np.random.randint(low=low, high=high) for _ in range(T)] + compute_device = ComputeDevice.CUDA if use_cpu: managed = [EmbeddingLocation.HOST] * T @@ -1679,6 +1200,7 @@ def execute_backward_none_( # noqa C901 np.random.choice( [ EmbeddingLocation.DEVICE, + EmbeddingLocation.MANAGED, ] ) for _ in range(T) @@ -1698,338 +1220,38 @@ def execute_backward_none_( # noqa C901 bs = [b.half() for b in bs] feature_table_map = list(range(T)) + table_to_replicate = T // 2 + # pyre-fixme[6]: For 2nd param expected `Embedding` but got + # `Union[Embedding, EmbeddingBag]`. + bs.insert(table_to_replicate, bs[table_to_replicate]) + feature_table_map.insert(table_to_replicate, table_to_replicate) + + num_features = len(feature_table_map) + if not mixed_B: + Bs = [B] * num_features + Bs_rank_feature = [[0]] + else: + Bs_rank_feature, Bs = gen_mixed_B_batch_sizes(B, num_features) + + # Generate indices xs = [ to_device( torch.from_numpy( - np.random.choice(range(Es[t]), size=(B, L)).astype(np.int64) + np.random.choice(range(Es[t]), size=(b, L), replace=True).astype( + np.int64 + ) ), use_cpu, ) - for t in feature_table_map + for t, b in zip(feature_table_map, Bs) ] if long_segments and L > 0: for x in xs: x[:, 0] = 0 - xws = [to_device(torch.randn(size=(B, L)), use_cpu) for _ in range(len(xs))] - xws_acc_type = copy.deepcopy(xws) - - if weights_precision == SparseType.FP16: - xws = [xw.half() for xw in xws] - - x = torch.cat([x.view(1, B, L) for x in xs], dim=0) - xw = torch.cat([xw.view(1, B, L) for xw in xws_acc_type], dim=0) - - (indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu=use_cpu) - embedding_specs = [ - (E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed) - ] - - # Hyperparameters in case optimizer is not None - lr = 0.5 - eps = 0.2 - stochastic_rounding = random.choice([True, False]) - - if optimizer is None: - fs = ( - [ - b_indices(b, x, use_cpu=use_cpu, do_pooling=do_pooling) - for (b, x) in zip(bs, xs) - ] - if not weighted - else [ - b_indices( - b, - x, - per_sample_weights=xw.view(-1), - use_cpu=use_cpu, - do_pooling=do_pooling, - ) - for (b, x, xw) in zip(bs, xs, xws) - ] - ) - gos: Union[List[Tensor], Tensor] = [torch.randn_like(f) for f in fs] - [f.backward(go) for (f, go) in zip(fs, gos)] - else: - bs_ = SplitTableBatchedEmbeddingBagsCodegen( - embedding_specs=embedding_specs, - optimizer=optimizer, - feature_table_map=feature_table_map, - weights_precision=weights_precision, - pooling_mode=pooling_mode, - output_dtype=output_dtype, - learning_rate=lr, - eps=eps, - stochastic_rounding=stochastic_rounding, - ) - - for t in range(T): - bs_.split_embedding_weights()[t].data.copy_(bs[t].weight) - - fs = ( - bs_(indices, offsets) - if not weighted - else bs_( - indices, - offsets, - to_device(xw.contiguous().view(-1), use_cpu), - ) - ) - gos: Union[List[Tensor], Tensor] = torch.rand_like(fs) - fs.backward(gos) - - cc = SplitTableBatchedEmbeddingBagsCodegen( - embedding_specs=embedding_specs, - optimizer=OptimType.NONE, - feature_table_map=feature_table_map, - weights_precision=weights_precision, - pooling_mode=pooling_mode, - output_dtype=output_dtype, - ) - - for t in range(T): - cc.split_embedding_weights()[t].data.copy_(bs[t].weight) - - total_unique_indices = 0 - # Compute number of unique indices - for t in range(len(feature_table_map)): - start = offsets[t * B] - end = offsets[(t + 1) * B] - uniq_indices = indices[start:end].unique() - total_unique_indices += uniq_indices.numel() - - fc2 = ( - cc(indices, offsets, total_unique_indices=total_unique_indices) - if not weighted - else cc( - indices, - offsets, - to_device(xw.contiguous().view(-1), use_cpu), - total_unique_indices=total_unique_indices, - ) - ) - if optimizer is None: - assert type(gos) is list - if do_pooling: - goc = torch.cat([go.view(B, -1) for go in gos], dim=1) - else: - goc = torch.cat(gos, dim=0) - else: - assert type(gos) is Tensor - goc = gos.clone() - fc2.backward(goc) - - if optimizer is not None: - params = SplitEmbeddingOptimizerParams(weights_dev=cc.weights_dev) - embedding_args = SplitEmbeddingArgs( - weights_placements=cc.weights_placements, - weights_offsets=cc.weights_offsets, - max_D=cc.max_D, - ) - optim = SplitEmbeddingRowwiseAdagrad( - params, - embedding_args, - embedding_specs, - feature_table_map, - learning_rate=lr, - eps=eps, - stochastic_rounding=stochastic_rounding, - ) - optim.step() - - if use_cache: - cc.flush() - - if optimizer is None: - test_tensor = cc.weights_dev.grad - weight_grads = [] - for t in range(T): - grad = bs[t].weight.grad - # Check grad to suppress pyre error - assert grad is not None - weight_grads.append(grad) - ref_grad = torch.concat(weight_grads, dim=0).to_sparse().coalesce() - ref_tensor = ( - ref_grad.half() if weights_precision == SparseType.FP16 else ref_grad - ) - else: - indices = cc.weights_dev.grad._indices().flatten() - # Select only the part in the table that is updated - test_tensor = torch.index_select(cc.weights_dev.view(-1, D), 0, indices) - ref_tensor = torch.index_select(bs_.weights_dev.view(-1, D), 0, indices) - - tolerance = ( - 1.0e-2 - if long_segments - else ( - 1.0e-4 - if weights_precision == SparseType.FP32 - and output_dtype == SparseType.FP32 - else 1.0e-2 - ) - ) - torch.testing.assert_close( - test_tensor, - ref_tensor, - atol=tolerance, - rtol=tolerance, - ) - - def execute_backward_sgd_( # noqa C901 - self, - T: int, - D: int, - B: int, - log_E: int, - L: int, - weights_precision: SparseType, - weighted: bool, - mixed: bool, - mixed_B: bool, - use_cache: bool, - cache_algorithm: CacheAlgorithm, - long_segments: bool, - pooling_mode: PoolingMode, - use_cpu: bool, - output_dtype: SparseType, - ) -> None: - # NOTE: cache is not applicable to CPU version. - assume(not use_cpu or not use_cache) - # NOTE: limit (T * B * L * D) to avoid timeout for CPU version! - assume(not use_cpu or T * B * L * D <= 2048) - assume(not (use_cpu and weights_precision == SparseType.FP16)) - # No bag ops only work on GPUs, no mixed, no weighted - assume(not use_cpu or pooling_mode != PoolingMode.NONE) - assume(not mixed or pooling_mode != PoolingMode.NONE) - assume(not weighted or pooling_mode != PoolingMode.NONE) - - assume(pooling_mode == PoolingMode.SUM or not weighted) - # TODO: Support these cases - assume( - not mixed_B - or ( - weights_precision != SparseType.INT8 - and output_dtype != SparseType.INT8 - and not use_cpu - and not use_cache - and pooling_mode != PoolingMode.NONE - ) - ) - - emb_op = SplitTableBatchedEmbeddingBagsCodegen - if pooling_mode == PoolingMode.SUM: - mode = "sum" - do_pooling = True - elif pooling_mode == PoolingMode.MEAN: - mode = "mean" - do_pooling = True - elif pooling_mode == PoolingMode.NONE: - mode = "sum" - do_pooling = False - else: - # This proves that we have exhaustively checked all PoolingModes - raise RuntimeError("Unknown PoolingMode!") - - E = int(10**log_E) - if use_cpu: - D = (D + 15) // 16 * 4 - else: - D = D * 4 - if not mixed: - Ds = [D] * T - Es = [E] * T - else: - Ds = [ - round_up(np.random.randint(low=int(0.25 * D), high=int(1.0 * D)), 4) - for _ in range(T) - ] - Es = [ - np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T) - ] - - if not mixed_B: - Bs = [B] * T - else: - low = max(int(0.25 * B), 1) - high = int(B) - if low == high: - Bs = [B] * T - else: - Bs = [np.random.randint(low=low, high=high) for _ in range(T)] - - compute_device = ComputeDevice.CUDA - if use_cpu: - managed = [EmbeddingLocation.HOST] * T - compute_device = ComputeDevice.CPU - elif TEST_WITH_ROCM: - # ROCm managed memory allocation is under development - managed = [EmbeddingLocation.DEVICE] * T - elif use_cache: - managed = [EmbeddingLocation.MANAGED_CACHING] * T - if mixed: - average_D = sum(Ds) // T - for t, d in enumerate(Ds): - managed[t] = ( - EmbeddingLocation.DEVICE if d < average_D else managed[t] - ) - else: - managed = [ - np.random.choice( - [ - EmbeddingLocation.DEVICE, - EmbeddingLocation.MANAGED, - ] - ) - for _ in range(T) - ] - if do_pooling: - bs = [ - to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True), use_cpu) - for (E, D) in zip(Es, Ds) - ] - else: - bs = [ - to_device(torch.nn.Embedding(E, D, sparse=True), use_cpu) - for (E, D) in zip(Es, Ds) - ] - - if weights_precision == SparseType.FP16: - bs = [b.half() for b in bs] - - feature_table_map = list(range(T)) - table_to_replicate = T // 2 - # pyre-fixme[6]: For 2nd param expected `Embedding` but got - # `Union[Embedding, EmbeddingBag]`. - bs.insert(table_to_replicate, bs[table_to_replicate]) - feature_table_map.insert(table_to_replicate, table_to_replicate) - - num_features = len(feature_table_map) - if not mixed_B: - Bs = [B] * num_features - Bs_rank_feature = [[0]] - else: - Bs_rank_feature, Bs = gen_mixed_B_batch_sizes(B, num_features) - - # Generate indices - xs = [ - to_device( - torch.from_numpy( - np.random.choice(range(Es[t]), size=(b, L), replace=True).astype( - np.int64 - ) - ), - use_cpu, - ) - for t, b in zip(feature_table_map, Bs) - ] - - if long_segments and L > 0: - for x in xs: - x[:, 0] = 0 - - # Generate positional weights - xws = [to_device(torch.randn(size=(b, L)), use_cpu) for b in Bs] + # Generate positional weights + xws = [to_device(torch.randn(size=(b, L)), use_cpu) for b in Bs] xws_acc_type = copy.deepcopy(xws) if weights_precision == SparseType.FP16: @@ -2039,898 +1261,179 @@ def execute_backward_sgd_( # noqa C901 fs = ( [ b_indices(b, x, use_cpu=use_cpu, do_pooling=do_pooling) - for (b, x) in zip(bs, xs) - ] - if not weighted - else [ - b_indices( - b, - x, - per_sample_weights=xw.view(-1), - use_cpu=use_cpu, - do_pooling=do_pooling, - ) - for (b, x, xw) in zip(bs, xs, xws) - ] - ) - # Generate gradients - gos = [torch.randn_like(f) for f in fs] - # Run baseline's backward - [f.backward(go) for (f, go) in zip(fs, gos)] - # do SGD update - lr = 0.05 - del bs[table_to_replicate] - # pyre-fixme[58]: `*` is not supported for operand types - # `Optional[torch._tensor.Tensor]` and `float`. - new_weights = [(b.weight - b.weight.grad * lr) for b in bs] - - # Create a TBE op - cc = emb_op( - embedding_specs=[ - (E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed) - ], - optimizer=OptimType.EXACT_SGD, - feature_table_map=feature_table_map, - learning_rate=lr, - weights_precision=weights_precision, - cache_algorithm=cache_algorithm, - pooling_mode=pooling_mode, - output_dtype=output_dtype, - ) - - for t in range(T): - cc.split_embedding_weights()[t].data.copy_(bs[t].weight) - - x = torch.cat([x.contiguous().flatten() for x in xs], dim=0) - xw = torch.cat([xw.contiguous().flatten() for xw in xws_acc_type], dim=0) - - (indices, offsets) = get_table_batched_offsets_from_dense( - x, L, sum(Bs), use_cpu=use_cpu - ) - - batch_size_per_feature_per_rank = Bs_rank_feature if mixed_B else None - - # Run TBE's forward - fc2 = ( - cc( - indices, - offsets, - batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, - ) - if not weighted - else cc( - indices, - offsets, - to_device(xw.contiguous().view(-1), use_cpu), - batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, - ) - ) - # Generate gradients - if do_pooling: - if mixed_B: - goc = format_ref_tensors_in_mixed_B_layout(gos, Bs_rank_feature) - else: - goc = torch.cat([go.view(B, -1) for go in gos], dim=1) - else: - goc = torch.cat(gos, dim=0) - - # Run TBE's backward - fc2.backward(goc) - - if use_cache: - cc.flush() - for t in range(T): - torch.testing.assert_close( - cc.split_embedding_weights()[t], - new_weights[t].half() - if weights_precision == SparseType.FP16 and not use_cpu - else new_weights[t], - atol=1.0e-2 - if long_segments - else (5.0e-3 if weights_precision == SparseType.FP16 else 1.0e-5), - rtol=1.0e-1 - if long_segments - else (2.0e-2 if weights_precision == SparseType.FP16 else 1.0e-5), - ) - - @given( - T=st.integers(min_value=1, max_value=5), - D=st.integers(min_value=2, max_value=256), - B=st.integers(min_value=1, max_value=128), - log_E=st.integers(min_value=3, max_value=5), - L=st.integers(min_value=0, max_value=20), - weights_precision=st.sampled_from([SparseType.FP16, SparseType.FP32]), - weighted=st.booleans(), - mixed=st.booleans(), - mixed_B=st.booleans(), - use_cache=st.booleans(), - cache_algorithm=st.sampled_from(CacheAlgorithm), - long_segments=st.booleans(), - pooling_mode=st.sampled_from( - [ - PoolingMode.SUM, - PoolingMode.MEAN, - PoolingMode.NONE, - ] - ), - use_cpu=st.booleans() - if (gpu_available and not TEST_WITH_ROCM) - else st.just(False) - if (gpu_available and TEST_WITH_ROCM) - else st.just(True), - ) - @settings( - verbosity=VERBOSITY, - max_examples=MAX_EXAMPLES, - deadline=None, - suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], - ) - def test_backward_sgd( # noqa C901 - self, - T: int, - D: int, - B: int, - log_E: int, - L: int, - weights_precision: SparseType, - weighted: bool, - mixed: bool, - mixed_B: bool, - use_cache: bool, - cache_algorithm: CacheAlgorithm, - long_segments: bool, - pooling_mode: PoolingMode, - use_cpu: bool, - ) -> None: - self.execute_backward_sgd_( - T, - D, - B, - log_E, - L, - weights_precision, - weighted, - mixed, - mixed_B if not use_cpu else False, - use_cache, - cache_algorithm, - long_segments, - pooling_mode, - use_cpu, - SparseType.FP32, # output_dtype - ) - - @given( - D=st.integers(min_value=2, max_value=10), - # 128 * 1024 is to exercise a case num_ctas_for_run needs to be capped - # at the number of SMs (H100 SXM5 has 132 SMs and the default seglen - # per CTA is 1024) - B=st.sampled_from([1152, 256 * 1024]), - L=st.integers(min_value=1, max_value=4), - weighted=st.booleans(), - mixed=st.booleans(), - mixed_B=st.booleans(), - use_cache=st.booleans(), - cache_algorithm=st.sampled_from(CacheAlgorithm), - ) - @settings( - verbosity=VERBOSITY, - max_examples=MAX_EXAMPLES_LONG_RUNNING, - deadline=None, - suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], - ) - @unittest.skipIf(*gpu_unavailable) - def test_backward_sgd_really_long_segments( # noqa C901 - self, - D: int, - B: int, - L: int, - weighted: bool, - mixed: bool, - mixed_B: bool, - use_cache: bool, - cache_algorithm: CacheAlgorithm, - ) -> None: - self.execute_backward_sgd_( - 2, # T - D, - B, - 1, # log_E, - L, - SparseType.FP32, # weights_precision - weighted, - mixed, - mixed_B, - use_cache, - cache_algorithm, - True, # long_segments - PoolingMode.SUM, # pooling_mode - False, # use_cpu - SparseType.FP32, # output_dtype - ) - - def execute_backward_adagrad_( # noqa C901 - self, - T: int, - D: int, - B: int, - log_E: int, - L: int, - D_gradcheck: int, - weights_precision: SparseType, - stochastic_rounding: bool, - weighted: bool, - row_wise: bool, - mixed: bool, - mixed_B: bool, - use_cache: bool, - cache_algorithm: CacheAlgorithm, - pooling_mode: PoolingMode, - use_cpu: bool, - output_dtype: SparseType, - weight_decay_mode: WeightDecayMode = WeightDecayMode.NONE, - ) -> None: - # NOTE: cache is not applicable to CPU version. - assume(not use_cpu or not use_cache) - - # NOTE: torch.autograd.gradcheck() is too time-consuming for CPU version - # so we have to limit (T * B * L * D)! - assume(not use_cpu or T * B * L * D <= 1024) - assume(not (use_cpu and weights_precision == SparseType.FP16)) - - assume( - pooling_mode == PoolingMode.SUM or not weighted - ) # No bag ops only work on GPUs, no mixed, no weighted - assume(not use_cpu or pooling_mode != PoolingMode.NONE) - assume(not mixed or pooling_mode != PoolingMode.NONE) - assume(not weighted or pooling_mode != PoolingMode.NONE) - # TODO: Support these cases - assume( - not mixed_B - or ( - weights_precision != SparseType.INT8 - and output_dtype != SparseType.INT8 - and not use_cpu - and not use_cache - and pooling_mode != PoolingMode.NONE - ) - ) - - emb_op = SplitTableBatchedEmbeddingBagsCodegen - if pooling_mode == PoolingMode.SUM: - mode = "sum" - do_pooling = True - elif pooling_mode == PoolingMode.MEAN: - mode = "mean" - do_pooling = True - elif pooling_mode == PoolingMode.NONE: - mode = "sum" - do_pooling = False - else: - # This proves that we have exhaustively checked all PoolingModes - raise RuntimeError("Unknown PoolingMode!") - - # stochastic rounding only implemented for rowwise - assume(not stochastic_rounding or row_wise) - # only row-wise supports caching - assume(row_wise or not use_cache) - - E = int(10**log_E) - if use_cpu: - D = (D + 15) // 16 * 4 - else: - D = D * 4 - if not mixed: - Ds = [D] * T - Es = [E] * T - else: - Ds = [ - round_up(np.random.randint(low=int(0.25 * D), high=int(1.0 * D)), 4) - for _ in range(T) - ] - Es = [ - np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T) - ] - - if not mixed_B: - Bs = [B] * T - else: - low = max(int(0.25 * B), 1) - high = int(B) - if low == high: - Bs = [B] * T - else: - Bs = [np.random.randint(low=low, high=high) for _ in range(T)] - - compute_device = ComputeDevice.CUDA - if use_cpu: - managed = [EmbeddingLocation.HOST] * T - compute_device = ComputeDevice.CPU - elif TEST_WITH_ROCM: - # ROCm managed memory allocation is under development - managed = [EmbeddingLocation.DEVICE] * T - elif use_cache: - managed = [EmbeddingLocation.MANAGED_CACHING] * T - if mixed: - average_D = sum(Ds) // T - for t, d in enumerate(Ds): - managed[t] = ( - EmbeddingLocation.DEVICE if d < average_D else managed[t] - ) - else: - managed = [ - np.random.choice( - [ - EmbeddingLocation.DEVICE, - EmbeddingLocation.MANAGED, - ] - ) - for _ in range(T) - ] - if do_pooling: - bs = [ - to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True), use_cpu) - for (E, D) in zip(Es, Ds) - ] - else: - bs = [ - to_device(torch.nn.Embedding(E, D, sparse=True), use_cpu) - for (E, D) in zip(Es, Ds) - ] - - if weights_precision == SparseType.FP16: - bs = [b.half() for b in bs] - - feature_table_map = list(range(T)) - # autograd with shared embedding only works for exact - table_to_replicate = T // 2 - # pyre-fixme[6]: For 2nd param expected `Embedding` but got - # `Union[Embedding, EmbeddingBag]`. - bs.insert(table_to_replicate, bs[table_to_replicate]) - feature_table_map.insert(table_to_replicate, table_to_replicate) - - num_features = len(feature_table_map) - if not mixed_B: - Bs = [B] * num_features - Bs_rank_feature = [[0]] - else: - Bs_rank_feature, Bs = gen_mixed_B_batch_sizes(B, num_features) - - xs = [ - to_device( - torch.from_numpy( - np.random.choice(range(Es[t]), size=(b, L), replace=True).astype( - np.int64 - ) - ), - use_cpu, - ) - for t, b in zip(feature_table_map, Bs) - ] - xws = [to_device(torch.randn(size=(b, L)), use_cpu) for b in Bs] - xws_acc_type = copy.deepcopy(xws) - - if weights_precision == SparseType.FP16 and not use_cpu: - xws = [xw.half() for xw in xws] - - fs = ( - [ - b_indices(b, x, use_cpu=use_cpu, do_pooling=do_pooling) - for (b, x) in zip(bs, xs) - ] - if not weighted - else [ - b_indices( - b, - x, - per_sample_weights=xw.view(-1), - use_cpu=use_cpu, - do_pooling=do_pooling, - ) - for (b, x, xw) in zip(bs, xs, xws) - ] - ) - gos = [torch.randn_like(f) for f in fs] - [f.backward(go) for (f, go) in zip(fs, gos)] - # do SGD update - lr = 0.5 - eps = 0.2 - - optimizer = ( - OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD - ) - cc = emb_op( - embedding_specs=[ - (E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed) - ], - feature_table_map=feature_table_map, - optimizer=optimizer, - learning_rate=lr, - eps=eps, - weights_precision=weights_precision, - stochastic_rounding=stochastic_rounding, - pooling_mode=pooling_mode, - output_dtype=output_dtype, - ) - - del bs[table_to_replicate] - for t in range(T): - cc.split_embedding_weights()[t].data.copy_(bs[t].weight) - - x = torch.cat([x.contiguous().flatten() for x in xs], dim=0) - xw = torch.cat([xw.contiguous().flatten() for xw in xws_acc_type], dim=0) - - (indices, offsets) = get_table_batched_offsets_from_dense( - x, L, sum(Bs), use_cpu=use_cpu - ) - - batch_size_per_feature_per_rank = Bs_rank_feature if mixed_B else None - - fc2 = ( - cc( - indices, - offsets, - batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, - ) - if not weighted - else cc( - indices, - offsets, - to_device(xw.contiguous().view(-1), use_cpu), - batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, - ) - ) - if do_pooling: - if mixed_B: - goc = format_ref_tensors_in_mixed_B_layout(gos, Bs_rank_feature) - else: - goc = torch.cat([go.view(B, -1) for go in gos], dim=1) - else: - goc = torch.cat(gos, dim=0) - fc2.backward(goc) - cc.flush() - split_optimizer_states = cc.split_optimizer_states() - assert len(split_optimizer_states) == T - - get_optimizer_states = None - if row_wise: - # get_optimizer_state should/must be implemented for rowwise - get_optimizer_states = cc.get_optimizer_state() - assert len(get_optimizer_states) == T - - tolerance = ( - 1.0e-4 - if weights_precision == SparseType.FP32 and output_dtype == SparseType.FP32 - else 1.0e-2 - ) - - for t in range(T): - expected_keys = {"sum"} - if row_wise and weight_decay_mode == WeightDecayMode.COUNTER: - (m1, c1, c2) = split_optimizer_states[t] - expected_keys.update( - [ - "prev_iter", - "row_counter", - ] - ) - else: - (m1,) = split_optimizer_states[t] - if get_optimizer_states is not None: - optimizer_states_dict = get_optimizer_states[t] - assert set(optimizer_states_dict.keys()) == expected_keys - # pyre-fixme[16]: `Optional` has no attribute `float`. - ref_optimizer_state = bs[t].weight.grad.float().cpu().to_dense().pow(2) - torch.testing.assert_close( - m1.float().cpu(), - ref_optimizer_state.mean(dim=1) if row_wise else ref_optimizer_state, - atol=tolerance, - rtol=tolerance, - ) - for t in range(T): - # optimizer_state = squares (no row-wise) or sum squares (row-wise) - if row_wise and weight_decay_mode == WeightDecayMode.COUNTER: - (m1, c1, c2) = split_optimizer_states[t] - else: - (m1,) = split_optimizer_states[t] - torch.testing.assert_close( - cc.split_embedding_weights()[t].float().cpu(), - torch.addcdiv( - bs[t].weight.float().cpu(), - value=-lr, - tensor1=bs[t].weight.grad.float().cpu().to_dense(), - tensor2=m1.float() - .sqrt_() - .add_(eps) - .view(Es[t], 1 if row_wise else Ds[t]) - .cpu(), - ), - atol=tolerance, - rtol=tolerance, - ) - if use_cpu: - D_gradcheck = (D_gradcheck + 15) // 16 * 4 - else: - D_gradcheck = D_gradcheck * 4 - cc = emb_op( - embedding_specs=[ - (E, D_gradcheck, M, compute_device) for (E, M) in zip(Es, managed) - ], - feature_table_map=feature_table_map, - optimizer=optimizer, - learning_rate=0.0, - eps=eps, - weights_precision=weights_precision, - stochastic_rounding=stochastic_rounding, - # NOTE: only SUM pooling can work with per_sample_weights! - pooling_mode=PoolingMode.SUM, - output_dtype=output_dtype, - ) - per_sample_weights = to_device(xw.contiguous().view(-1), use_cpu) - per_sample_weights.requires_grad = True - indices.requires_grad = False - offsets.requires_grad = False - for param in cc.parameters(): - param.requires_grad = False - gradcheck( - cc, - ( - indices, - offsets, - per_sample_weights, - None, - batch_size_per_feature_per_rank, - ), - ) - - per_sample_weights = to_device(xw.contiguous().view(-1), use_cpu) - per_sample_weights.requires_grad = True - indices.requires_grad = False - offsets.requires_grad = False - for param in cc.parameters(): - param.requires_grad = False - y = cc( - indices, - offsets, - per_sample_weights, - batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, - ) - y.sum().backward() - # pyre-fixme[16]: `Optional` has no attribute `clone`. - indice_weight_grad_all = per_sample_weights.grad.clone().cpu() - T_ = len(xws) - feature_requires_grad = to_device( - torch.tensor(np.random.choice([0, 1], replace=True, size=(T_,))).int(), - use_cpu, - ) - per_sample_weights = per_sample_weights.detach().clone() - per_sample_weights.requires_grad = True - y = cc( - indices, - offsets, - per_sample_weights, - feature_requires_grad=feature_requires_grad, - batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, - ) - y.sum().backward() - indice_weight_grad_mask = per_sample_weights.grad.clone().cpu() - torch.cuda.synchronize() - - acc_B = 0 - for t in range(T_): - B = Bs[t] - table_indice_weight_grad_mask = indice_weight_grad_mask[ - acc_B : acc_B + B * L - ] - table_indice_weight_grad_all = indice_weight_grad_all[acc_B : acc_B + B * L] - acc_B += B * L - if feature_requires_grad[t]: - torch.testing.assert_close( - table_indice_weight_grad_mask, - table_indice_weight_grad_all, - ) - else: - torch.testing.assert_close( - table_indice_weight_grad_mask, - torch.zeros_like(table_indice_weight_grad_mask), - ) - - @given( - T=st.integers(min_value=1, max_value=5), - D=st.integers(min_value=2, max_value=128), - B=st.integers(min_value=1, max_value=128), - log_E=st.integers(min_value=3, max_value=5), - L=st.integers(min_value=0, max_value=20), - D_gradcheck=st.integers(min_value=1, max_value=2), - weights_precision=st.just(SparseType.FP16), - stochastic_rounding=st.booleans(), - weighted=st.booleans(), - row_wise=st.booleans(), - mixed=st.booleans(), - mixed_B=st.booleans(), - use_cache=st.booleans(), - cache_algorithm=st.sampled_from(CacheAlgorithm), - use_cpu=st.booleans() - if (gpu_available and not TEST_WITH_ROCM) - else st.just(False) - if (gpu_available and TEST_WITH_ROCM) - else st.just(True), - output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), - ) - @settings( - verbosity=VERBOSITY, - max_examples=MAX_EXAMPLES_LONG_RUNNING, - deadline=None, - suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], - ) - def test_backward_adagrad_fp16_pmSUM( # noqa C901 - self, - T: int, - D: int, - B: int, - log_E: int, - L: int, - D_gradcheck: int, - weights_precision: SparseType, - stochastic_rounding: bool, - weighted: bool, - row_wise: bool, - mixed: bool, - mixed_B: bool, - use_cache: bool, - cache_algorithm: CacheAlgorithm, - use_cpu: bool, - output_dtype: SparseType, - ) -> None: - # VBE is supported in rowwise_adagrad only - if not row_wise: - mixed_B = False - self.execute_backward_adagrad_( - T, - D, - B, - log_E, - L, - D_gradcheck, - weights_precision, - stochastic_rounding, - weighted, - row_wise, - mixed, - mixed_B, - use_cache, - cache_algorithm, - PoolingMode.SUM, - use_cpu, - output_dtype, - ) - - @given( - T=st.integers(min_value=1, max_value=5), - D=st.integers(min_value=2, max_value=128), - B=st.integers(min_value=1, max_value=128), - log_E=st.integers(min_value=3, max_value=5), - L=st.integers(min_value=0, max_value=20), - D_gradcheck=st.integers(min_value=1, max_value=2), - weights_precision=st.just(SparseType.FP16), - stochastic_rounding=st.booleans(), - weighted=st.booleans(), - row_wise=st.booleans(), - mixed=st.booleans(), - mixed_B=st.booleans(), - use_cache=st.booleans(), - cache_algorithm=st.sampled_from(CacheAlgorithm), - use_cpu=st.booleans() - if (gpu_available and not TEST_WITH_ROCM) - else st.just(False) - if (gpu_available and TEST_WITH_ROCM) - else st.just(True), - output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), - ) - @settings( - verbosity=VERBOSITY, - max_examples=MAX_EXAMPLES_LONG_RUNNING, - deadline=None, - suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], - ) - def test_backward_adagrad_fp16_pmMEAN( # noqa C901 - self, - T: int, - D: int, - B: int, - log_E: int, - L: int, - D_gradcheck: int, - weights_precision: SparseType, - stochastic_rounding: bool, - weighted: bool, - row_wise: bool, - mixed: bool, - mixed_B: bool, - use_cache: bool, - cache_algorithm: CacheAlgorithm, - use_cpu: bool, - output_dtype: SparseType, - ) -> None: - # VBE is supported in rowwise_adagrad only - if not row_wise: - mixed_B = False - self.execute_backward_adagrad_( - T, - D, - B, - log_E, - L, - D_gradcheck, - weights_precision, - stochastic_rounding, - weighted, - row_wise, - mixed, - mixed_B, - use_cache, - cache_algorithm, - PoolingMode.MEAN, - use_cpu, - output_dtype, + for (b, x) in zip(bs, xs) + ] + if not weighted + else [ + b_indices( + b, + x, + per_sample_weights=xw.view(-1), + use_cpu=use_cpu, + do_pooling=do_pooling, + ) + for (b, x, xw) in zip(bs, xs, xws) + ] ) + # Generate gradients + gos = [torch.randn_like(f) for f in fs] + # Run baseline's backward + [f.backward(go) for (f, go) in zip(fs, gos)] + # do SGD update + lr = 0.05 + del bs[table_to_replicate] + # pyre-fixme[58]: `*` is not supported for operand types + # `Optional[torch._tensor.Tensor]` and `float`. + new_weights = [(b.weight - b.weight.grad * lr) for b in bs] - @given( - T=st.integers(min_value=1, max_value=5), - D=st.integers(min_value=2, max_value=128), - B=st.integers(min_value=1, max_value=128), - log_E=st.integers(min_value=3, max_value=5), - L=st.integers(min_value=0, max_value=20), - D_gradcheck=st.integers(min_value=1, max_value=2), - weights_precision=st.just(SparseType.FP16), - stochastic_rounding=st.booleans(), - weighted=st.booleans(), - row_wise=st.booleans(), - mixed=st.booleans(), - use_cache=st.booleans(), - cache_algorithm=st.sampled_from(CacheAlgorithm), - use_cpu=st.booleans() - if (gpu_available and not TEST_WITH_ROCM) - else st.just(False) - if (gpu_available and TEST_WITH_ROCM) - else st.just(True), - output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), - ) - @settings( - verbosity=VERBOSITY, - max_examples=MAX_EXAMPLES_LONG_RUNNING, - deadline=None, - suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], - ) - def test_backward_adagrad_fp16_pmNONE( # noqa C901 - self, - T: int, - D: int, - B: int, - log_E: int, - L: int, - D_gradcheck: int, - weights_precision: SparseType, - stochastic_rounding: bool, - weighted: bool, - row_wise: bool, - mixed: bool, - use_cache: bool, - cache_algorithm: CacheAlgorithm, - use_cpu: bool, - output_dtype: SparseType, - ) -> None: - self.execute_backward_adagrad_( - T, - D, - B, - log_E, - L, - D_gradcheck, - weights_precision, - stochastic_rounding, - weighted, - row_wise, - mixed, - False, # mixed_B - use_cache, - cache_algorithm, - PoolingMode.NONE, - use_cpu, - output_dtype, + # Create a TBE op + cc = emb_op( + embedding_specs=[ + (E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed) + ], + optimizer=OptimType.EXACT_SGD, + feature_table_map=feature_table_map, + learning_rate=lr, + weights_precision=weights_precision, + cache_algorithm=cache_algorithm, + pooling_mode=pooling_mode, + output_dtype=output_dtype, + ) + + for t in range(T): + cc.split_embedding_weights()[t].data.copy_(bs[t].weight) + + x = torch.cat([x.contiguous().flatten() for x in xs], dim=0) + xw = torch.cat([xw.contiguous().flatten() for xw in xws_acc_type], dim=0) + + (indices, offsets) = get_table_batched_offsets_from_dense( + x, L, sum(Bs), use_cpu=use_cpu + ) + + batch_size_per_feature_per_rank = Bs_rank_feature if mixed_B else None + + # Run TBE's forward + fc2 = ( + cc( + indices, + offsets, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) + if not weighted + else cc( + indices, + offsets, + to_device(xw.contiguous().view(-1), use_cpu), + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) ) + # Generate gradients + if do_pooling: + if mixed_B: + goc = format_ref_tensors_in_mixed_B_layout(gos, Bs_rank_feature) + else: + goc = torch.cat([go.view(B, -1) for go in gos], dim=1) + else: + goc = torch.cat(gos, dim=0) + + # Run TBE's backward + fc2.backward(goc) + + if use_cache: + cc.flush() + for t in range(T): + torch.testing.assert_close( + cc.split_embedding_weights()[t], + new_weights[t].half() + if weights_precision == SparseType.FP16 and not use_cpu + else new_weights[t], + atol=1.0e-2 + if long_segments + else (5.0e-3 if weights_precision == SparseType.FP16 else 1.0e-5), + rtol=1.0e-1 + if long_segments + else (2.0e-2 if weights_precision == SparseType.FP16 else 1.0e-5), + ) @given( T=st.integers(min_value=1, max_value=5), - D=st.integers(min_value=2, max_value=128), + D=st.integers(min_value=2, max_value=256), B=st.integers(min_value=1, max_value=128), log_E=st.integers(min_value=3, max_value=5), L=st.integers(min_value=0, max_value=20), - D_gradcheck=st.integers(min_value=1, max_value=2), - weights_precision=st.just(SparseType.FP32), - stochastic_rounding=st.booleans(), + weights_precision=st.sampled_from([SparseType.FP16, SparseType.FP32]), weighted=st.booleans(), - row_wise=st.booleans(), mixed=st.booleans(), mixed_B=st.booleans(), use_cache=st.booleans(), cache_algorithm=st.sampled_from(CacheAlgorithm), + long_segments=st.booleans(), + pooling_mode=st.sampled_from( + [ + PoolingMode.SUM, + PoolingMode.MEAN, + PoolingMode.NONE, + ] + ), use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), - output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), ) @settings( verbosity=VERBOSITY, - max_examples=MAX_EXAMPLES_LONG_RUNNING, + max_examples=MAX_EXAMPLES, deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], ) - def test_backward_adagrad_fp32_pmSUM( # noqa C901 + def test_backward_sgd( # noqa C901 self, T: int, D: int, B: int, log_E: int, L: int, - D_gradcheck: int, weights_precision: SparseType, - stochastic_rounding: bool, weighted: bool, - row_wise: bool, mixed: bool, mixed_B: bool, use_cache: bool, cache_algorithm: CacheAlgorithm, + long_segments: bool, + pooling_mode: PoolingMode, use_cpu: bool, - output_dtype: SparseType, ) -> None: - # VBE is supported in rowwise_adagrad only - if not row_wise: - mixed_B = False - self.execute_backward_adagrad_( + self.execute_backward_sgd_( T, D, B, log_E, L, - D_gradcheck, weights_precision, - stochastic_rounding, weighted, - row_wise, mixed, - mixed_B, + mixed_B if not use_cpu else False, use_cache, cache_algorithm, - PoolingMode.SUM, + long_segments, + pooling_mode, use_cpu, - output_dtype, + SparseType.FP32, # output_dtype ) @given( - T=st.integers(min_value=1, max_value=5), - D=st.integers(min_value=2, max_value=128), - B=st.integers(min_value=1, max_value=128), - log_E=st.integers(min_value=3, max_value=5), - L=st.integers(min_value=0, max_value=20), - D_gradcheck=st.integers(min_value=1, max_value=2), - weights_precision=st.just(SparseType.FP32), - stochastic_rounding=st.booleans(), + D=st.integers(min_value=2, max_value=10), + # 128 * 1024 is to exercise a case num_ctas_for_run needs to be capped + # at the number of SMs (H100 SXM5 has 132 SMs and the default seglen + # per CTA is 1024) + B=st.sampled_from([1152, 256 * 1024]), + L=st.integers(min_value=1, max_value=4), weighted=st.booleans(), - row_wise=st.booleans(), mixed=st.booleans(), mixed_B=st.booleans(), use_cache=st.booleans(), cache_algorithm=st.sampled_from(CacheAlgorithm), - use_cpu=st.booleans() - if (gpu_available and not TEST_WITH_ROCM) - else st.just(False) - if (gpu_available and TEST_WITH_ROCM) - else st.just(True), - output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), ) @settings( verbosity=VERBOSITY, @@ -2938,111 +1441,34 @@ def test_backward_adagrad_fp32_pmSUM( # noqa C901 deadline=None, suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], ) - def test_backward_adagrad_fp32_pmMEAN( # noqa C901 + @unittest.skipIf(*gpu_unavailable) + def test_backward_sgd_really_long_segments( # noqa C901 self, - T: int, D: int, B: int, - log_E: int, L: int, - D_gradcheck: int, - weights_precision: SparseType, - stochastic_rounding: bool, weighted: bool, - row_wise: bool, mixed: bool, mixed_B: bool, use_cache: bool, cache_algorithm: CacheAlgorithm, - use_cpu: bool, - output_dtype: SparseType, ) -> None: - # VBE is supported in rowwise_adagrad only - if not row_wise: - mixed_B = False - self.execute_backward_adagrad_( - T, + self.execute_backward_sgd_( + 2, # T D, B, - log_E, + 1, # log_E, L, - D_gradcheck, - weights_precision, - stochastic_rounding, + SparseType.FP32, # weights_precision weighted, - row_wise, mixed, mixed_B, use_cache, cache_algorithm, - PoolingMode.MEAN, - use_cpu, - output_dtype, - ) - - @given( - T=st.integers(min_value=1, max_value=5), - D=st.integers(min_value=2, max_value=128), - B=st.integers(min_value=1, max_value=128), - log_E=st.integers(min_value=3, max_value=5), - L=st.integers(min_value=0, max_value=20), - D_gradcheck=st.integers(min_value=1, max_value=2), - weights_precision=st.just(SparseType.FP32), - stochastic_rounding=st.booleans(), - weighted=st.booleans(), - row_wise=st.booleans(), - mixed=st.booleans(), - use_cache=st.booleans(), - cache_algorithm=st.sampled_from(CacheAlgorithm), - use_cpu=st.booleans() - if (gpu_available and not TEST_WITH_ROCM) - else st.just(False) - if (gpu_available and TEST_WITH_ROCM) - else st.just(True), - output_dtype=st.sampled_from([SparseType.FP32, SparseType.FP16]), - ) - @settings( - verbosity=VERBOSITY, - max_examples=MAX_EXAMPLES_LONG_RUNNING, - deadline=None, - suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], - ) - def test_backward_adagrad_fp32_pmNONE( # noqa C901 - self, - T: int, - D: int, - B: int, - log_E: int, - L: int, - D_gradcheck: int, - weights_precision: SparseType, - stochastic_rounding: bool, - weighted: bool, - row_wise: bool, - mixed: bool, - use_cache: bool, - cache_algorithm: CacheAlgorithm, - use_cpu: bool, - output_dtype: SparseType, - ) -> None: - self.execute_backward_adagrad_( - T, - D, - B, - log_E, - L, - D_gradcheck, - weights_precision, - stochastic_rounding, - weighted, - row_wise, - mixed, - False, # mixed_B - use_cache, - cache_algorithm, - PoolingMode.NONE, - use_cpu, - output_dtype, + True, # long_segments + PoolingMode.SUM, # pooling_mode + False, # use_cpu + SparseType.FP32, # output_dtype ) def _generate_cache_tbes( diff --git a/fbgemm_gpu/test/split_embeddings_utils_test.py b/fbgemm_gpu/test/tbe/utils_test.py similarity index 100% rename from fbgemm_gpu/test/split_embeddings_utils_test.py rename to fbgemm_gpu/test/tbe/utils_test.py