Skip to content

Commit

Permalink
apply Black 2024 style in fbcode (10/16)
Browse files Browse the repository at this point in the history
Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: aleivag

Differential Revision: D54447733

fbshipit-source-id: 11ac742489579bb1dfec025514aa956159cf4959
  • Loading branch information
amyreese authored and facebook-github-bot committed Mar 3, 2024
1 parent 228c65c commit 42753de
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions fbgemm_gpu/codegen/embedding_common_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,16 @@ def make_kernel_arg(
TENSOR: lambda x: acc_cache_tensor_arg(x, pass_by_ref=pass_by_ref),
INT_TENSOR: lambda x: int_tensor_arg(x, pass_by_ref=pass_by_ref),
LONG_TENSOR: lambda x: long_tensor_arg(x, pass_by_ref=pass_by_ref),
INT: (lambda x: int64_arg(x, default=int(default)))
if default is not None
else int64_arg_no_default,
FLOAT: (lambda x: float_arg(x, default=default))
if default is not None
else float_arg_no_default,
INT: (
(lambda x: int64_arg(x, default=int(default)))
if default is not None
else int64_arg_no_default
),
FLOAT: (
(lambda x: float_arg(x, default=default))
if default is not None
else float_arg_no_default
),
}[ty](name)

def make_kernel_arg_constructor(ty: int, name: str) -> str:
Expand Down Expand Up @@ -260,12 +264,16 @@ def make_function_arg(
TENSOR: tensor_arg,
INT_TENSOR: tensor_arg,
LONG_TENSOR: tensor_arg,
INT: (lambda x: int64_arg(x, default=int(default)))
if default is not None
else int64_arg_no_default,
FLOAT: (lambda x: double_arg(x, default=default))
if default is not None
else double_arg_no_default,
INT: (
(lambda x: int64_arg(x, default=int(default)))
if default is not None
else int64_arg_no_default
),
FLOAT: (
(lambda x: double_arg(x, default=default))
if default is not None
else double_arg_no_default
),
}[ty](name)

def make_function_schema_arg(ty: int, name: str, default: Union[int, float]) -> str:
Expand Down Expand Up @@ -342,7 +350,7 @@ def make_args_for_compute_device(
]

split_arg_spec = []
for (ty, arg, default) in augmented_arg_spec:
for ty, arg, default in augmented_arg_spec:
if ty in (FLOAT, INT):
split_arg_spec.append((ty, arg, default))
else:
Expand All @@ -357,7 +365,7 @@ def make_args_for_compute_device(
cpu = make_args_for_compute_device(split_arg_spec)

split_arg_spec = []
for (ty, arg, default) in augmented_arg_spec:
for ty, arg, default in augmented_arg_spec:
if ty in (FLOAT, INT):
split_arg_spec.append((ty, arg, default))
else:
Expand Down

0 comments on commit 42753de

Please sign in to comment.