From 78c60ce1ead6f08d90eb2241d145561619042fa3 Mon Sep 17 00:00:00 2001 From: Yu Guo Date: Wed, 26 Jul 2023 18:07:15 -0700 Subject: [PATCH] TBE UVM cache prefetch pipeline (#1893) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1893 This diff is to enable cache prefetch pipeline of TBE, so that prefetch of batch_{i+1} can overlap with forward/backward of batch_i. As the cache can be evicted by prefetch and the weights can be updated by the backward, we need to carefully protect a few scenarios that result in cache invalidation. ## 1. prevent immature cache eviction: cache gets evicted while it is being used by forward pass Since prefetch can overlap with forward/backward pass, it is possible that prefetch tries to evict cache but the cache is being used by forward/backward pass. The fix is to use the `lxu_cache_locking_counter` in D46172802/https://github.com/pytorch/FBGEMM/pull/1883 to check whether a cache slot is in use or not when an eviction is attempted. ## 2. prevent dirty cache: weight is being updated while it is loading to cache If the prefetch overlaps with TBE backward pass, the backward may write to uvm (idx not in cache) and at the same time prefetch (idx is inserted to cache) loads the weight from uvm to cache. We sync the streams to avoid TBE backward pass overlapping with prefetch. The backward of the rest of the module can still overlap with prefetch of TBE. The stream sync looks like: ``` # backward(batch_i) waits for prefetch(batch_{i+1}) backward pre_hook: cur_stream.wait_stream(prefetch_stream) # backward(batch_i) TBE.backward() # prefetch(batch_{i+2}) waits for backward(batch_i) backward hook: prefetch_stream.wait_stream(cur_stream) ``` ## 3. prevent cache inconsistency: weight get updated after it is loaded to cache With pipeline, in the case that the same index is not inserted into cache in batch_i, but it is inserted in batch_{i+1}, the cache can be invalid in the sense that the cached weight for this index does not have the backward update of batch_i. Example of the issue is as follows: idx is in batch_i, batch_{i+1} prefetch(batch_i) - failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss) forward(batch_i) prefetch(batch_{i+1}) - insert idx into cache, cache is loaded from host memory backward(batch_i) - cache_locations_batch_i of idx is -1, the host memory is updated forward(batch_{i+1}) - OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated. The fix to this cache invalidation is to update the cache_locations_batch_i before backward of batch_i,so that the cache gets updated correctly by the backward pass of TBE. Reviewed By: sryap Differential Revision: D47418650 fbshipit-source-id: 144855513814ca9eb4a181c46c318d5cb70efb4d --- ...t_table_batched_embeddings_ops_training.py | 185 ++++++++++++- .../split_table_batched_embeddings_test.py | 254 ++++++++++++++++-- 2 files changed, 410 insertions(+), 29 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index ed40eed21..b5a007ccb 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -279,6 +279,7 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module): record_cache_metrics: RecordCacheMetrics uvm_cache_stats: torch.Tensor local_uvm_cache_stats: torch.Tensor + linear_cache_indices_list: List[Tensor] def __init__( # noqa C901 self, @@ -323,6 +324,10 @@ def __init__( # noqa C901 bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING, uvm_non_rowwise_momentum: bool = False, # place non-rowwise momentum on UVM use_experimental_tbe: bool = False, # set to True to use TBE v2 (only support NVIDIA GPUs) + # set to True to enable prefetch pipeline, currently only supports LRU cache policy. + # If a separate stream is used for prefetch, the optional forward_stream arg of prefetch function + # should be set. + prefetch_pipeline: bool = False, ) -> None: super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__() @@ -330,6 +335,11 @@ def __init__( # noqa C901 self.bounds_check_mode_int: int = bounds_check_mode.value self.weights_precision = weights_precision self.output_dtype: int = output_dtype.as_int() + assert ( + not prefetch_pipeline or cache_algorithm == CacheAlgorithm.LRU + ), "Only LRU cache policy supports prefetch_pipeline." + self.prefetch_pipeline: bool = prefetch_pipeline + self.lock_cache_line: bool = self.prefetch_pipeline if record_cache_metrics is not None: self.record_cache_metrics = record_cache_metrics @@ -919,10 +929,10 @@ def forward( # noqa: C901 ) self.step += 1 if len(self.timesteps_prefetched) == 0: - self.prefetch(indices, offsets) + self._prefetch(indices, offsets) self.timesteps_prefetched.pop(0) - lxu_cache_locations = ( + self.lxu_cache_locations = ( self.lxu_cache_locations_empty if len(self.lxu_cache_locations_list) == 0 else self.lxu_cache_locations_list.pop(0) @@ -945,7 +955,7 @@ def forward( # noqa: C901 pooling_mode=self.pooling_mode, indice_weights=per_sample_weights, feature_requires_grad=feature_requires_grad, - lxu_cache_locations=lxu_cache_locations, + lxu_cache_locations=self.lxu_cache_locations, output_dtype=self.output_dtype, vbe_metadata=vbe_metadata, is_experimental=self.is_experimental, @@ -1114,7 +1124,23 @@ def print_uvm_cache_stats(self) -> None: f"unique misses / requested indices: {uvm_cache_stats[3]/uvm_cache_stats[1]}\n" ) - def prefetch(self, indices: Tensor, offsets: Tensor) -> None: + def prefetch( + self, + indices: Tensor, + offsets: Tensor, + forward_stream: Optional[torch.cuda.Stream] = None, + ) -> None: + if self.prefetch_stream is None and forward_stream is not None: + self.prefetch_stream = torch.cuda.current_stream() + assert ( + self.prefetch_stream != forward_stream + ), "prefetch_stream and forward_stream should not be the same stream" + + self._prefetch(indices, offsets) + if forward_stream is not None: + self._prefetch_tensors_record_stream(forward_stream) + + def _prefetch(self, indices: Tensor, offsets: Tensor) -> None: self.timestep += 1 self.timesteps_prefetched.append(self.timestep) if not self.lxu_cache_weights.numel(): @@ -1163,6 +1189,8 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None: self.stochastic_rounding, self.gather_uvm_cache_stats, self.local_uvm_cache_stats, + self.lock_cache_line, + self.lxu_cache_locking_counter, ) elif self.cache_algorithm == CacheAlgorithm.LFU: torch.ops.fbgemm.lfu_cache_populate( @@ -1182,15 +1210,19 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None: assert ( len(self.lxu_cache_locations_list) < self.max_prefetch_depth ), f"self.lxu_cache_locations_list has grown to size: {len(self.lxu_cache_locations_list)}, this exceeds the maximum: {self.max_prefetch_depth}. This probably indicates an error in logic where prefetch() is being called more frequently than forward()" - self.lxu_cache_locations_list.append( - torch.ops.fbgemm.lxu_cache_lookup( - linear_cache_indices, - self.lxu_cache_state, - self.total_cache_hash_size, - self.gather_uvm_cache_stats, - self.local_uvm_cache_stats, - ) + + lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( + linear_cache_indices, + self.lxu_cache_state, + self.total_cache_hash_size, + self.gather_uvm_cache_stats, + self.local_uvm_cache_stats, ) + + self.lxu_cache_locations_list.append(lxu_cache_locations) + if self.prefetch_pipeline: + self.linear_cache_indices_list.append(linear_cache_indices) + if self.gather_uvm_cache_stats: # Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64). # We may wanna do this accumulation atomically, but as it's only for monitoring, @@ -1200,6 +1232,20 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None: ) self.local_uvm_cache_stats.zero_() + def _prefetch_tensors_record_stream( + self, forward_stream: torch.cuda.Stream + ) -> None: + # Record the tensors created by prefetch stream and consumed by forward/backward + # to the forward stream. In PyTorch, each backward CUDA op runs on the same + # stream that was used for its corresponding forward op. + + for t in self.lxu_cache_locations_list: + # pyre-fixme[6]: For 1st param expected `_C.Stream` but got `streams.Stream` + t.record_stream(forward_stream) + for t in self.linear_cache_indices_list: + # pyre-fixme[6]: For 1st param expected `_C.Stream` but got `streams.Stream` + t.record_stream(forward_stream) + def _update_cache_miss_counter( self, lxu_cache_locations: Tensor, @@ -1521,6 +1567,9 @@ def _apply_cache_state( self.lxu_cache_locations_empty = torch.empty( 0, device=self.current_device, dtype=torch.int32 ).fill_(-1) + self.lxu_cache_locations = self.lxu_cache_locations_empty + self.prefetch_stream: Optional[torch.cuda.Stream] = None + self.linear_cache_indices_list = [] self._init_uvm_cache_stats() @@ -1561,6 +1610,7 @@ def _apply_cache_state( torch.tensor([0, 0], dtype=torch.int64), persistent=False, ) + self._init_uvm_cache_counter(cache_sets, persistent=False) return assert cache_load_factor > 0 @@ -1648,6 +1698,18 @@ def _apply_cache_state( "cache_miss_counter", torch.tensor([0, 0], device=self.current_device, dtype=torch.int64), ) + self._init_uvm_cache_counter(cache_sets, persistent=True) + if self.prefetch_pipeline: + # using the placeholder_autograd_tensor to make sure + # the hook is executed after the backward pass + # not using register_module_full_backward_hook + # due to https://github.com/pytorch/pytorch/issues/100528 + self.placeholder_autograd_tensor.register_hook( + self._sync_stream_post_backward + ) + self.register_full_backward_pre_hook( + self._update_cache_counter_and_locations + ) if cache_algorithm not in (CacheAlgorithm.LFU, CacheAlgorithm.LRU): raise ValueError( @@ -1655,6 +1717,105 @@ def _apply_cache_state( f"or {CacheAlgorithm.LFU}" ) + def _sync_stream_post_backward( + self, + grad: Tensor, + ) -> None: + """ + backward hook function when prefetch_pipeline is enabled. + + With the pipeline, prefetch(batch_{i+2}) may overlap with backward(batch_{i}). + There is race condition that backward(batch_i) writes to UVM memory and + at the same time prefetch(batch_{i+2}) loads UVM memory to cache. This stream sync forces + backward(batch_i) to finish before prefetch(batch_{i+2}). + """ + if self.prefetch_stream is not None: + self.prefetch_stream.wait_stream(torch.cuda.current_stream()) + + def _update_cache_counter_and_locations( + self, + module: nn.Module, + grad_input: Union[Tuple[Tensor, ...], Tensor], + ) -> None: + """ + Backward prehook function when prefetch_pipeline is enabled. + + This function does 3 things: + 1. backward stream waits for prefetch stream to finish. + Otherwise the prefetch(batch_{i+1}) might overlap with backward(batch_i). + If an idx is not in cache in batch_i, but it is being inserted in batch_{i+1}, + there is race condition that backward(batch_i) writes to UVM memory and + at the same time prefetch(batch_{i+1}) loads UVM memory to cache. + + 2. decrement the lxu_cache_locking_counter to indicate the current batch is finished. + The lxu_cache_locking_counter is updated in both prefetch and TBE backward. + As there is no overlap between prefetch and backward, we can decrement either before or + after backward. It's better to decrement before lxu_cache_locations gets updated. + + 3. update lxu_cache_locations to address the cache inconsistency issue. + In the case that the same index is not inserted into cache in batch_i, + but it is inserted in batch_{i+1}, the cache can be invalid in + the sense that the cached weight for this index does not have the + backward update of batch_i. + + Example of the issue is as follows: + idx is in batch_i, batch_{i+1} + prefetch(batch_i) + - failed to insert idx into cache, cache_locations_batch_i of idx is -1 (cache miss) + forward(batch_i) + prefetch(batch_{i+1}) + - insert idx into cache, cache is loaded from host memory + backward(batch_i) + - cache_locations_batch_i of idx is -1, the host memory is updated + forward(batch_{i+1}) + - OUTPUT IS WRONG. the weight for idx is fetched from cache, but the cache is outdated. + + The fix to this cache inconsistency is to update the cache_locations_batch_i before backward of batch_i, + so that the cache gets updated correctly by the backward pass of TBE. + """ + + if self.prefetch_stream is not None: + # need to wait for the prefetch of next batch, + # so that cache states are valid + torch.cuda.current_stream().wait_stream(self.prefetch_stream) + + torch.ops.fbgemm.lxu_cache_locking_counter_decrement( + self.lxu_cache_locking_counter, + self.lxu_cache_locations, + ) + + linear_cache_indices = self.linear_cache_indices_list.pop(0) + lxu_cache_locations_new = torch.ops.fbgemm.lxu_cache_lookup( + linear_cache_indices, + self.lxu_cache_state, + self.total_cache_hash_size, + False, # not collecting cache stats + self.local_uvm_cache_stats, + ) + # self.lxu_cache_locations is updated inplace + torch.ops.fbgemm.lxu_cache_locations_update( + self.lxu_cache_locations, + lxu_cache_locations_new, + ) + + def _init_uvm_cache_counter(self, cache_sets: int, persistent: bool) -> None: + if self.prefetch_pipeline and persistent: + self.register_buffer( + "lxu_cache_locking_counter", + torch.zeros( + cache_sets, + DEFAULT_ASSOC, + device=self.current_device, + dtype=torch.int32, + ), + ) + else: + self.register_buffer( + "lxu_cache_locking_counter", + torch.zeros([0, 0], dtype=torch.int32, device=self.current_device), + persistent=persistent, + ) + def _init_uvm_cache_stats(self) -> None: if not self.gather_uvm_cache_stats: # If uvm_cache_stats is not enabled, register stub entries via buffer to state_dict for TorchScript to JIT properly. diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index fe0f72cb5..25aa0ad78 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -3011,18 +3011,7 @@ def test_backward_adagrad_fp32_pmNONE( # noqa C901 output_dtype, ) - @unittest.skipIf(*gpu_unavailable) - @given( - T=st.integers(min_value=1, max_value=5), - D=st.integers(min_value=2, max_value=256), - B=st.integers(min_value=1, max_value=128), - log_E=st.integers(min_value=3, max_value=5), - L=st.integers(min_value=1, max_value=20), - mixed=st.booleans(), - cache_algorithm=st.sampled_from(CacheAlgorithm), - ) - @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) - def test_cache_pipeline( + def _generate_cache_tbes( self, T: int, D: int, @@ -3030,9 +3019,16 @@ def test_cache_pipeline( log_E: int, L: int, mixed: bool, - cache_algorithm: CacheAlgorithm, - ) -> None: - iters = 3 + cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU, + prefetch_pipeline: bool = False, + use_int_weight: bool = False, + ) -> Tuple[ + SplitTableBatchedEmbeddingBagsCodegen, + SplitTableBatchedEmbeddingBagsCodegen, + int, + int, + ]: + lr = 1.0 if use_int_weight else 0.02 E = int(10**log_E) D = D * 4 if not mixed: @@ -3061,11 +3057,30 @@ def test_cache_pipeline( ) for (E, D) in zip(Es, Ds) ], + stochastic_rounding=False, + prefetch_pipeline=False, + learning_rate=lr, ) cc = SplitTableBatchedEmbeddingBagsCodegen( [(E, D, M, ComputeDevice.CUDA) for (E, D, M) in zip(Es, Ds, managed)], cache_algorithm=cache_algorithm, + stochastic_rounding=False, + prefetch_pipeline=prefetch_pipeline, + learning_rate=lr, ) + + if use_int_weight: + min_val = -20 + max_val = +20 + for param in cc_ref.split_embedding_weights(): + p = torch.randint( + int(min_val), + int(max_val) + 1, + size=param.shape, + device=param.device, + ) + param.data.copy_(p) + for t in range(T): self.assertEqual( cc.split_embedding_weights()[t].size(), @@ -3075,8 +3090,35 @@ def test_cache_pipeline( cc_ref.split_embedding_weights()[t] ) - requests = generate_requests(iters, B, T, L, min(Es), reuse=0.1) - grad_output = torch.randn(B, sum(Ds)).cuda() + return (cc, cc_ref, min(Es), sum(Ds)) + + @unittest.skipIf(*gpu_unavailable) + @given( + T=st.integers(min_value=1, max_value=5), + D=st.integers(min_value=2, max_value=256), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=1, max_value=20), + mixed=st.booleans(), + cache_algorithm=st.sampled_from(CacheAlgorithm), + ) + @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + def test_cache_pipeline( + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + mixed: bool, + cache_algorithm: CacheAlgorithm, + ) -> None: + cc, cc_ref, min_Es, sum_Ds = self._generate_cache_tbes( + T, D, B, log_E, L, mixed, cache_algorithm + ) + iters = 3 + requests = generate_requests(iters, B, T, L, min_Es, reuse=0.1) + grad_output = torch.randn(B, sum_Ds).cuda() for indices, offsets, _ in requests: output = cc(indices, offsets) @@ -3090,6 +3132,184 @@ def test_cache_pipeline( cc.split_embedding_weights()[t], cc_ref.split_embedding_weights()[t] ) + def _test_cache_prefetch_pipeline( # noqa C901 + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + mixed: bool, + prefetch_location: str, + prefetch_stream: Optional[torch.cuda.Stream], + ) -> None: + """ + test cache prefetch pipeline with prefetch_pipeline=True. + prefetch_location can be "before_fwd" or "between_fwd_bwd", + where the TBE prefetch(batch_{i+1}) is called before forward(batch_i) + or in between of forward(batch_i) and backward(batch_i), respectively. + If prefetch_stream is not None, the TBE prefetch function will use this stream. + In addition, we make the TBE weights initialized as integer values, learning_rate + as integer value, and gradients as integer values so that the test is more stable. + """ + + assert prefetch_location in ["before_fwd", "between_fwd_bwd"] + cc, cc_ref, min_Es, sum_Ds = self._generate_cache_tbes( + T, D, B, log_E, L, mixed, CacheAlgorithm.LRU, True, True + ) + iters = 5 + requests = generate_requests(iters, B, T, L, min_Es, reuse=0.1) + grad_output = ( + torch.randint( + low=-10, + high=10, + size=(B, sum_Ds), + ) + .float() + .cuda() + ) + torch.cuda.synchronize() # make sure TBEs and inputs are ready + self.assertTrue(torch.all(cc.lxu_cache_locking_counter == 0)) + + cur_stream: torch.cuda.Stream = torch.cuda.current_stream() + + req_iter = iter(requests) + batch_i = next(req_iter) + batch_ip1 = None + output, output_ref = None, None + + def _prefetch( + cc: SplitTableBatchedEmbeddingBagsCodegen, + batch: Optional[Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], + ) -> None: + if not batch: + return + context_stream = prefetch_stream if prefetch_stream else cur_stream + stream = cur_stream if prefetch_stream else None + indices, offsets, _ = batch + with torch.cuda.stream(context_stream): + cc.prefetch(indices, offsets, stream) + + _prefetch(cc, batch_i) + while batch_i: + indices, offsets, _ = batch_i + batch_ip1 = next(req_iter, None) + if prefetch_stream: + cur_stream.wait_stream(prefetch_stream) + if prefetch_location == "before_fwd": + _prefetch(cc, batch_ip1) + output = cc(indices, offsets) + if prefetch_location == "between_fwd_bwd": + _prefetch(cc, batch_ip1) + output.backward(grad_output) + batch_i = batch_ip1 + batch_ip1 = None + cc.flush() + + for indices, offsets, _ in requests: + output_ref = cc_ref(indices, offsets) + output_ref.backward(grad_output) + + for t in range(T): + torch.testing.assert_close( + cc.split_embedding_weights()[t], cc_ref.split_embedding_weights()[t] + ) + + torch.testing.assert_close(output, output_ref) + self.assertTrue(torch.all(cc.lxu_cache_locking_counter == 0)) + + @unittest.skipIf(*gpu_unavailable) + @given( + T=st.integers(min_value=1, max_value=5), + D=st.integers(min_value=2, max_value=256), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=1, max_value=20), + mixed=st.booleans(), + prefetch_location=st.sampled_from(["before_fwd", "between_fwd_bwd"]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + def test_cache_prefetch_pipeline( + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + mixed: bool, + prefetch_location: str, + ) -> None: + self._test_cache_prefetch_pipeline( + T, + D, + B, + log_E, + L, + mixed, + prefetch_location, + prefetch_stream=None, + ) + + @unittest.skipIf(*gpu_unavailable) + @given( + T=st.integers(min_value=1, max_value=5), + D=st.integers(min_value=2, max_value=256), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=1, max_value=20), + mixed=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + def test_cache_prefetch_pipeline_stream_1( + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + mixed: bool, + ) -> None: + self._test_cache_prefetch_pipeline( + T, + D, + B, + log_E, + L, + mixed, + prefetch_location="before_fwd", + prefetch_stream=torch.cuda.Stream(), + ) + + @unittest.skipIf(*gpu_unavailable) + @given( + T=st.integers(min_value=1, max_value=5), + D=st.integers(min_value=2, max_value=256), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=1, max_value=20), + mixed=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + def test_cache_prefetch_pipeline_stream_2( + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + mixed: bool, + ) -> None: + self._test_cache_prefetch_pipeline( + T, + D, + B, + log_E, + L, + mixed, + prefetch_location="between_fwd_bwd", + prefetch_stream=torch.cuda.Stream(), + ) + def execute_backward_optimizers_( # noqa C901 self, T: int,