Skip to content

Commit

Permalink
Merge pull request #64 from Balandat/catch_lincg_test_warnings
Browse files Browse the repository at this point in the history
Catch some warnings in linear_cg tests
  • Loading branch information
Balandat authored May 30, 2023
2 parents 46b08fc + 66dc5ec commit f020146
Showing 1 changed file with 24 additions and 18 deletions.
42 changes: 24 additions & 18 deletions test/utils/test_linear_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import os
import random
import unittest
import warnings

import torch

from linear_operator.utils.linear_cg import linear_cg
from linear_operator.utils.warnings import NumericalWarning


class TestLinearCG(unittest.TestCase):
Expand Down Expand Up @@ -69,15 +71,17 @@ def test_cg_with_tridiag(self):
matrix.add_(torch.eye(matrix.size(-1), dtype=torch.float64).mul_(1e-1))

rhs = torch.randn(size, 50, dtype=torch.float64)
solves, t_mats = linear_cg(
matrix.matmul,
rhs=rhs,
n_tridiag=5,
max_tridiag_iter=10,
max_iter=size,
tolerance=0,
eps=1e-15,
)
with warnings.catch_warnings(record=True) as ws:
solves, t_mats = linear_cg(
matrix.matmul,
rhs=rhs,
n_tridiag=5,
max_tridiag_iter=10,
max_iter=size,
tolerance=0,
eps=1e-15,
)
self.assertTrue(any(issubclass(w.category, NumericalWarning) for w in ws))

# Check cg
matrix_chol = torch.linalg.cholesky(matrix)
Expand Down Expand Up @@ -115,15 +119,17 @@ def test_batch_cg_with_tridiag(self):
matrix.add_(torch.eye(matrix.size(-1), dtype=torch.float64).mul_(1e-1))

rhs = torch.randn(batch, size, 10, dtype=torch.float64)
solves, t_mats = linear_cg(
matrix.matmul,
rhs=rhs,
n_tridiag=8,
max_iter=size,
max_tridiag_iter=10,
tolerance=0,
eps=1e-30,
)
with warnings.catch_warnings(record=True) as ws:
solves, t_mats = linear_cg(
matrix.matmul,
rhs=rhs,
n_tridiag=8,
max_iter=size,
max_tridiag_iter=10,
tolerance=0,
eps=1e-30,
)
self.assertTrue(any(issubclass(w.category, NumericalWarning) for w in ws))

# Check cg
matrix_chol = torch.linalg.cholesky(matrix)
Expand Down

0 comments on commit f020146

Please sign in to comment.