diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py index 52d3d9005..b43d0bd89 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py @@ -36,6 +36,19 @@ class CacheAlgorithm(enum.Enum): LFU = 1 +class MultiPassPrefetchConfig(NamedTuple): + # Number of passes to split indices tensor into. Actual number of passes may + # be less if indices tensor is too small to split. + num_passes: int = 12 + + # The minimal number of element in indices tensor to be able to split into + # two passes. This is useful to prevent too many prefetch kernels spamming + # the CUDA launch queue. + # The default 6M indices means 6M * 8 * 6 = approx. 300MB of memory overhead + # per pass. + min_splitable_pass_size: int = 6 * 1024 * 1024 + + class PoolingMode(enum.IntEnum): SUM = 0 MEAN = 1 diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index ab2577b7c..a9651e94a 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -37,6 +37,7 @@ construct_cache_state, EmbeddingLocation, MAX_PREFETCH_DEPTH, + MultiPassPrefetchConfig, PoolingMode, RecordCacheMetrics, SplitState, @@ -389,6 +390,7 @@ def __init__( # noqa C901 # Embedding table names that are contained in this TBE. table_names: Optional[List[str]] = None, optimizer_state_dtypes: Optional[Dict[str, SparseType]] = None, + multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None, ) -> None: super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__() self.uuid = str(uuid.uuid4()) @@ -402,12 +404,33 @@ def __init__( # noqa C901 self.prefetch_pipeline: bool = prefetch_pipeline self.lock_cache_line: bool = self.prefetch_pipeline self.use_uniq_cache_locations_bwd: bool = self.prefetch_pipeline + self.multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = ( + multipass_prefetch_config + ) if record_cache_metrics is not None: self.record_cache_metrics = record_cache_metrics else: self.record_cache_metrics = RecordCacheMetrics(False, False) + if multipass_prefetch_config: + assert ( + prefetch_pipeline + ), "Multipass prefetch makes no sense in non-prefetch mode." + assert ( + cache_algorithm == CacheAlgorithm.LRU + ), "Multipass prefetch is only supported in LRU cache." + assert ( + multipass_prefetch_config.num_passes > 0 + ), f"num_passes must be positive, get {multipass_prefetch_config.num_passes}" + assert ( + multipass_prefetch_config.min_splitable_pass_size > 0 + ), f"min_splitable_pass_size must be positive, get {multipass_prefetch_config.min_splitable_pass_size}" + assert ( + not self.record_cache_metrics.record_cache_miss_counter + and not self.record_cache_metrics.record_tablewise_cache_miss + ), "Unique cache miss counters are not accurate in multipass prefetch and therefore not supported" + self.embedding_specs = embedding_specs (rows, dims, locations, compute_devices) = zip(*embedding_specs) T_ = len(self.embedding_specs) @@ -925,6 +948,43 @@ def _register_nonpersistent_buffers(self, prefix: str) -> None: persistent=False, ) + @staticmethod + def get_prefetch_passes( + multipass_prefetch_config: Optional[MultiPassPrefetchConfig], + input_tensor: Tensor, + output_tensor: Tensor, + ) -> List[Tuple[Tensor, Tensor, int]]: + """ + Given input (the indices to forward), return the segmentation for each pass + in the format of (input[start_idx:end_idx], output[start_idx:end_idx], start_idx). + + Caller should guarantee input and output are having the size on dimension 0 + The returned segments are guaranteed to completely and non-overlappingly cover the input tensor. + + In non-multipass-prefetch mode, it returns the input/output tensor itself. + """ + if multipass_prefetch_config is None: + return [(input_tensor, output_tensor, 0)] + mpp_config: MultiPassPrefetchConfig = multipass_prefetch_config + + N = input_tensor.size(0) + if N <= mpp_config.num_passes or mpp_config.num_passes == 1: + # One row per pass, just don't split + return [(input_tensor, output_tensor, 0)] + + pass_size: int = max( + (N + mpp_config.num_passes - 1) // mpp_config.num_passes, + mpp_config.min_splitable_pass_size, + ) + + return list( + zip( + torch.split(input_tensor, pass_size), + torch.split(output_tensor, pass_size), + range(0, N, pass_size), + ) + ) + def get_states(self, prefix: str) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: if not hasattr(self, f"{prefix}_physical_placements"): raise DoesNotHavePrefix() @@ -1194,7 +1254,12 @@ def forward( # noqa: C901 self._report_tbe_mem_usage() if len(self.timesteps_prefetched) == 0: - self._prefetch(indices, offsets, vbe_metadata) + # In forward, we don't enable multi-pass prefetch as we want the process + # to be as fast as possible and memory usage doesn't matter (will be recycled + # by dense fwd/bwd) + self._prefetch( + indices, offsets, vbe_metadata, multipass_prefetch_config=None + ) if len(self.timesteps_prefetched) > 0: self.timesteps_prefetched.pop(0) @@ -1509,7 +1574,12 @@ def prefetch( offsets, batch_size_per_feature_per_rank, ) - self._prefetch(indices, offsets, vbe_metadata) + self._prefetch( + indices, + offsets, + vbe_metadata, + multipass_prefetch_config=self.multipass_prefetch_config, + ) if forward_stream is not None: self._prefetch_tensors_record_stream(forward_stream) @@ -1518,6 +1588,7 @@ def _prefetch( indices: Tensor, offsets: Tensor, vbe_metadata: Optional[invokers.lookup_args.VBEMetadata] = None, + multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None, ) -> None: if not is_torchdynamo_compiling(): # Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time @@ -1534,81 +1605,90 @@ def _prefetch( self.local_uvm_cache_stats.zero_() self._report_io_size_count("prefetch_input", indices) - linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( - self.cache_hash_size_cumsum, - indices, - offsets, - vbe_metadata.B_offsets if vbe_metadata is not None else None, - vbe_metadata.max_B if vbe_metadata is not None else -1, - ) - - if ( - self.record_cache_metrics.record_cache_miss_counter - or self.record_cache_metrics.record_tablewise_cache_miss + final_lxu_cache_locations = torch.empty_like(indices, dtype=torch.int32) + for ( + partial_indices, + partial_lxu_cache_locations, + base_offset, + ) in self.get_prefetch_passes( + multipass_prefetch_config, indices, final_lxu_cache_locations ): - lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( - linear_cache_indices, - self.lxu_cache_state, - self.total_cache_hash_size, - self.gather_uvm_cache_stats, - self.local_uvm_cache_stats, + linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( + self.cache_hash_size_cumsum, + partial_indices, + offsets, + vbe_metadata.B_offsets if vbe_metadata is not None else None, + vbe_metadata.max_B if vbe_metadata is not None else -1, + base_offset, ) - if self.record_cache_metrics.record_cache_miss_counter: - self._update_cache_miss_counter( - lxu_cache_locations, linear_cache_indices + + if ( + self.record_cache_metrics.record_cache_miss_counter + or self.record_cache_metrics.record_tablewise_cache_miss + ): + lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( + linear_cache_indices, + self.lxu_cache_state, + self.total_cache_hash_size, + self.gather_uvm_cache_stats, + self.local_uvm_cache_stats, + ) + if self.record_cache_metrics.record_cache_miss_counter: + self._update_cache_miss_counter( + lxu_cache_locations, linear_cache_indices + ) + if self.record_cache_metrics.record_tablewise_cache_miss: + self._update_tablewise_cache_miss( + lxu_cache_locations, linear_cache_indices, offsets + ) + + if self.cache_algorithm == CacheAlgorithm.LRU: + torch.ops.fbgemm.lru_cache_populate( + self.weights_uvm, + self.cache_hash_size_cumsum, + self.total_cache_hash_size, + self.cache_index_table_map, + self.weights_offsets, + self.D_offsets, + linear_cache_indices, + self.lxu_cache_state, + self.lxu_cache_weights, + self.timestep, + self.lxu_state, + self.stochastic_rounding, + self.gather_uvm_cache_stats, + self.local_uvm_cache_stats, + self.lock_cache_line, + self.lxu_cache_locking_counter, ) - if self.record_cache_metrics.record_tablewise_cache_miss: - self._update_tablewise_cache_miss( - lxu_cache_locations, linear_cache_indices, offsets + elif self.cache_algorithm == CacheAlgorithm.LFU: + torch.ops.fbgemm.lfu_cache_populate( + self.weights_uvm, + self.cache_hash_size_cumsum, + self.total_cache_hash_size, + self.cache_index_table_map, + self.weights_offsets, + self.D_offsets, + linear_cache_indices, + self.lxu_cache_state, + self.lxu_cache_weights, + self.lxu_state, + self.stochastic_rounding, ) - if self.cache_algorithm == CacheAlgorithm.LRU: - torch.ops.fbgemm.lru_cache_populate( - self.weights_uvm, - self.cache_hash_size_cumsum, - self.total_cache_hash_size, - self.cache_index_table_map, - self.weights_offsets, - self.D_offsets, + torch.ops.fbgemm.lxu_cache_lookup( linear_cache_indices, self.lxu_cache_state, - self.lxu_cache_weights, - self.timestep, - self.lxu_state, - self.stochastic_rounding, + self.total_cache_hash_size, self.gather_uvm_cache_stats, self.local_uvm_cache_stats, - self.lock_cache_line, - self.lxu_cache_locking_counter, - ) - elif self.cache_algorithm == CacheAlgorithm.LFU: - torch.ops.fbgemm.lfu_cache_populate( - self.weights_uvm, - self.cache_hash_size_cumsum, - self.total_cache_hash_size, - self.cache_index_table_map, - self.weights_offsets, - self.D_offsets, - linear_cache_indices, - self.lxu_cache_state, - self.lxu_cache_weights, - self.lxu_state, - self.stochastic_rounding, + lxu_cache_locations_output=partial_lxu_cache_locations, ) assert ( len(self.lxu_cache_locations_list) < self.max_prefetch_depth ), f"self.lxu_cache_locations_list has grown to size: {len(self.lxu_cache_locations_list)}, this exceeds the maximum: {self.max_prefetch_depth}. This probably indicates an error in logic where prefetch() is being called more frequently than forward()" - - lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( - linear_cache_indices, - self.lxu_cache_state, - self.total_cache_hash_size, - self.gather_uvm_cache_stats, - self.local_uvm_cache_stats, - ) - - self.lxu_cache_locations_list.append(lxu_cache_locations) + self.lxu_cache_locations_list.append(final_lxu_cache_locations) if self.gather_uvm_cache_stats: # Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64). diff --git a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_find.cu b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_find.cu index a0d08d0ee..f4448e81c 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_find.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_find.cu @@ -124,8 +124,11 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_find_uncached_kernel( const bool found = ::__ldg((&lxu_cache_state[cache_set][0]) + slot) == idx; if (found) { // mark it as recently accessed so we don't evict. + const bool already_locked = lru_state[cache_set][slot] == time_stamp; lru_state[cache_set][slot] = time_stamp; - if (lock_cache_line) { + // Don't lock the line one more time if we have locked it in the same + // batch (timestamp) + if (lock_cache_line && !already_locked) { lxu_cache_locking_counter[cache_set][slot] += 1; } } diff --git a/fbgemm_gpu/test/tbe/cache/cache_common.py b/fbgemm_gpu/test/tbe/cache/cache_common.py index e3d34649d..6920c2d23 100644 --- a/fbgemm_gpu/test/tbe/cache/cache_common.py +++ b/fbgemm_gpu/test/tbe/cache/cache_common.py @@ -23,6 +23,7 @@ ) from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( ComputeDevice, + MultiPassPrefetchConfig, SplitTableBatchedEmbeddingBagsCodegen, ) @@ -32,9 +33,13 @@ if open_source: # pyre-ignore[21] - from test_utils import gpu_unavailable, optests + from test_utils import gpu_unavailable, optests, skipIfRocm else: - from fbgemm_gpu.test.test_utils import gpu_unavailable, optests # noqa: F401 + from fbgemm_gpu.test.test_utils import ( # noqa: F401 + gpu_unavailable, # noqa: F401 + optests, # noqa: F401 + skipIfRocm, # noqa: F401 + ) VERBOSITY: Verbosity = Verbosity.verbose @@ -95,6 +100,7 @@ def generate_cache_tbes( stochastic_rounding: bool = False, gather_uvm_cache_stats: bool = False, reporter_config: Optional[TestingStatsReporterConfig] = None, + multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None, ) -> Tuple[ SplitTableBatchedEmbeddingBagsCodegen, SplitTableBatchedEmbeddingBagsCodegen, @@ -152,6 +158,7 @@ def generate_cache_tbes( cache_precision=weights_cache_precision, gather_uvm_cache_stats=gather_uvm_cache_stats, stats_reporter_config=reporter_config, + multipass_prefetch_config=multipass_prefetch_config, ) if use_int_weight: diff --git a/fbgemm_gpu/test/tbe/cache/cache_test.py b/fbgemm_gpu/test/tbe/cache/cache_test.py index 283453b1d..d6e7a8b3a 100644 --- a/fbgemm_gpu/test/tbe/cache/cache_test.py +++ b/fbgemm_gpu/test/tbe/cache/cache_test.py @@ -31,6 +31,7 @@ ) from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( ComputeDevice, + MultiPassPrefetchConfig, SplitTableBatchedEmbeddingBagsCodegen, ) from hypothesis import assume, given, settings @@ -42,6 +43,7 @@ generate_cache_tbes, gpu_unavailable, optests, + skipIfRocm, TestingStatsReporter, TestingStatsReporterConfig, VERBOSITY, @@ -75,6 +77,7 @@ def _compute_grad_output_shape( @optests.dontGenerateOpCheckTests("Serial OOM") @unittest.skipIf(*gpu_unavailable) + @skipIfRocm @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -172,6 +175,8 @@ def _test_cache_prefetch_pipeline( # noqa C901 stochastic_rounding: bool, gather_uvm_cache_stats: bool, mixed_B: bool = False, + mpp_n_passes: Optional[int] = None, + mpp_min_size: Optional[int] = None, ) -> None: """ test cache prefetch pipeline with prefetch_pipeline=True. @@ -186,6 +191,14 @@ def _test_cache_prefetch_pipeline( # noqa C901 assume(not mixed_B or T > 1) assert prefetch_location in ["before_fwd", "between_fwd_bwd"] reporter = TestingStatsReporterConfig(interval=2) + + mpp_conf: Optional[MultiPassPrefetchConfig] = None + if mpp_n_passes or mpp_min_size: + mpp_conf = MultiPassPrefetchConfig() + if mpp_n_passes: + mpp_conf = mpp_conf._replace(num_passes=mpp_n_passes) + if mpp_min_size: + mpp_conf = mpp_conf._replace(min_splitable_pass_size=mpp_min_size) cc, cc_ref, min_Es, sum_Ds = generate_cache_tbes( T, D, @@ -198,6 +211,7 @@ def _test_cache_prefetch_pipeline( # noqa C901 stochastic_rounding=stochastic_rounding, gather_uvm_cache_stats=gather_uvm_cache_stats, reporter_config=reporter, + multipass_prefetch_config=mpp_conf, ) iters = 5 vbe_num_ranks = random.randint(2, 5) @@ -410,6 +424,7 @@ def assert_event_not_exist(event_name: str) -> None: @optests.dontGenerateOpCheckTests("Serial OOM") @unittest.skipIf(*gpu_unavailable) + @skipIfRocm @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -421,6 +436,8 @@ def assert_event_not_exist(event_name: str) -> None: weights_cache_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]), stochastic_rounding=st.booleans(), gather_uvm_cache_stats=st.booleans(), + mpp_n_passes=st.sampled_from([None, 1, 6, 12]), + mpp_min_size=st.sampled_from([None, 1, 5, 10, 1024]), ) @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_cache_prefetch_pipeline( @@ -434,6 +451,7 @@ def test_cache_prefetch_pipeline( @optests.dontGenerateOpCheckTests("Serial OOM") @unittest.skipIf(*gpu_unavailable) + @skipIfRocm @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -445,6 +463,8 @@ def test_cache_prefetch_pipeline( stochastic_rounding=st.booleans(), gather_uvm_cache_stats=st.booleans(), mixed_B=st.booleans(), + mpp_n_passes=st.sampled_from([None, 1, 6, 12]), + mpp_min_size=st.sampled_from([None, 1, 5, 10, 1024]), ) @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_cache_prefetch_pipeline_stream_1( @@ -459,6 +479,7 @@ def test_cache_prefetch_pipeline_stream_1( @optests.dontGenerateOpCheckTests("Serial OOM") @unittest.skipIf(*gpu_unavailable) + @skipIfRocm @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -470,6 +491,8 @@ def test_cache_prefetch_pipeline_stream_1( stochastic_rounding=st.booleans(), gather_uvm_cache_stats=st.booleans(), mixed_B=st.booleans(), + mpp_n_passes=st.sampled_from([None, 1, 6, 12]), + mpp_min_size=st.sampled_from([None, 1, 5, 10, 1024]), ) @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_cache_prefetch_pipeline_stream_2( @@ -482,6 +505,59 @@ def test_cache_prefetch_pipeline_stream_2( prefetch_stream=torch.cuda.Stream(), ) + @given( + S=st.sampled_from([0, 7, 100, 1024]), + mpp_n_passes=st.sampled_from([None, 1, 6, 12]), + mpp_min_size=st.sampled_from([None, 1, 5, 10, 128]), + ) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) + def test_get_prefetch_passes( + self, S: int, mpp_n_passes: Optional[int], mpp_min_size: Optional[int] + ) -> None: + mpp_conf: Optional[MultiPassPrefetchConfig] = None + if mpp_n_passes or mpp_min_size: + mpp_conf = MultiPassPrefetchConfig() + if mpp_n_passes: + mpp_conf = mpp_conf._replace(num_passes=mpp_n_passes) + if mpp_min_size: + mpp_conf = mpp_conf._replace(min_splitable_pass_size=mpp_min_size) + input_tensor = torch.randn(S) + output_tensor = torch.randn(S) + + ret = SplitTableBatchedEmbeddingBagsCodegen.get_prefetch_passes( + mpp_conf, input_tensor, output_tensor + ) + + if not mpp_conf: + self.assertEqual(len(ret), 1) + self.assertTrue(torch.equal(ret[0][0], input_tensor)) + self.assertTrue(torch.equal(ret[0][1], output_tensor)) + self.assertEqual(ret[0][2], 0) + return + + # Make sure the max passes is not exceeding the configured value + self.assertGreaterEqual(mpp_conf.num_passes, len(ret)) + + # Make sure the passes are having the right start offset. Also make sure + # every pass would not go below the configured min size (except for the + # last pass) + for idx, t in enumerate(ret): + i, o, s = t + if idx < len(ret) - 1: + self.assertGreaterEqual(i.numel(), mpp_conf.min_splitable_pass_size) + self.assertTrue(torch.equal(i, input_tensor[s : s + i.numel()])) + self.assertTrue(torch.equal(o, output_tensor[s : s + i.numel()])) + + # Make sure the returned passes are both non-overlapping and complete. We do + # this by settong the tensor to all zero, and increment them when visited + input_tensor.zero_() + output_tensor.zero_() + for i, o, _ in ret: + i.add_(1) + o.add_(1) + self.assertTrue(torch.equal(torch.full_like(input_tensor, 1), input_tensor)) + self.assertTrue(torch.equal(torch.full_like(output_tensor, 1), output_tensor)) + @unittest.skipIf(*gpu_unavailable) @given( L=st.integers(min_value=0, max_value=16),