diff --git a/fbgemm_gpu/codegen/embedding_common_code_generator.py b/fbgemm_gpu/codegen/embedding_common_code_generator.py index 0a6f4bedb..2a4ea4bc9 100644 --- a/fbgemm_gpu/codegen/embedding_common_code_generator.py +++ b/fbgemm_gpu/codegen/embedding_common_code_generator.py @@ -218,12 +218,16 @@ def make_kernel_arg( TENSOR: lambda x: acc_cache_tensor_arg(x, pass_by_ref=pass_by_ref), INT_TENSOR: lambda x: int_tensor_arg(x, pass_by_ref=pass_by_ref), LONG_TENSOR: lambda x: long_tensor_arg(x, pass_by_ref=pass_by_ref), - INT: (lambda x: int64_arg(x, default=int(default))) - if default is not None - else int64_arg_no_default, - FLOAT: (lambda x: float_arg(x, default=default)) - if default is not None - else float_arg_no_default, + INT: ( + (lambda x: int64_arg(x, default=int(default))) + if default is not None + else int64_arg_no_default + ), + FLOAT: ( + (lambda x: float_arg(x, default=default)) + if default is not None + else float_arg_no_default + ), }[ty](name) def make_kernel_arg_constructor(ty: int, name: str) -> str: @@ -260,12 +264,16 @@ def make_function_arg( TENSOR: tensor_arg, INT_TENSOR: tensor_arg, LONG_TENSOR: tensor_arg, - INT: (lambda x: int64_arg(x, default=int(default))) - if default is not None - else int64_arg_no_default, - FLOAT: (lambda x: double_arg(x, default=default)) - if default is not None - else double_arg_no_default, + INT: ( + (lambda x: int64_arg(x, default=int(default))) + if default is not None + else int64_arg_no_default + ), + FLOAT: ( + (lambda x: double_arg(x, default=default)) + if default is not None + else double_arg_no_default + ), }[ty](name) def make_function_schema_arg(ty: int, name: str, default: Union[int, float]) -> str: @@ -342,7 +350,7 @@ def make_args_for_compute_device( ] split_arg_spec = [] - for (ty, arg, default) in augmented_arg_spec: + for ty, arg, default in augmented_arg_spec: if ty in (FLOAT, INT): split_arg_spec.append((ty, arg, default)) else: @@ -357,7 +365,7 @@ def make_args_for_compute_device( cpu = make_args_for_compute_device(split_arg_spec) split_arg_spec = [] - for (ty, arg, default) in augmented_arg_spec: + for ty, arg, default in augmented_arg_spec: if ty in (FLOAT, INT): split_arg_spec.append((ty, arg, default)) else: