Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 19, 2024
1 parent 5a4966c commit e414cb2
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 12 deletions.
8 changes: 2 additions & 6 deletions test/test_grad/test_dedr.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@
tol = 1e-8


def gradchecker(
dtype: torch.dtype, name: str
) -> tuple[
def gradchecker(dtype: torch.dtype, name: str) -> tuple[
Callable[[Tensor], Tensor], # autograd function
Tensor, # differentiable variables
]:
Expand Down Expand Up @@ -92,9 +90,7 @@ def test_gradgradcheck(dtype: torch.dtype, name: str) -> None:
assert dgradgradcheck(func, diffvars, atol=tol, fast_mode=FAST_MODE)


def gradchecker_batch(
dtype: torch.dtype, name1: str, name2: str
) -> tuple[
def gradchecker_batch(dtype: torch.dtype, name1: str, name2: str) -> tuple[
Callable[[Tensor], Tensor], # autograd function
Tensor, # differentiable variables
]:
Expand Down
8 changes: 2 additions & 6 deletions test/test_grad/test_dqdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@
tol = 1e-8


def gradchecker(
dtype: torch.dtype, name: str
) -> tuple[
def gradchecker(dtype: torch.dtype, name: str) -> tuple[
Callable[[Tensor], Tensor], # autograd function
Tensor, # differentiable variables
]:
Expand Down Expand Up @@ -92,9 +90,7 @@ def test_gradgradcheck(dtype: torch.dtype, name: str) -> None:
assert dgradgradcheck(func, diffvars, atol=tol, fast_mode=FAST_MODE)


def gradchecker_batch(
dtype: torch.dtype, name1: str, name2: str
) -> tuple[
def gradchecker_batch(dtype: torch.dtype, name1: str, name2: str) -> tuple[
Callable[[Tensor], Tensor], # autograd function
Tensor, # differentiable variables
]:
Expand Down

0 comments on commit e414cb2

Please sign in to comment.