Skip to content

Commit

Permalink
Add type error suppressions for upcoming upgrade (#3118)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3118

X-link: facebookresearch/FBGEMM#206

Reviewed By: MaggieMoss

Differential Revision: D62550144

fbshipit-source-id: 792f706e61bc2cec8703ca4102afabbc26848a3b
  • Loading branch information
generatedunixname89002005307016 authored and facebook-github-bot committed Sep 12, 2024
1 parent 528556c commit bcd01db
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 4 deletions.
28 changes: 27 additions & 1 deletion fbgemm_gpu/bench/sparse_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,21 @@ def gen_inverse_index(curr_size: int, final_size: int) -> np.array:

# Benchmark forward
time_ref, output_ref = benchmark_torch_function(
torch.index_select, (input, 0, offset_indices), **bench_kwargs
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
torch.index_select,
(input, 0, offset_indices),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

input_group = input.split(batch_size, 0)
time, output_group = benchmark_torch_function(
torch.ops.fbgemm.group_index_select_dim0,
(input_group, indices_group),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)
logging.info(
Expand All @@ -306,13 +314,19 @@ def gen_inverse_index(curr_size: int, final_size: int) -> np.array:
time_ref, _ = benchmark_torch_function(
functools.partial(output_ref.backward, retain_graph=True),
(grad,),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

# pyre-fixme[6]: For 1st argument expected `Union[List[Tensor],
# typing.Tuple[Tensor, ...]]` but got `Tensor`.
cat_output = torch.cat(output_group)
time, _ = benchmark_torch_function(
functools.partial(cat_output.backward, retain_graph=True),
(grad,),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)
logging.info(
Expand Down Expand Up @@ -714,6 +728,8 @@ def batch_group_index_select_bwd(
time_pyt, out_pyt = benchmark_torch_function(
index_select_fwd_ref,
(inputs, indices),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

Expand All @@ -726,12 +742,16 @@ def batch_group_index_select_bwd(
input_rows,
input_columns,
),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

time_gis, out_gis = benchmark_torch_function(
group_index_select_fwd,
(gis_inputs, indices),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

Expand All @@ -746,6 +766,8 @@ def batch_group_index_select_bwd(
time_bwd_pyt, _ = benchmark_torch_function(
index_select_bwd_ref,
(out_pyt, grads),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

Expand All @@ -756,6 +778,8 @@ def batch_group_index_select_bwd(
concat_grads,
optim_batch,
),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

Expand All @@ -766,6 +790,8 @@ def batch_group_index_select_bwd(
concat_grads,
optim_group,
),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

Expand Down
10 changes: 10 additions & 0 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,9 @@ def run_bench(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor) -> N

time_per_iter = benchmark_requests(
requests_uvm,
# pyre-fixme[6]: For 2nd argument expected `(Tensor, Tensor,
# Optional[Tensor]) -> Tensor` but got `(indices: Tensor, offsets: Tensor,
# per_sample_weights: Tensor) -> None`.
run_bench,
flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
num_warmups=warmup_runs,
Expand Down Expand Up @@ -1934,6 +1937,9 @@ def nbit_uvm(
indices,
offsets,
),
# pyre-fixme[6]: For 3rd argument expected `(Tensor, Tensor,
# Optional[Tensor]) -> None` but got `(indices: Any, offsets: Any,
# indices_weights: Any) -> Tensor`.
lambda indices, offsets, indices_weights: emb_mixed.forward(
indices,
offsets,
Expand Down Expand Up @@ -2421,6 +2427,9 @@ def nbit_cache( # noqa C901
indices,
offsets,
),
# pyre-fixme[6]: For 3rd argument expected `(Tensor, Tensor,
# Optional[Tensor]) -> None` but got `(indices: Any, offsets: Any,
# indices_weights: Any) -> Tensor`.
lambda indices, offsets, indices_weights: emb.forward(
indices,
offsets,
Expand Down Expand Up @@ -3061,6 +3070,7 @@ def device_with_spec( # noqa C901
reuse=reuse,
alpha=alpha,
weighted=weighted,
# pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined.
sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None,
zipf_oversample_ratio=3 if Ls[t] > 5 else 5,
)
Expand Down
5 changes: 4 additions & 1 deletion fbgemm_gpu/codegen/genscript/jinja_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,10 @@ def replace_pta_namespace(pta_str_list: List[str]) -> List[str]:


def replace_placeholder_types(
arg_str_list: List[str], type_combo: Optional[Dict[str, TensorType]]
# pyre-fixme[11]: Annotation `TensorType` is not defined as a type.
arg_str_list: List[str],
# pyre-fixme[11]: Annotation `TensorType` is not defined as a type.
type_combo: Optional[Dict[str, TensorType]],
) -> List[str]:
"""
Replace the placeholder types with the primitive types
Expand Down
21 changes: 20 additions & 1 deletion fbgemm_gpu/codegen/genscript/optimizer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,11 @@ def schema_sym_int_arg_no_default(name: str) -> str:


def make_kernel_arg(
ty: ArgType, name: str, default: Union[int, float, None], pass_by_ref: bool = False
# pyre-fixme[11]: Annotation `ArgType` is not defined as a type.
ty: ArgType,
name: str,
default: Union[int, float, None],
pass_by_ref: bool = False,
) -> str:
return {
ArgType.TENSOR: lambda x: acc_cache_tensor_arg(x, pass_by_ref=pass_by_ref),
Expand Down Expand Up @@ -318,6 +322,7 @@ class OptimizerArgs:
split_variables: List[str]
split_ref_kernel_args: List[str]
placeholder_tensor_names: List[str]
# pyre-fixme[11]: Annotation `TensorType` is not defined as a type.
placeholder_type_combos: Union[List[Dict[str, TensorType]], List[None]]

@staticmethod
Expand Down Expand Up @@ -345,6 +350,7 @@ def create(
else:
ph_combos = [None]

# pyre-fixme[28]: Unexpected keyword argument `placeholder_type_combos`.
return OptimizerArgs(
# GPU kernel args
split_kernel_args=[
Expand Down Expand Up @@ -434,6 +440,7 @@ def create_optim_args(
split_arg_spec = []
for s in arg_spec:
if s.ty in (ArgType.FLOAT, ArgType.INT, ArgType.SYM_INT):
# pyre-fixme[19]: Expected 1 positional argument.
split_arg_spec.append(OptimItem(s.ty, s.name, s.default))
else:
assert s.ty in (ArgType.TENSOR, ArgType.PLACEHOLDER_TENSOR)
Expand All @@ -446,8 +453,11 @@ def extend_for_cpu(spec: OptimItem) -> List[OptimItem]:
name = spec.name
default = spec.default
return [
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ArgType.TENSOR, f"{name}_host", default),
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ArgType.INT_TENSOR, f"{name}_placements", default),
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ArgType.LONG_TENSOR, f"{name}_offsets", default),
]

Expand All @@ -458,9 +468,13 @@ def extend_for_cuda(spec: OptimItem) -> List[OptimItem]:
ty = spec.ty
ph_tys = spec.ph_tys
return [
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ty, f"{name}_dev", default, ph_tys),
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ty, f"{name}_uvm", default, ph_tys),
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ArgType.INT_TENSOR, f"{name}_placements", default),
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ArgType.LONG_TENSOR, f"{name}_offsets", default),
]

Expand All @@ -471,10 +485,15 @@ def extend_for_any(spec: OptimItem) -> List[OptimItem]:
ty = spec.ty
ph_tys = spec.ph_tys
return [
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ArgType.TENSOR, f"{name}_host", default),
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ty, f"{name}_dev", default, ph_tys),
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ty, f"{name}_uvm", default, ph_tys),
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ArgType.INT_TENSOR, f"{name}_placements", default),
# pyre-fixme[19]: Expected 1 positional argument.
OptimItem(ArgType.LONG_TENSOR, f"{name}_offsets", default),
]

Expand Down
6 changes: 5 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/quantize_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@


def none_throws(
optional: Optional[TypeVar("_T")], message: str = "Unexpected `None`"
# pyre-fixme[31]: Expression `typing.Optional[typing.TypeVar("_T")]` is not a
# valid type.
optional: Optional[TypeVar("_T")],
message: str = "Unexpected `None`",
# pyre-fixme[31]: Expression `typing.TypeVar("_T")` is not a valid type.
) -> TypeVar("_T"):
if optional is None:
raise AssertionError(message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,15 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):

embedding_specs: List[Tuple[str, int, int, SparseType, EmbeddingLocation]]
record_cache_metrics: RecordCacheMetrics
# pyre-fixme[13]: Attribute `cache_miss_counter` is never initialized.
cache_miss_counter: torch.Tensor
# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
uvm_cache_stats: torch.Tensor
# pyre-fixme[13]: Attribute `local_uvm_cache_stats` is never initialized.
local_uvm_cache_stats: torch.Tensor
# pyre-fixme[13]: Attribute `weights_offsets` is never initialized.
weights_offsets: torch.Tensor
# pyre-fixme[13]: Attribute `weights_placements` is never initialized.
weights_placements: torch.Tensor

def __init__( # noqa C901
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
lxu_cache_locations_empty: Tensor
timesteps_prefetched: List[int]
record_cache_metrics: RecordCacheMetrics
# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
uvm_cache_stats: torch.Tensor
# pyre-fixme[13]: Attribute `local_uvm_cache_stats` is never initialized.
local_uvm_cache_stats: torch.Tensor
uuid: str
# pyre-fixme[13]: Attribute `last_uvm_cache_print_state` is never initialized.
last_uvm_cache_print_state: torch.Tensor
_vbe_B_offsets: Optional[torch.Tensor]
_vbe_max_B: int
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,7 @@ def forward(
offsets: Tensor,
per_sample_weights: Optional[Tensor] = None,
feature_requires_grad: Optional[Tensor] = None,
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
) -> Tensor:
indices, offsets, per_sample_weights = self.prepare_inputs(
indices, offsets, per_sample_weights
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/utils/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def dequantize_embs(
weight_ty: SparseType,
use_cpu: bool,
fp8_config: Optional[FP8QuantizationConfig] = None,
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
) -> torch.Tensor:
print(f"weight_ty: {weight_ty}")
assert (
Expand Down

0 comments on commit bcd01db

Please sign in to comment.