Skip to content

Commit

Permalink
Delete special case for when precomputed-kernel is provided
Browse files Browse the repository at this point in the history
  • Loading branch information
gmeanti committed Apr 18, 2024
1 parent 373c03b commit 8cf5c2a
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 34 deletions.
42 changes: 11 additions & 31 deletions falkon/optim/conjgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch

import falkon
from falkon.mmv_ops.fmmv_incore import incore_fdmmv, incore_fmmv
from falkon.options import ConjugateGradientOptions, FalkonOptions
from falkon.utils import TicToc
from falkon.utils.tensor_helpers import copy_same_stride, create_same_stride
Expand Down Expand Up @@ -240,18 +239,14 @@ def __init__(

self.weight_fn = weight_fn

def falkon_mmv(self, sol, penalty, X, M, Knm):
n = Knm.shape[0] if Knm is not None else X.shape[0]
def falkon_mmv(self, sol, penalty, X, M, n: int):
prec = self.preconditioner

with TicToc("MMV", False):
v = prec.invA(sol)
v_t = prec.invT(v)

if Knm is not None:
cc = incore_fdmmv(Knm, v_t, w=None, out=None, opt=self.params)
else:
cc = self.kernel.dmmv(X, M, v_t, None, opt=self.params)
cc = self.kernel.dmmv(X, M, v_t, None, opt=self.params)

# AT^-1 @ (TT^-1 @ (cc / n) + penalty * v)
cc_ = cc.div_(n)
Expand All @@ -260,20 +255,15 @@ def falkon_mmv(self, sol, penalty, X, M, Knm):
out = prec.invAt(cc_)
return out

def weighted_falkon_mmv(self, sol, penalty, X, M, Knm, y_weights):
n = Knm.shape[0] if Knm is not None else X.shape[0]
def weighted_falkon_mmv(self, sol, penalty, X, M, y_weights, n: int):
prec = self.preconditioner

with TicToc("MMV", False):
v = prec.invA(sol)
v_t = prec.invT(v)

if Knm is not None:
cc = incore_fmmv(Knm, v_t, None, opt=self.params).mul_(y_weights)
cc = incore_fmmv(Knm.T, cc, None, opt=self.params)
else:
cc = self.kernel.mmv(X, M, v_t, None, opt=self.params).mul_(y_weights)
cc = self.kernel.mmv(M, X, cc, None, opt=self.params)
cc = self.kernel.mmv(X, M, v_t, None, opt=self.params).mul_(y_weights)
cc = self.kernel.mmv(M, X, cc, None, opt=self.params)

# AT^-1 @ (TT^-1 @ (cc / n) + penalty * v)
cc_ = cc.div_(n)
Expand All @@ -283,14 +273,9 @@ def weighted_falkon_mmv(self, sol, penalty, X, M, Knm, y_weights):
return out

def solve(self, X, M, Y, _lambda, initial_solution, max_iter, callback=None):
n = X.size(0)
if M is None:
Knm = X
else:
Knm = None

cuda_inputs: bool = X.is_cuda
device = X.device
n = Y.size(0)
cuda_inputs: bool = Y.is_cuda
device = Y.device

stream = None
if cuda_inputs:
Expand All @@ -308,18 +293,13 @@ def solve(self, X, M, Y, _lambda, initial_solution, max_iter, callback=None):
y_over_n.mul_(y_weights) # This can be in-place since we own y_over_n

# Compute the right hand side
if Knm is not None:
B = incore_fmmv(Knm, y_over_n, None, transpose=True, opt=self.params)
else:
B = self.kernel.mmv(M, X, y_over_n, opt=self.params)
B = self.kernel.mmv(M, X, y_over_n, opt=self.params)
B = self.preconditioner.apply_t(B)

if self.is_weighted:
mmv = functools.partial(
self.weighted_falkon_mmv, penalty=_lambda, X=X, M=M, Knm=Knm, y_weights=y_weights
)
mmv = functools.partial(self.weighted_falkon_mmv, penalty=_lambda, X=X, M=M, y_weights=y_weights, n=n)
else:
mmv = functools.partial(self.falkon_mmv, penalty=_lambda, X=X, M=M, Knm=Knm)
mmv = functools.partial(self.falkon_mmv, penalty=_lambda, X=X, M=M, n=n)
# Run the conjugate gradient solver
beta = self.optimizer.solve(initial_solution, B, mmv, max_iter, callback)

Expand Down
8 changes: 5 additions & 3 deletions falkon/tests/test_conjgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from falkon.center_selection import UniformSelector
from falkon.kernels import GaussianKernel
from falkon.kernels import GaussianKernel, PrecomputedKernel
from falkon.optim.conjgrad import ConjugateGradient, FalkonConjugateGradient
from falkon.options import FalkonOptions
from falkon.preconditioner import FalkonPreconditioner
Expand Down Expand Up @@ -154,7 +154,8 @@ def test_restarts(self, data, centers, kernel, preconditioner, knm, kmm, vec_rhs
def test_precomputed_kernel(self, data, centers, kernel, preconditioner, knm, kmm, vec_rhs, device):
preconditioner = preconditioner.to(device)
options = dataclasses.replace(self.basic_opt, use_cpu=device == "cpu")
opt = FalkonConjugateGradient(kernel, preconditioner, opt=options)
calc_kernel = PrecomputedKernel(knm, options)
opt = FalkonConjugateGradient(calc_kernel, preconditioner, opt=options)

# Solve (knm.T @ knm + lambda*n*kmm) x = knm.T @ b
rhs = knm.T @ vec_rhs
Expand All @@ -164,7 +165,8 @@ def test_precomputed_kernel(self, data, centers, kernel, preconditioner, knm, km
knm = move_tensor(knm, device)
vec_rhs = move_tensor(vec_rhs, device)

beta = opt.solve(X=knm, M=None, Y=vec_rhs, _lambda=self.penalty, initial_solution=None, max_iter=200)
# We still need to pass X and M in for shape checks to pass in MMV
beta = opt.solve(X=None, M=None, Y=vec_rhs, _lambda=self.penalty, initial_solution=None, max_iter=200)
alpha = preconditioner.apply(beta)

assert str(beta.device) == device, "Device has changed unexpectedly"
Expand Down
73 changes: 73 additions & 0 deletions falkon/tests/test_falkon.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,29 @@ def error_fn(t, p):

np.testing.assert_allclose(flk_cpu.alpha_.numpy(), flk_gpu.alpha_.numpy())

def test_precompute_kernel(self, reg_data):
Xtr, Ytr, Xts, Yts = reg_data
kernel = kernels.GaussianKernel(20.0)

def error_fn(t, p):
return torch.sqrt(torch.mean((t - p) ** 2)).item(), "RMSE"

opt = FalkonOptions(
use_cpu=True,
keops_active="no",
debug=True,
never_store_kernel=False,
store_kernel_d_threshold=Xtr.shape[1] - 1,
)
flk = Falkon(kernel=kernel, penalty=1e-6, M=Xtr.shape[0] // 2, seed=10, options=opt, maxiter=10)
flk.fit(Xtr, Ytr, Xts=Xts, Yts=Yts)

assert flk.predict(Xts).shape == (Yts.shape[0], 1)
ts_err = error_fn(flk.predict(Xts), Yts)[0]
tr_err = error_fn(flk.predict(Xtr), Ytr)[0]
assert tr_err < ts_err
assert ts_err < 2.5


class TestWeightedFalkon:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -220,6 +243,29 @@ def weight_fn(y, x, indices):
assert err_weight_m1 < err_m1, "Error of weighted class is higher than without weighting"
assert err_weight_p1 >= err_p1, "Error of unweighted class is lower than in flk with no weights"

def test_precompute_kernel(self, reg_data):
Xtr, Ytr, Xts, Yts = reg_data
kernel = kernels.GaussianKernel(20.0)

def error_fn(t, p):
return torch.sqrt(torch.mean((t - p) ** 2)).item(), "RMSE"

opt = FalkonOptions(
use_cpu=True,
keops_active="no",
debug=True,
never_store_kernel=False,
store_kernel_d_threshold=Xtr.shape[1] - 1,
)
flk = Falkon(kernel=kernel, penalty=1e-6, M=Xtr.shape[0] // 2, seed=10, options=opt, maxiter=10)
flk.fit(Xtr, Ytr, Xts=Xts, Yts=Yts)

assert flk.predict(Xts).shape == (Yts.shape[0], 1)
ts_err = error_fn(flk.predict(Xts), Yts)[0]
tr_err = error_fn(flk.predict(Xtr), Ytr)[0]
assert tr_err < ts_err
assert ts_err < 2.5


@pytest.mark.skipif(not decide_cuda(), reason="No GPU found.")
class TestIncoreFalkon:
Expand Down Expand Up @@ -257,6 +303,33 @@ def error_fn(t, p):
err = error_fn(cpreds, Yc)[0]
assert err < 5

def test_precompute_kernel(self, cls_data):
X, Y = cls_data
Xc = X.cuda()
Yc = Y.cuda()
kernel = kernels.GaussianKernel(2.0)
torch.manual_seed(13)
np.random.seed(13)

def error_fn(t, p):
return 100 * torch.sum(t * p <= 0).to(torch.float32) / t.shape[0], "c-err"

opt = FalkonOptions(
use_cpu=False,
keops_active="no",
debug=True,
never_store_kernel=False,
store_kernel_d_threshold=X.shape[1] - 1,
)
M = 500
flkc = InCoreFalkon(kernel=kernel, penalty=1e-6, M=M, seed=10, options=opt, maxiter=20, error_fn=error_fn)
flkc.fit(Xc, Yc)

cpreds = flkc.predict(Xc)
assert cpreds.device == Xc.device
err = error_fn(cpreds, Yc)[0]
assert err < 5


def _runner_str(fname_X, fname_Y, fname_out, num_centers, num_rep, max_iter, gpu_num):
run_str = f"""
Expand Down

0 comments on commit 8cf5c2a

Please sign in to comment.