From 8cf5c2aa468b9ad19c141715eafe5ca2a340ba5f Mon Sep 17 00:00:00 2001 From: gmeanti Date: Thu, 18 Apr 2024 22:40:04 +0200 Subject: [PATCH] Delete special case for when precomputed-kernel is provided --- falkon/optim/conjgrad.py | 42 ++++++-------------- falkon/tests/test_conjgrad.py | 8 ++-- falkon/tests/test_falkon.py | 73 +++++++++++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 34 deletions(-) diff --git a/falkon/optim/conjgrad.py b/falkon/optim/conjgrad.py index 17cac93..a6b8157 100644 --- a/falkon/optim/conjgrad.py +++ b/falkon/optim/conjgrad.py @@ -6,7 +6,6 @@ import torch import falkon -from falkon.mmv_ops.fmmv_incore import incore_fdmmv, incore_fmmv from falkon.options import ConjugateGradientOptions, FalkonOptions from falkon.utils import TicToc from falkon.utils.tensor_helpers import copy_same_stride, create_same_stride @@ -240,18 +239,14 @@ def __init__( self.weight_fn = weight_fn - def falkon_mmv(self, sol, penalty, X, M, Knm): - n = Knm.shape[0] if Knm is not None else X.shape[0] + def falkon_mmv(self, sol, penalty, X, M, n: int): prec = self.preconditioner with TicToc("MMV", False): v = prec.invA(sol) v_t = prec.invT(v) - if Knm is not None: - cc = incore_fdmmv(Knm, v_t, w=None, out=None, opt=self.params) - else: - cc = self.kernel.dmmv(X, M, v_t, None, opt=self.params) + cc = self.kernel.dmmv(X, M, v_t, None, opt=self.params) # AT^-1 @ (TT^-1 @ (cc / n) + penalty * v) cc_ = cc.div_(n) @@ -260,20 +255,15 @@ def falkon_mmv(self, sol, penalty, X, M, Knm): out = prec.invAt(cc_) return out - def weighted_falkon_mmv(self, sol, penalty, X, M, Knm, y_weights): - n = Knm.shape[0] if Knm is not None else X.shape[0] + def weighted_falkon_mmv(self, sol, penalty, X, M, y_weights, n: int): prec = self.preconditioner with TicToc("MMV", False): v = prec.invA(sol) v_t = prec.invT(v) - if Knm is not None: - cc = incore_fmmv(Knm, v_t, None, opt=self.params).mul_(y_weights) - cc = incore_fmmv(Knm.T, cc, None, opt=self.params) - else: - cc = self.kernel.mmv(X, M, v_t, None, opt=self.params).mul_(y_weights) - cc = self.kernel.mmv(M, X, cc, None, opt=self.params) + cc = self.kernel.mmv(X, M, v_t, None, opt=self.params).mul_(y_weights) + cc = self.kernel.mmv(M, X, cc, None, opt=self.params) # AT^-1 @ (TT^-1 @ (cc / n) + penalty * v) cc_ = cc.div_(n) @@ -283,14 +273,9 @@ def weighted_falkon_mmv(self, sol, penalty, X, M, Knm, y_weights): return out def solve(self, X, M, Y, _lambda, initial_solution, max_iter, callback=None): - n = X.size(0) - if M is None: - Knm = X - else: - Knm = None - - cuda_inputs: bool = X.is_cuda - device = X.device + n = Y.size(0) + cuda_inputs: bool = Y.is_cuda + device = Y.device stream = None if cuda_inputs: @@ -308,18 +293,13 @@ def solve(self, X, M, Y, _lambda, initial_solution, max_iter, callback=None): y_over_n.mul_(y_weights) # This can be in-place since we own y_over_n # Compute the right hand side - if Knm is not None: - B = incore_fmmv(Knm, y_over_n, None, transpose=True, opt=self.params) - else: - B = self.kernel.mmv(M, X, y_over_n, opt=self.params) + B = self.kernel.mmv(M, X, y_over_n, opt=self.params) B = self.preconditioner.apply_t(B) if self.is_weighted: - mmv = functools.partial( - self.weighted_falkon_mmv, penalty=_lambda, X=X, M=M, Knm=Knm, y_weights=y_weights - ) + mmv = functools.partial(self.weighted_falkon_mmv, penalty=_lambda, X=X, M=M, y_weights=y_weights, n=n) else: - mmv = functools.partial(self.falkon_mmv, penalty=_lambda, X=X, M=M, Knm=Knm) + mmv = functools.partial(self.falkon_mmv, penalty=_lambda, X=X, M=M, n=n) # Run the conjugate gradient solver beta = self.optimizer.solve(initial_solution, B, mmv, max_iter, callback) diff --git a/falkon/tests/test_conjgrad.py b/falkon/tests/test_conjgrad.py index 43c3bc2..aa49cf5 100644 --- a/falkon/tests/test_conjgrad.py +++ b/falkon/tests/test_conjgrad.py @@ -5,7 +5,7 @@ import torch from falkon.center_selection import UniformSelector -from falkon.kernels import GaussianKernel +from falkon.kernels import GaussianKernel, PrecomputedKernel from falkon.optim.conjgrad import ConjugateGradient, FalkonConjugateGradient from falkon.options import FalkonOptions from falkon.preconditioner import FalkonPreconditioner @@ -154,7 +154,8 @@ def test_restarts(self, data, centers, kernel, preconditioner, knm, kmm, vec_rhs def test_precomputed_kernel(self, data, centers, kernel, preconditioner, knm, kmm, vec_rhs, device): preconditioner = preconditioner.to(device) options = dataclasses.replace(self.basic_opt, use_cpu=device == "cpu") - opt = FalkonConjugateGradient(kernel, preconditioner, opt=options) + calc_kernel = PrecomputedKernel(knm, options) + opt = FalkonConjugateGradient(calc_kernel, preconditioner, opt=options) # Solve (knm.T @ knm + lambda*n*kmm) x = knm.T @ b rhs = knm.T @ vec_rhs @@ -164,7 +165,8 @@ def test_precomputed_kernel(self, data, centers, kernel, preconditioner, knm, km knm = move_tensor(knm, device) vec_rhs = move_tensor(vec_rhs, device) - beta = opt.solve(X=knm, M=None, Y=vec_rhs, _lambda=self.penalty, initial_solution=None, max_iter=200) + # We still need to pass X and M in for shape checks to pass in MMV + beta = opt.solve(X=None, M=None, Y=vec_rhs, _lambda=self.penalty, initial_solution=None, max_iter=200) alpha = preconditioner.apply(beta) assert str(beta.device) == device, "Device has changed unexpectedly" diff --git a/falkon/tests/test_falkon.py b/falkon/tests/test_falkon.py index 7ebed47..18f4dca 100644 --- a/falkon/tests/test_falkon.py +++ b/falkon/tests/test_falkon.py @@ -164,6 +164,29 @@ def error_fn(t, p): np.testing.assert_allclose(flk_cpu.alpha_.numpy(), flk_gpu.alpha_.numpy()) + def test_precompute_kernel(self, reg_data): + Xtr, Ytr, Xts, Yts = reg_data + kernel = kernels.GaussianKernel(20.0) + + def error_fn(t, p): + return torch.sqrt(torch.mean((t - p) ** 2)).item(), "RMSE" + + opt = FalkonOptions( + use_cpu=True, + keops_active="no", + debug=True, + never_store_kernel=False, + store_kernel_d_threshold=Xtr.shape[1] - 1, + ) + flk = Falkon(kernel=kernel, penalty=1e-6, M=Xtr.shape[0] // 2, seed=10, options=opt, maxiter=10) + flk.fit(Xtr, Ytr, Xts=Xts, Yts=Yts) + + assert flk.predict(Xts).shape == (Yts.shape[0], 1) + ts_err = error_fn(flk.predict(Xts), Yts)[0] + tr_err = error_fn(flk.predict(Xtr), Ytr)[0] + assert tr_err < ts_err + assert ts_err < 2.5 + class TestWeightedFalkon: @pytest.mark.parametrize( @@ -220,6 +243,29 @@ def weight_fn(y, x, indices): assert err_weight_m1 < err_m1, "Error of weighted class is higher than without weighting" assert err_weight_p1 >= err_p1, "Error of unweighted class is lower than in flk with no weights" + def test_precompute_kernel(self, reg_data): + Xtr, Ytr, Xts, Yts = reg_data + kernel = kernels.GaussianKernel(20.0) + + def error_fn(t, p): + return torch.sqrt(torch.mean((t - p) ** 2)).item(), "RMSE" + + opt = FalkonOptions( + use_cpu=True, + keops_active="no", + debug=True, + never_store_kernel=False, + store_kernel_d_threshold=Xtr.shape[1] - 1, + ) + flk = Falkon(kernel=kernel, penalty=1e-6, M=Xtr.shape[0] // 2, seed=10, options=opt, maxiter=10) + flk.fit(Xtr, Ytr, Xts=Xts, Yts=Yts) + + assert flk.predict(Xts).shape == (Yts.shape[0], 1) + ts_err = error_fn(flk.predict(Xts), Yts)[0] + tr_err = error_fn(flk.predict(Xtr), Ytr)[0] + assert tr_err < ts_err + assert ts_err < 2.5 + @pytest.mark.skipif(not decide_cuda(), reason="No GPU found.") class TestIncoreFalkon: @@ -257,6 +303,33 @@ def error_fn(t, p): err = error_fn(cpreds, Yc)[0] assert err < 5 + def test_precompute_kernel(self, cls_data): + X, Y = cls_data + Xc = X.cuda() + Yc = Y.cuda() + kernel = kernels.GaussianKernel(2.0) + torch.manual_seed(13) + np.random.seed(13) + + def error_fn(t, p): + return 100 * torch.sum(t * p <= 0).to(torch.float32) / t.shape[0], "c-err" + + opt = FalkonOptions( + use_cpu=False, + keops_active="no", + debug=True, + never_store_kernel=False, + store_kernel_d_threshold=X.shape[1] - 1, + ) + M = 500 + flkc = InCoreFalkon(kernel=kernel, penalty=1e-6, M=M, seed=10, options=opt, maxiter=20, error_fn=error_fn) + flkc.fit(Xc, Yc) + + cpreds = flkc.predict(Xc) + assert cpreds.device == Xc.device + err = error_fn(cpreds, Yc)[0] + assert err < 5 + def _runner_str(fname_X, fname_Y, fname_out, num_centers, num_rep, max_iter, gpu_num): run_str = f"""