From 07c3f6bd9a5a5aee9ce96f624400f290e2dafec3 Mon Sep 17 00:00:00 2001 From: generatedunixname89002005307016 Date: Wed, 12 Jul 2023 13:47:46 -0700 Subject: [PATCH] suppress errors in `deeplearning/fbgemm/fbgemm_gpu` Differential Revision: D47398494 fbshipit-source-id: bc4ee6710be8309f1235f49383f62f1b85aef573 --- fbgemm_gpu/bench/sparse_ops_benchmark.py | 2 - .../bench/split_embeddings_cache_benchmark.py | 4 - ...plit_table_batched_embeddings_benchmark.py | 18 +-- ..._table_batched_embeddings_ops_inference.py | 12 -- ...t_table_batched_embeddings_ops_training.py | 123 ------------------ fbgemm_gpu/test/layout_transform_ops_test.py | 4 - .../split_table_batched_embeddings_test.py | 18 --- 7 files changed, 2 insertions(+), 179 deletions(-) diff --git a/fbgemm_gpu/bench/sparse_ops_benchmark.py b/fbgemm_gpu/bench/sparse_ops_benchmark.py index 3e9e9ca5e..f97adc0ba 100644 --- a/fbgemm_gpu/bench/sparse_ops_benchmark.py +++ b/fbgemm_gpu/bench/sparse_ops_benchmark.py @@ -99,7 +99,6 @@ def gen_inverse_index(curr_size: int, final_size: int) -> np.array: else: raise RuntimeError(f"Does not support data type {input_precision}") - # pyre-fixme[16]: Module `cuda` has no attribute `IntTensor`. indices = torch.cuda.IntTensor(gen_inverse_index(unique_batch_size, batch_size)) input = torch.rand(unique_batch_size, row_size, dtype=dtype, device="cuda") @@ -260,7 +259,6 @@ def gen_inverse_index(curr_size: int, final_size: int) -> np.array: offset_indices_group = [] indices_group = [] for i in range(num_groups): - # pyre-fixme[16]: Module `cuda` has no attribute `IntTensor`. indices = torch.cuda.IntTensor(gen_inverse_index(unique_batch_size, batch_size)) if sort_indices: indices, _ = indices.sort() diff --git a/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py b/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py index 61d9009d9..e7b836aa4 100644 --- a/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py +++ b/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py @@ -416,10 +416,8 @@ def replay_populate(linear_indices: Tensor) -> None: total_rows = 0 for request in requests: - # pyre-ignore prev = replay_cc.lxu_cache_state.clone().detach() replay_populate(request) - # pyre-ignore after = replay_cc.lxu_cache_state.clone().detach() diff = after - prev @@ -535,10 +533,8 @@ def replay_populate(linear_indices: Tensor) -> None: total_rows = 0 for request in requests: - # pyre-ignore prev = replay_cc.lxu_cache_state.clone().detach() replay_populate(request) - # pyre-ignore after = replay_cc.lxu_cache_state.clone().detach() diff = after - prev diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index 66ca1520e..a1a0852a5 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -693,17 +693,10 @@ def cache( # noqa C901 exchanged_cache_lines = [] NOT_FOUND = -1 for indices, offsets, _ in requests: - # pyre-fixme[29]: - # `Union[BoundMethod[typing.Callable(Tensor.clone)[[Named(self, - # Variable[torch._TTensor (bound to Tensor)])], Variable[torch._TTensor (bound - # to Tensor)]], Tensor], Tensor, torch.nn.Module]` is not a function. old_lxu_cache_state = emb.lxu_cache_state.clone() emb.prefetch(indices.long(), offsets.long()) exchanged_cache_lines.append( - # pyre-fixme[16]: `bool` has no attribute `sum`. - (emb.lxu_cache_state != old_lxu_cache_state) - .sum() - .item() + (emb.lxu_cache_state != old_lxu_cache_state).sum().item() ) cache_misses.append((emb.lxu_cache_locations_list[0] == NOT_FOUND).sum().item()) emb.forward(indices.long(), offsets.long()) @@ -2064,17 +2057,10 @@ def nbit_cache( # noqa C901 emb.reset_uvm_cache_stats() for indices, offsets, _ in requests: - # pyre-fixme[29]: - # `Union[BoundMethod[typing.Callable(Tensor.clone)[[Named(self, - # Variable[torch._TTensor (bound to Tensor)])], Variable[torch._TTensor (bound - # to Tensor)]], Tensor], Tensor, torch.nn.Module]` is not a function. old_lxu_cache_state = emb.lxu_cache_state.clone() emb.prefetch(indices, offsets) exchanged_cache_lines.append( - # pyre-fixme[16]: `bool` has no attribute `sum`. - (emb.lxu_cache_state != old_lxu_cache_state) - .sum() - .item() + (emb.lxu_cache_state != old_lxu_cache_state).sum().item() ) cache_misses.append( (emb.lxu_cache_locations_list.top() == NOT_FOUND).sum().item() diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index 2cf630115..cc3b0aab7 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -490,9 +490,6 @@ def print_uvm_cache_stats(self) -> None: def prefetch(self, indices: Tensor, offsets: Tensor) -> None: self.timestep_counter.increment() self.timestep_prefetch_size.increment() - # pyre-fixme[29]: - # `Union[BoundMethod[typing.Callable(Tensor.numel)[[Named(self, Tensor)], - # int], Tensor], Tensor, nn.Module]` is not a function. if not self.lxu_cache_weights.numel(): return @@ -669,9 +666,6 @@ def _update_tablewise_cache_miss( CACHE_MISS = torch.tensor([-1], device=self.current_device, dtype=torch.int32) CACHE_HIT = torch.tensor([-2], device=self.current_device, dtype=torch.int32) - # pyre-ignore[6]: - # Incompatible parameter type [6]: Expected `typing.Sized` for 1st - # positional only parameter to call `len` but got `typing.Union[Tensor, nn.Module]`. num_tables = len(self.cache_hash_size_cumsum) - 1 num_offsets_per_table = (len(offsets) - 1) // num_tables cache_missed_locations = torch.where( @@ -1128,9 +1122,6 @@ def _apply_cache_state( self.reset_uvm_cache_stats() def reset_cache_states(self) -> None: - # pyre-fixme[29]: - # `Union[BoundMethod[typing.Callable(Tensor.numel)[[Named(self, Tensor)], - # int], Tensor], Tensor, nn.Module]` is not a function. if not self.lxu_cache_weights.numel(): return self.lxu_cache_state.fill_(-1) @@ -1500,9 +1491,6 @@ def embedding_inplace_update_internal( ) lxu_cache_locations = None - # pyre-fixme[29]: - # `Union[BoundMethod[typing.Callable(Tensor.numel)[[Named(self, Tensor)], - # int], Tensor], Tensor, nn.Module]` is not a function. if self.lxu_cache_weights.numel() > 0: linear_cache_indices = ( torch.ops.fbgemm.linearize_cache_indices_from_row_idx( diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 17db584bb..1fb830a52 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -803,8 +803,6 @@ def get_cache_miss_counter(self) -> Tensor: # The first one is cache_miss_forward_count which records the total number of forwards which has at least one cache miss # The second one is the unique_cache_miss_count which records to total number of unique (dedup) cache misses - # pyre-fixme[7]: Expected `Tensor` but got `typing.Union[Tensor, - # nn.Module]`. return self.cache_miss_counter @torch.jit.export @@ -930,23 +928,11 @@ def forward( # noqa: C901 ) common_args = invokers.lookup_args.CommonArgs( placeholder_autograd_tensor=self.placeholder_autograd_tensor, - # pyre-fixme[6]: Expected `Tensor` for 2nd param but got `Union[Tensor, - # nn.Module]`. dev_weights=self.weights_dev, - # pyre-fixme[6]: Expected `Tensor` for 3rd param but got `Union[Tensor, - # nn.Module]`. host_weights=self.weights_host, - # pyre-fixme[6]: Expected `Tensor` for 4th param but got `Union[Tensor, - # nn.Module]`. uvm_weights=self.weights_uvm, - # pyre-fixme[6]: Expected `Tensor` for 5th param but got `Union[Tensor, - # nn.Module]`. lxu_cache_weights=self.lxu_cache_weights, - # pyre-fixme[6]: Expected `Tensor` for 6th param but got `Union[Tensor, - # nn.Module]`. weights_placements=self.weights_placements, - # pyre-fixme[6]: Expected `Tensor` for 7th param but got `Union[Tensor, - # nn.Module]`. weights_offsets=self.weights_offsets, D_offsets=self.D_offsets, total_D=self.total_D, @@ -976,20 +962,10 @@ def forward( # noqa: C901 return invokers.lookup_sgd.invoke(common_args, self.optimizer_args) momentum1 = invokers.lookup_args.Momentum( - # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[Tensor, - # nn.Module]`. dev=self.momentum1_dev, - # pyre-fixme[6]: Expected `Tensor` for 2nd param but got `Union[Tensor, - # nn.Module]`. host=self.momentum1_host, - # pyre-fixme[6]: Expected `Tensor` for 3rd param but got `Union[Tensor, - # nn.Module]`. uvm=self.momentum1_uvm, - # pyre-fixme[6]: Expected `Tensor` for 4th param but got `Union[Tensor, - # nn.Module]`. offsets=self.momentum1_offsets, - # pyre-fixme[6]: Expected `Tensor` for 5th param but got `Union[Tensor, - # nn.Module]`. placements=self.momentum1_placements, ) @@ -1003,20 +979,10 @@ def forward( # noqa: C901 ) momentum2 = invokers.lookup_args.Momentum( - # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[Tensor, - # nn.Module]`. dev=self.momentum2_dev, - # pyre-fixme[6]: Expected `Tensor` for 2nd param but got `Union[Tensor, - # nn.Module]`. host=self.momentum2_host, - # pyre-fixme[6]: Expected `Tensor` for 3rd param but got `Union[Tensor, - # nn.Module]`. uvm=self.momentum2_uvm, - # pyre-fixme[6]: Expected `Tensor` for 4th param but got `Union[Tensor, - # nn.Module]`. offsets=self.momentum2_offsets, - # pyre-fixme[6]: Expected `Tensor` for 5th param but got `Union[Tensor, - # nn.Module]`. placements=self.momentum2_placements, ) # Ensure iter is always on CPU so the increment doesn't synchronize. @@ -1075,37 +1041,17 @@ def forward( # noqa: C901 ) prev_iter = invokers.lookup_args.Momentum( - # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[Tensor, - # nn.Module]`. dev=self.prev_iter_dev, - # pyre-fixme[6]: Expected `Tensor` for 2nd param but got `Union[Tensor, - # nn.Module]`. host=self.prev_iter_host, - # pyre-fixme[6]: Expected `Tensor` for 3rd param but got `Union[Tensor, - # nn.Module]`. uvm=self.prev_iter_uvm, - # pyre-fixme[6]: Expected `Tensor` for 4th param but got `Union[Tensor, - # nn.Module]`. offsets=self.prev_iter_offsets, - # pyre-fixme[6]: Expected `Tensor` for 5th param but got `Union[Tensor, - # nn.Module]`. placements=self.prev_iter_placements, ) row_counter = invokers.lookup_args.Momentum( - # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[Tensor, - # nn.Module]`. dev=self.row_counter_dev, - # pyre-fixme[6]: Expected `Tensor` for 2nd param but got `Union[Tensor, - # nn.Module]`. host=self.row_counter_host, - # pyre-fixme[6]: Expected `Tensor` for 3rd param but got `Union[Tensor, - # nn.Module]`. uvm=self.row_counter_uvm, - # pyre-fixme[6]: Expected `Tensor` for 4th param but got `Union[Tensor, - # nn.Module]`. offsets=self.row_counter_offsets, - # pyre-fixme[6]: Expected `Tensor` for 5th param but got `Union[Tensor, - # nn.Module]`. placements=self.row_counter_placements, ) if self._used_rowwise_adagrad_with_counter: @@ -1170,9 +1116,6 @@ def print_uvm_cache_stats(self) -> None: def prefetch(self, indices: Tensor, offsets: Tensor) -> None: self.timestep += 1 self.timesteps_prefetched.append(self.timestep) - # pyre-fixme[29]: - # `Union[BoundMethod[typing.Callable(Tensor.numel)[[Named(self, Tensor)], - # int], Tensor], Tensor, nn.Module]` is not a function. if not self.lxu_cache_weights.numel(): return @@ -1272,16 +1215,8 @@ def _update_cache_miss_counter( miss_count = torch.sum(unique_ids_count_list) - # pyre-fixme[29]: - # `Union[BoundMethod[typing.Callable(Tensor.__getitem__)[[Named(self, - # Tensor), Named(item, typing.Any)], typing.Any], Tensor], Tensor, - # nn.Module]` is not a function. self.cache_miss_counter[0] += (miss_count > 0).to(torch.int64) - # pyre-fixme[29]: - # `Union[BoundMethod[typing.Callable(Tensor.__getitem__)[[Named(self, - # Tensor), Named(item, typing.Any)], typing.Any], Tensor], Tensor, - # nn.Module]` is not a function. self.cache_miss_counter[1] += miss_count def _update_tablewise_cache_miss( @@ -1293,9 +1228,6 @@ def _update_tablewise_cache_miss( CACHE_MISS = -1 CACHE_HIT = -2 - # pyre-ignore[6]: - # Incompatible parameter type [6]: Expected `typing.Sized` for 1st - # positional only parameter to call `len` but got `typing.Union[Tensor, nn.Module]`. num_tables = len(self.cache_hash_size_cumsum) - 1 num_offsets_per_table = (len(offsets) - 1) // num_tables cache_missed_locations = torch.where( @@ -1340,15 +1272,7 @@ def split_embedding_weights(self) -> List[Tensor]: for t, (rows, dim, _, _) in enumerate(self.embedding_specs): if self.weights_precision == SparseType.INT8: dim += self.int8_emb_row_dim_offset - # pyre-fixme[29]: - # `Union[BoundMethod[typing.Callable(Tensor.__getitem__)[[Named(self, - # Tensor), Named(item, typing.Any)], typing.Any], Tensor], Tensor, - # nn.Module]` is not a function. placement = self.weights_physical_placements[t] - # pyre-fixme[29]: - # `Union[BoundMethod[typing.Callable(Tensor.__getitem__)[[Named(self, - # Tensor), Named(item, typing.Any)], typing.Any], Tensor], Tensor, - # nn.Module]` is not a function. offset = self.weights_physical_offsets[t] if placement == EmbeddingLocation.DEVICE.value: weights = self.weights_dev @@ -1356,7 +1280,6 @@ def split_embedding_weights(self) -> List[Tensor]: weights = self.weights_host else: weights = self.weights_uvm - # pyre-ignore[29] if weights.dim() == 2: weights = weights.flatten() splits.append( @@ -1456,20 +1379,10 @@ def get_optimizer_states( if self.optimizer not in (OptimType.EXACT_SGD,): states.append( get_optimizer_states( - # pyre-fixme[6]: Expected `Tensor` for 1st param but got - # `Union[Tensor, nn.Module]`. self.momentum1_dev, - # pyre-fixme[6]: Expected `Tensor` for 2nd param but got - # `Union[Tensor, nn.Module]`. self.momentum1_host, - # pyre-fixme[6]: Expected `Tensor` for 3rd param but got - # `Union[Tensor, nn.Module]`. self.momentum1_uvm, - # pyre-fixme[6]: Expected `Tensor` for 4th param but got - # `Union[Tensor, nn.Module]`. self.momentum1_physical_offsets, - # pyre-fixme[6]: Expected `Tensor` for 5th param but got - # `Union[Tensor, nn.Module]`. self.momentum1_physical_placements, rowwise=self.optimizer in [ @@ -1486,20 +1399,10 @@ def get_optimizer_states( ): states.append( get_optimizer_states( - # pyre-fixme[6]: Expected `Tensor` for 1st param but got - # `Union[Tensor, nn.Module]`. self.momentum2_dev, - # pyre-fixme[6]: Expected `Tensor` for 2nd param but got - # `Union[Tensor, nn.Module]`. self.momentum2_host, - # pyre-fixme[6]: Expected `Tensor` for 3rd param but got - # `Union[Tensor, nn.Module]`. self.momentum2_uvm, - # pyre-fixme[6]: Expected `Tensor` for 4th param but got - # `Union[Tensor, nn.Module]`. self.momentum2_physical_offsets, - # pyre-fixme[6]: Expected `Tensor` for 5th param but got - # `Union[Tensor, nn.Module]`. self.momentum2_physical_placements, rowwise=self.optimizer in (OptimType.PARTIAL_ROWWISE_ADAM, OptimType.PARTIAL_ROWWISE_LAMB), @@ -1508,40 +1411,20 @@ def get_optimizer_states( if self._used_rowwise_adagrad_with_counter: states.append( get_optimizer_states( - # pyre-fixme[6]: Expected `Tensor` for 1st param but got - # `Union[Tensor, nn.Module]`. self.prev_iter_dev, - # pyre-fixme[6]: Expected `Tensor` for 2nd param but got - # `Union[Tensor, nn.Module]`. self.prev_iter_host, - # pyre-fixme[6]: Expected `Tensor` for 3rd param but got - # `Union[Tensor, nn.Module]`. self.prev_iter_uvm, - # pyre-fixme[6]: Expected `Tensor` for 4th param but got - # `Union[Tensor, nn.Module]`. self.prev_iter_physical_offsets, - # pyre-fixme[6]: Expected `Tensor` for 5th param but got - # `Union[Tensor, nn.Module]`. self.prev_iter_physical_placements, rowwise=True, ) ) states.append( get_optimizer_states( - # pyre-fixme[6]: Expected `Tensor` for 1st param but got - # `Union[Tensor, nn.Module]`. self.row_counter_dev, - # pyre-fixme[6]: Expected `Tensor` for 2nd param but got - # `Union[Tensor, nn.Module]`. self.row_counter_host, - # pyre-fixme[6]: Expected `Tensor` for 3rd param but got - # `Union[Tensor, nn.Module]`. self.row_counter_uvm, - # pyre-fixme[6]: Expected `Tensor` for 4th param but got - # `Union[Tensor, nn.Module]`. self.row_counter_physical_offsets, - # pyre-fixme[6]: Expected `Tensor` for 5th param but got - # `Union[Tensor, nn.Module]`. self.row_counter_physical_placements, rowwise=True, ) @@ -1582,9 +1465,6 @@ def set_optimizer_step(self, step: int) -> None: @torch.jit.export def flush(self) -> None: - # pyre-fixme[29]: - # `Union[BoundMethod[typing.Callable(Tensor.numel)[[Named(self, Tensor)], - # int], Tensor], Tensor, nn.Module]` is not a function. if not self.lxu_cache_weights.numel(): return torch.ops.fbgemm.lxu_cache_flush( @@ -1816,9 +1696,6 @@ def _init_uvm_cache_stats(self) -> None: self.reset_uvm_cache_stats() def reset_cache_states(self) -> None: - # pyre-fixme[29]: - # `Union[BoundMethod[typing.Callable(Tensor.numel)[[Named(self, Tensor)], - # int], Tensor], Tensor, nn.Module]` is not a function. if not self.lxu_cache_weights.numel(): return self.lxu_cache_state.fill_(-1) diff --git a/fbgemm_gpu/test/layout_transform_ops_test.py b/fbgemm_gpu/test/layout_transform_ops_test.py index 4bfc603d6..37bd53377 100644 --- a/fbgemm_gpu/test/layout_transform_ops_test.py +++ b/fbgemm_gpu/test/layout_transform_ops_test.py @@ -122,9 +122,7 @@ def test_recat_embedding_grad_output_mixed_D_batch(self, B: int, W: int) -> None ) for i in range(W) ] - # pyre-fixme[16]: Module `cuda` has no attribute `LongTensor`. dim_sum_per_rank_tensor = torch.cuda.LongTensor(dim_sum_per_rank) - # pyre-fixme[16]: Module `cuda` has no attribute `LongTensor`. cumsum_dim_sum_per_rank_tensor = torch.cuda.LongTensor( np.cumsum([0] + dim_sum_per_rank)[:-1] ) @@ -162,9 +160,7 @@ def test_recat_embedding_grad_output_mixed_D_batch(self, B: int, W: int) -> None ) for i in range(W) ] - # pyre-fixme[16]: Module `cuda` has no attribute `LongTensor`. dim_sum_per_rank_tensor = torch.cuda.LongTensor(dim_sum_per_rank) - # pyre-fixme[16]: Module `cuda` has no attribute `LongTensor`. cumsum_dim_sum_per_rank_tensor = torch.cuda.LongTensor( np.cumsum([0] + dim_sum_per_rank)[:-1] ) diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index f2b7b9756..fdeb5b661 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -1785,12 +1785,9 @@ def execute_backward_none_( # noqa C901 fc2.backward(goc) if optimizer is not None: - # pyre-ignore[6] params = SplitEmbeddingOptimizerParams(weights_dev=cc.weights_dev) embedding_args = SplitEmbeddingArgs( - # pyre-ignore[6] weights_placements=cc.weights_placements, - # pyre-ignore[6] weights_offsets=cc.weights_offsets, max_D=cc.max_D, ) @@ -1821,7 +1818,6 @@ def execute_backward_none_( # noqa C901 ref_grad.half() if weights_precision == SparseType.FP16 else ref_grad ) else: - # pyre-ignore[16] indices = cc.weights_dev.grad._indices().flatten() # Select only the part in the table that is updated test_tensor = torch.index_select(cc.weights_dev.view(-1, D), 0, indices) @@ -4620,13 +4616,9 @@ def test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function( # cache status; we use the exact same logic, but still assigning ways in a associative cache can be # arbitrary. We compare sum along ways in each set, instead of expecting exact tensor match. cache_weights_ref = torch.reshape( - # pyre-fixme[6]: For 1st param expected `Tensor` but got - # `Union[Tensor, Module]`. cc_ref.lxu_cache_weights, [-1, associativity], ) - # pyre-fixme[6]: For 1st param expected `Tensor` but got `Union[Tensor, - # Module]`. cache_weights = torch.reshape(cc.lxu_cache_weights, [-1, associativity]) torch.testing.assert_close( torch.sum(cache_weights_ref, 1), @@ -4634,26 +4626,16 @@ def test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function( equal_nan=True, ) torch.testing.assert_close( - # pyre-fixme[6]: For 1st param expected `Tensor` but got - # `Union[Tensor, Module]`. torch.sum(cc.lxu_cache_state, 1), - # pyre-fixme[6]: For 1st param expected `Tensor` but got - # `Union[Tensor, Module]`. torch.sum(cc_ref.lxu_cache_state, 1), equal_nan=True, ) # lxu_state can be different as time_stamp values can be different. # we check the entries with max value. - # pyre-fixme[6]: For 1st param expected `Tensor` but got `Union[Tensor, - # Module]`. max_timestamp_ref = torch.max(cc_ref.lxu_state) - # pyre-fixme[6]: For 1st param expected `Tensor` but got `Union[Tensor, - # Module]`. max_timestamp_uvm_caching = torch.max(cc.lxu_state) x = cc_ref.lxu_state == max_timestamp_ref y = cc.lxu_state == max_timestamp_uvm_caching - # pyre-fixme[6]: For 1st param expected `Tensor` but got `Union[bool, - # Tensor]`. torch.testing.assert_close(torch.sum(x, 1), torch.sum(y, 1)) # int_nbit_split_embedding_uvm_caching_codegen_lookup_function for UVM.