From 578ab67ba5b1afe66ea9b8b5b854a17b1eb332cb Mon Sep 17 00:00:00 2001 From: Levy Zhao Date: Fri, 17 May 2024 07:31:40 -0700 Subject: [PATCH] Implement multi-pass prefetch for memory efficiency (#2566) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2566 ## Context Memory snapshot shows significant memory usage during prefetch kernels (specifically, `linearize_cache_index` and `lru_cache_populate`), which is estimated to be 6x of input size And unfortunately, due to they using dedicated stream, the memory cannot be reused by any other stream without performance penalty. So we need to lower down the peak prefetch memory usage as much as possible. ## MultiPass Prefetch (MPP) Multipass prefetch is basically a technique to sacrifice a bit of more running time for less peak memory during prefetch: We observed that intermediate memory usage for all functions during prefetch is `O(N)`, so we reduce the total prefetched index (`N`) for each pass to reduce the peak temporary usage. The following passes will recycle the memory used in the first pass so they won't further increase the memory footprint. **Benefit** With this being turned on, the peak memory usage will be dropped from `6 * input_size` to `(6 / M) * input_size`, where `M` is the total # of passes being configured. **Overhead** Overall, the bigger `M` we configured, the slower we'll be. But the overall overhead is acceptable. - **Efficiency regression**: Prefetch is taking longer because the process of cache lookup is being repeated for every duplicate index. In the past, they're deduped before being looked up, but now they might be look up multiple times if duplicate index are across different passes. - The regression is overall insignificant, as the major cost is the data movement between DDR and HBM. We'll always copy the data only once, even if they're duplicated across different passes. - The regression is likely hidden from the actual training performance, since prefetch happen in a separate stream. As long as it's not long enough to block sparse backward it's invisible. - **Spamming CUDA Launch Queue**: CUDA is allowing max # of 1024 pending kernels. CPU will go blocking if more are submitted. If a kernel is really small, we'll easily spam launch queue and greatly hurt QPS. We mitigate this via limit the minimal # of elements for a pass. ## What's in the patch? 1. Add multipass prefetch config to the interface of TBE. By default it's None for full backward compatibility 2. Modify the `lru_find_uncached` to make it idempotent -- if we tried to lock the same id multiple times in one single timestep (but multiple passes), we'll increase lock counter by only one. Reviewed By: sryap Differential Revision: D56908989 fbshipit-source-id: 94413acf18f4652687f594a3f69365fe9d68b54b --- ...lit_table_batched_embeddings_ops_common.py | 13 ++ ...t_table_batched_embeddings_ops_training.py | 206 ++++++++++++------ .../split_embeddings_cache/lru_cache_find.cu | 5 +- fbgemm_gpu/test/tbe/cache/cache_common.py | 11 +- fbgemm_gpu/test/tbe/cache/cache_test.py | 76 +++++++ 5 files changed, 245 insertions(+), 66 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py index 52d3d9005..b43d0bd89 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py @@ -36,6 +36,19 @@ class CacheAlgorithm(enum.Enum): LFU = 1 +class MultiPassPrefetchConfig(NamedTuple): + # Number of passes to split indices tensor into. Actual number of passes may + # be less if indices tensor is too small to split. + num_passes: int = 12 + + # The minimal number of element in indices tensor to be able to split into + # two passes. This is useful to prevent too many prefetch kernels spamming + # the CUDA launch queue. + # The default 6M indices means 6M * 8 * 6 = approx. 300MB of memory overhead + # per pass. + min_splitable_pass_size: int = 6 * 1024 * 1024 + + class PoolingMode(enum.IntEnum): SUM = 0 MEAN = 1 diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index ab2577b7c..a9651e94a 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -37,6 +37,7 @@ construct_cache_state, EmbeddingLocation, MAX_PREFETCH_DEPTH, + MultiPassPrefetchConfig, PoolingMode, RecordCacheMetrics, SplitState, @@ -389,6 +390,7 @@ def __init__( # noqa C901 # Embedding table names that are contained in this TBE. table_names: Optional[List[str]] = None, optimizer_state_dtypes: Optional[Dict[str, SparseType]] = None, + multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None, ) -> None: super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__() self.uuid = str(uuid.uuid4()) @@ -402,12 +404,33 @@ def __init__( # noqa C901 self.prefetch_pipeline: bool = prefetch_pipeline self.lock_cache_line: bool = self.prefetch_pipeline self.use_uniq_cache_locations_bwd: bool = self.prefetch_pipeline + self.multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = ( + multipass_prefetch_config + ) if record_cache_metrics is not None: self.record_cache_metrics = record_cache_metrics else: self.record_cache_metrics = RecordCacheMetrics(False, False) + if multipass_prefetch_config: + assert ( + prefetch_pipeline + ), "Multipass prefetch makes no sense in non-prefetch mode." + assert ( + cache_algorithm == CacheAlgorithm.LRU + ), "Multipass prefetch is only supported in LRU cache." + assert ( + multipass_prefetch_config.num_passes > 0 + ), f"num_passes must be positive, get {multipass_prefetch_config.num_passes}" + assert ( + multipass_prefetch_config.min_splitable_pass_size > 0 + ), f"min_splitable_pass_size must be positive, get {multipass_prefetch_config.min_splitable_pass_size}" + assert ( + not self.record_cache_metrics.record_cache_miss_counter + and not self.record_cache_metrics.record_tablewise_cache_miss + ), "Unique cache miss counters are not accurate in multipass prefetch and therefore not supported" + self.embedding_specs = embedding_specs (rows, dims, locations, compute_devices) = zip(*embedding_specs) T_ = len(self.embedding_specs) @@ -925,6 +948,43 @@ def _register_nonpersistent_buffers(self, prefix: str) -> None: persistent=False, ) + @staticmethod + def get_prefetch_passes( + multipass_prefetch_config: Optional[MultiPassPrefetchConfig], + input_tensor: Tensor, + output_tensor: Tensor, + ) -> List[Tuple[Tensor, Tensor, int]]: + """ + Given input (the indices to forward), return the segmentation for each pass + in the format of (input[start_idx:end_idx], output[start_idx:end_idx], start_idx). + + Caller should guarantee input and output are having the size on dimension 0 + The returned segments are guaranteed to completely and non-overlappingly cover the input tensor. + + In non-multipass-prefetch mode, it returns the input/output tensor itself. + """ + if multipass_prefetch_config is None: + return [(input_tensor, output_tensor, 0)] + mpp_config: MultiPassPrefetchConfig = multipass_prefetch_config + + N = input_tensor.size(0) + if N <= mpp_config.num_passes or mpp_config.num_passes == 1: + # One row per pass, just don't split + return [(input_tensor, output_tensor, 0)] + + pass_size: int = max( + (N + mpp_config.num_passes - 1) // mpp_config.num_passes, + mpp_config.min_splitable_pass_size, + ) + + return list( + zip( + torch.split(input_tensor, pass_size), + torch.split(output_tensor, pass_size), + range(0, N, pass_size), + ) + ) + def get_states(self, prefix: str) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: if not hasattr(self, f"{prefix}_physical_placements"): raise DoesNotHavePrefix() @@ -1194,7 +1254,12 @@ def forward( # noqa: C901 self._report_tbe_mem_usage() if len(self.timesteps_prefetched) == 0: - self._prefetch(indices, offsets, vbe_metadata) + # In forward, we don't enable multi-pass prefetch as we want the process + # to be as fast as possible and memory usage doesn't matter (will be recycled + # by dense fwd/bwd) + self._prefetch( + indices, offsets, vbe_metadata, multipass_prefetch_config=None + ) if len(self.timesteps_prefetched) > 0: self.timesteps_prefetched.pop(0) @@ -1509,7 +1574,12 @@ def prefetch( offsets, batch_size_per_feature_per_rank, ) - self._prefetch(indices, offsets, vbe_metadata) + self._prefetch( + indices, + offsets, + vbe_metadata, + multipass_prefetch_config=self.multipass_prefetch_config, + ) if forward_stream is not None: self._prefetch_tensors_record_stream(forward_stream) @@ -1518,6 +1588,7 @@ def _prefetch( indices: Tensor, offsets: Tensor, vbe_metadata: Optional[invokers.lookup_args.VBEMetadata] = None, + multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None, ) -> None: if not is_torchdynamo_compiling(): # Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time @@ -1534,81 +1605,90 @@ def _prefetch( self.local_uvm_cache_stats.zero_() self._report_io_size_count("prefetch_input", indices) - linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( - self.cache_hash_size_cumsum, - indices, - offsets, - vbe_metadata.B_offsets if vbe_metadata is not None else None, - vbe_metadata.max_B if vbe_metadata is not None else -1, - ) - - if ( - self.record_cache_metrics.record_cache_miss_counter - or self.record_cache_metrics.record_tablewise_cache_miss + final_lxu_cache_locations = torch.empty_like(indices, dtype=torch.int32) + for ( + partial_indices, + partial_lxu_cache_locations, + base_offset, + ) in self.get_prefetch_passes( + multipass_prefetch_config, indices, final_lxu_cache_locations ): - lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( - linear_cache_indices, - self.lxu_cache_state, - self.total_cache_hash_size, - self.gather_uvm_cache_stats, - self.local_uvm_cache_stats, + linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( + self.cache_hash_size_cumsum, + partial_indices, + offsets, + vbe_metadata.B_offsets if vbe_metadata is not None else None, + vbe_metadata.max_B if vbe_metadata is not None else -1, + base_offset, ) - if self.record_cache_metrics.record_cache_miss_counter: - self._update_cache_miss_counter( - lxu_cache_locations, linear_cache_indices + + if ( + self.record_cache_metrics.record_cache_miss_counter + or self.record_cache_metrics.record_tablewise_cache_miss + ): + lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( + linear_cache_indices, + self.lxu_cache_state, + self.total_cache_hash_size, + self.gather_uvm_cache_stats, + self.local_uvm_cache_stats, + ) + if self.record_cache_metrics.record_cache_miss_counter: + self._update_cache_miss_counter( + lxu_cache_locations, linear_cache_indices + ) + if self.record_cache_metrics.record_tablewise_cache_miss: + self._update_tablewise_cache_miss( + lxu_cache_locations, linear_cache_indices, offsets + ) + + if self.cache_algorithm == CacheAlgorithm.LRU: + torch.ops.fbgemm.lru_cache_populate( + self.weights_uvm, + self.cache_hash_size_cumsum, + self.total_cache_hash_size, + self.cache_index_table_map, + self.weights_offsets, + self.D_offsets, + linear_cache_indices, + self.lxu_cache_state, + self.lxu_cache_weights, + self.timestep, + self.lxu_state, + self.stochastic_rounding, + self.gather_uvm_cache_stats, + self.local_uvm_cache_stats, + self.lock_cache_line, + self.lxu_cache_locking_counter, ) - if self.record_cache_metrics.record_tablewise_cache_miss: - self._update_tablewise_cache_miss( - lxu_cache_locations, linear_cache_indices, offsets + elif self.cache_algorithm == CacheAlgorithm.LFU: + torch.ops.fbgemm.lfu_cache_populate( + self.weights_uvm, + self.cache_hash_size_cumsum, + self.total_cache_hash_size, + self.cache_index_table_map, + self.weights_offsets, + self.D_offsets, + linear_cache_indices, + self.lxu_cache_state, + self.lxu_cache_weights, + self.lxu_state, + self.stochastic_rounding, ) - if self.cache_algorithm == CacheAlgorithm.LRU: - torch.ops.fbgemm.lru_cache_populate( - self.weights_uvm, - self.cache_hash_size_cumsum, - self.total_cache_hash_size, - self.cache_index_table_map, - self.weights_offsets, - self.D_offsets, + torch.ops.fbgemm.lxu_cache_lookup( linear_cache_indices, self.lxu_cache_state, - self.lxu_cache_weights, - self.timestep, - self.lxu_state, - self.stochastic_rounding, + self.total_cache_hash_size, self.gather_uvm_cache_stats, self.local_uvm_cache_stats, - self.lock_cache_line, - self.lxu_cache_locking_counter, - ) - elif self.cache_algorithm == CacheAlgorithm.LFU: - torch.ops.fbgemm.lfu_cache_populate( - self.weights_uvm, - self.cache_hash_size_cumsum, - self.total_cache_hash_size, - self.cache_index_table_map, - self.weights_offsets, - self.D_offsets, - linear_cache_indices, - self.lxu_cache_state, - self.lxu_cache_weights, - self.lxu_state, - self.stochastic_rounding, + lxu_cache_locations_output=partial_lxu_cache_locations, ) assert ( len(self.lxu_cache_locations_list) < self.max_prefetch_depth ), f"self.lxu_cache_locations_list has grown to size: {len(self.lxu_cache_locations_list)}, this exceeds the maximum: {self.max_prefetch_depth}. This probably indicates an error in logic where prefetch() is being called more frequently than forward()" - - lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( - linear_cache_indices, - self.lxu_cache_state, - self.total_cache_hash_size, - self.gather_uvm_cache_stats, - self.local_uvm_cache_stats, - ) - - self.lxu_cache_locations_list.append(lxu_cache_locations) + self.lxu_cache_locations_list.append(final_lxu_cache_locations) if self.gather_uvm_cache_stats: # Accumulate local_uvm_cache_stats (int32) into uvm_cache_stats (int64). diff --git a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_find.cu b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_find.cu index a0d08d0ee..f4448e81c 100644 --- a/fbgemm_gpu/src/split_embeddings_cache/lru_cache_find.cu +++ b/fbgemm_gpu/src/split_embeddings_cache/lru_cache_find.cu @@ -124,8 +124,11 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_find_uncached_kernel( const bool found = ::__ldg((&lxu_cache_state[cache_set][0]) + slot) == idx; if (found) { // mark it as recently accessed so we don't evict. + const bool already_locked = lru_state[cache_set][slot] == time_stamp; lru_state[cache_set][slot] = time_stamp; - if (lock_cache_line) { + // Don't lock the line one more time if we have locked it in the same + // batch (timestamp) + if (lock_cache_line && !already_locked) { lxu_cache_locking_counter[cache_set][slot] += 1; } } diff --git a/fbgemm_gpu/test/tbe/cache/cache_common.py b/fbgemm_gpu/test/tbe/cache/cache_common.py index e3d34649d..6920c2d23 100644 --- a/fbgemm_gpu/test/tbe/cache/cache_common.py +++ b/fbgemm_gpu/test/tbe/cache/cache_common.py @@ -23,6 +23,7 @@ ) from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( ComputeDevice, + MultiPassPrefetchConfig, SplitTableBatchedEmbeddingBagsCodegen, ) @@ -32,9 +33,13 @@ if open_source: # pyre-ignore[21] - from test_utils import gpu_unavailable, optests + from test_utils import gpu_unavailable, optests, skipIfRocm else: - from fbgemm_gpu.test.test_utils import gpu_unavailable, optests # noqa: F401 + from fbgemm_gpu.test.test_utils import ( # noqa: F401 + gpu_unavailable, # noqa: F401 + optests, # noqa: F401 + skipIfRocm, # noqa: F401 + ) VERBOSITY: Verbosity = Verbosity.verbose @@ -95,6 +100,7 @@ def generate_cache_tbes( stochastic_rounding: bool = False, gather_uvm_cache_stats: bool = False, reporter_config: Optional[TestingStatsReporterConfig] = None, + multipass_prefetch_config: Optional[MultiPassPrefetchConfig] = None, ) -> Tuple[ SplitTableBatchedEmbeddingBagsCodegen, SplitTableBatchedEmbeddingBagsCodegen, @@ -152,6 +158,7 @@ def generate_cache_tbes( cache_precision=weights_cache_precision, gather_uvm_cache_stats=gather_uvm_cache_stats, stats_reporter_config=reporter_config, + multipass_prefetch_config=multipass_prefetch_config, ) if use_int_weight: diff --git a/fbgemm_gpu/test/tbe/cache/cache_test.py b/fbgemm_gpu/test/tbe/cache/cache_test.py index 283453b1d..d6e7a8b3a 100644 --- a/fbgemm_gpu/test/tbe/cache/cache_test.py +++ b/fbgemm_gpu/test/tbe/cache/cache_test.py @@ -31,6 +31,7 @@ ) from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( ComputeDevice, + MultiPassPrefetchConfig, SplitTableBatchedEmbeddingBagsCodegen, ) from hypothesis import assume, given, settings @@ -42,6 +43,7 @@ generate_cache_tbes, gpu_unavailable, optests, + skipIfRocm, TestingStatsReporter, TestingStatsReporterConfig, VERBOSITY, @@ -75,6 +77,7 @@ def _compute_grad_output_shape( @optests.dontGenerateOpCheckTests("Serial OOM") @unittest.skipIf(*gpu_unavailable) + @skipIfRocm @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -172,6 +175,8 @@ def _test_cache_prefetch_pipeline( # noqa C901 stochastic_rounding: bool, gather_uvm_cache_stats: bool, mixed_B: bool = False, + mpp_n_passes: Optional[int] = None, + mpp_min_size: Optional[int] = None, ) -> None: """ test cache prefetch pipeline with prefetch_pipeline=True. @@ -186,6 +191,14 @@ def _test_cache_prefetch_pipeline( # noqa C901 assume(not mixed_B or T > 1) assert prefetch_location in ["before_fwd", "between_fwd_bwd"] reporter = TestingStatsReporterConfig(interval=2) + + mpp_conf: Optional[MultiPassPrefetchConfig] = None + if mpp_n_passes or mpp_min_size: + mpp_conf = MultiPassPrefetchConfig() + if mpp_n_passes: + mpp_conf = mpp_conf._replace(num_passes=mpp_n_passes) + if mpp_min_size: + mpp_conf = mpp_conf._replace(min_splitable_pass_size=mpp_min_size) cc, cc_ref, min_Es, sum_Ds = generate_cache_tbes( T, D, @@ -198,6 +211,7 @@ def _test_cache_prefetch_pipeline( # noqa C901 stochastic_rounding=stochastic_rounding, gather_uvm_cache_stats=gather_uvm_cache_stats, reporter_config=reporter, + multipass_prefetch_config=mpp_conf, ) iters = 5 vbe_num_ranks = random.randint(2, 5) @@ -410,6 +424,7 @@ def assert_event_not_exist(event_name: str) -> None: @optests.dontGenerateOpCheckTests("Serial OOM") @unittest.skipIf(*gpu_unavailable) + @skipIfRocm @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -421,6 +436,8 @@ def assert_event_not_exist(event_name: str) -> None: weights_cache_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]), stochastic_rounding=st.booleans(), gather_uvm_cache_stats=st.booleans(), + mpp_n_passes=st.sampled_from([None, 1, 6, 12]), + mpp_min_size=st.sampled_from([None, 1, 5, 10, 1024]), ) @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_cache_prefetch_pipeline( @@ -434,6 +451,7 @@ def test_cache_prefetch_pipeline( @optests.dontGenerateOpCheckTests("Serial OOM") @unittest.skipIf(*gpu_unavailable) + @skipIfRocm @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -445,6 +463,8 @@ def test_cache_prefetch_pipeline( stochastic_rounding=st.booleans(), gather_uvm_cache_stats=st.booleans(), mixed_B=st.booleans(), + mpp_n_passes=st.sampled_from([None, 1, 6, 12]), + mpp_min_size=st.sampled_from([None, 1, 5, 10, 1024]), ) @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_cache_prefetch_pipeline_stream_1( @@ -459,6 +479,7 @@ def test_cache_prefetch_pipeline_stream_1( @optests.dontGenerateOpCheckTests("Serial OOM") @unittest.skipIf(*gpu_unavailable) + @skipIfRocm @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -470,6 +491,8 @@ def test_cache_prefetch_pipeline_stream_1( stochastic_rounding=st.booleans(), gather_uvm_cache_stats=st.booleans(), mixed_B=st.booleans(), + mpp_n_passes=st.sampled_from([None, 1, 6, 12]), + mpp_min_size=st.sampled_from([None, 1, 5, 10, 1024]), ) @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_cache_prefetch_pipeline_stream_2( @@ -482,6 +505,59 @@ def test_cache_prefetch_pipeline_stream_2( prefetch_stream=torch.cuda.Stream(), ) + @given( + S=st.sampled_from([0, 7, 100, 1024]), + mpp_n_passes=st.sampled_from([None, 1, 6, 12]), + mpp_min_size=st.sampled_from([None, 1, 5, 10, 128]), + ) + @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) + def test_get_prefetch_passes( + self, S: int, mpp_n_passes: Optional[int], mpp_min_size: Optional[int] + ) -> None: + mpp_conf: Optional[MultiPassPrefetchConfig] = None + if mpp_n_passes or mpp_min_size: + mpp_conf = MultiPassPrefetchConfig() + if mpp_n_passes: + mpp_conf = mpp_conf._replace(num_passes=mpp_n_passes) + if mpp_min_size: + mpp_conf = mpp_conf._replace(min_splitable_pass_size=mpp_min_size) + input_tensor = torch.randn(S) + output_tensor = torch.randn(S) + + ret = SplitTableBatchedEmbeddingBagsCodegen.get_prefetch_passes( + mpp_conf, input_tensor, output_tensor + ) + + if not mpp_conf: + self.assertEqual(len(ret), 1) + self.assertTrue(torch.equal(ret[0][0], input_tensor)) + self.assertTrue(torch.equal(ret[0][1], output_tensor)) + self.assertEqual(ret[0][2], 0) + return + + # Make sure the max passes is not exceeding the configured value + self.assertGreaterEqual(mpp_conf.num_passes, len(ret)) + + # Make sure the passes are having the right start offset. Also make sure + # every pass would not go below the configured min size (except for the + # last pass) + for idx, t in enumerate(ret): + i, o, s = t + if idx < len(ret) - 1: + self.assertGreaterEqual(i.numel(), mpp_conf.min_splitable_pass_size) + self.assertTrue(torch.equal(i, input_tensor[s : s + i.numel()])) + self.assertTrue(torch.equal(o, output_tensor[s : s + i.numel()])) + + # Make sure the returned passes are both non-overlapping and complete. We do + # this by settong the tensor to all zero, and increment them when visited + input_tensor.zero_() + output_tensor.zero_() + for i, o, _ in ret: + i.add_(1) + o.add_(1) + self.assertTrue(torch.equal(torch.full_like(input_tensor, 1), input_tensor)) + self.assertTrue(torch.equal(torch.full_like(output_tensor, 1), output_tensor)) + @unittest.skipIf(*gpu_unavailable) @given( L=st.integers(min_value=0, max_value=16),