Skip to content

Commit

Permalink
[WIP] Factor models into multiple functions
Browse files Browse the repository at this point in the history
This is to make it easier to extend the models using X which is not a Tensor
  • Loading branch information
gmeanti committed Apr 18, 2024
1 parent 5d86e37 commit 373c03b
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 166 deletions.
152 changes: 83 additions & 69 deletions falkon/models/falkon.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import dataclasses
import time
import warnings
from typing import Callable, Optional, Tuple, Union

import torch
from torch import Tensor

import falkon
from falkon.models.model_utils import FalkonBase
from falkon.options import FalkonOptions
from falkon.preconditioner import FalkonPreconditioner
from falkon.sparse import SparseTensor
from falkon.utils import TicToc
from falkon.utils.devices import get_device_info

Expand Down Expand Up @@ -132,6 +136,69 @@ def __init__(
self.weight_fn = weight_fn
self._init_cuda()
self.beta_ = None
self.precond = None

def _reset_state(self):
super()._reset_state()
self.beta_ = None
self.precond = None

def init_pc(
self,
ny_points: Union[Tensor, SparseTensor],
use_cuda_pc: bool,
X: Tensor,
Y: Tensor,
ny_indices: Optional[Tensor] = None,
) -> FalkonPreconditioner:
num_centers = ny_points.shape[0]
with TicToc(f"Calcuating Preconditioner of size {num_centers}", debug=self.options.debug):
pc_opt: FalkonOptions = dataclasses.replace(self.options, use_cpu=not use_cuda_pc)
if pc_opt.debug:
print("Preconditioner will run on %s" % ("CPU" if pc_opt.use_cpu else ("%d GPUs" % self.num_gpus)))
pc = FalkonPreconditioner(self.penalty, self.kernel, pc_opt)
ny_weight_vec = None
if self.weight_fn is not None:
ny_weight_vec = self.weight_fn(Y[ny_indices], X[ny_indices], ny_indices)
pc.init(ny_points, weight_vec=ny_weight_vec)
return pc

def init_kernel_matrix(self, X: Tensor, ny_pts: Tensor) -> falkon.kernels.Kernel:
"""
Decide whether to store the full kernel. If dimensions are such that it is convenient
to precompute it, it is saved in a :class:`PrecomputedKernel` which is used for
subsequent computations. Otherwise return the original kernel..
"""
k_opt = dataclasses.replace(self.options, use_cpu=True)
cpu_info = get_device_info(k_opt)
available_ram = min(k_opt.max_cpu_mem, cpu_info[-1].free_memory) * 0.9
kernel = self.kernel
if self._can_store_knm(X, ny_pts, available_ram):
Knm = self.kernel(X, ny_pts, opt=self.options)
kernel = falkon.kernels.PrecomputedKernel(Knm, opt=self.options)
return kernel

def run_solver(
self,
use_cuda: bool,
kernel: falkon.kernels.Kernel,
X: Tensor,
Y: Tensor,
ny_pts: Tensor,
warm_start: Optional[Tensor],
cb: Callable,
) -> Tuple[Tensor, Tensor]:
with TicToc("Computing Falkon iterations", debug=self.options.debug):
o_opt: FalkonOptions = dataclasses.replace(self.options, use_cpu=not use_cuda)
if o_opt.debug:
optim_dev_str = "CPU" if o_opt.use_cpu else f"{self.num_gpus} GPUs"
print(f"Optimizer will run on {optim_dev_str}", flush=True)
optim = falkon.optim.FalkonConjugateGradient(kernel, self.precond, o_opt, weight_fn=self.weight_fn)
beta = optim.solve(
X, ny_pts, Y, self.penalty, initial_solution=warm_start, max_iter=self.maxiter, callback=cb
)
alpha = self.precond.apply(beta)
return alpha, beta

def fit(
self,
Expand Down Expand Up @@ -182,109 +249,56 @@ def fit(
The fitted model
"""
X, Y, Xts, Yts = self._check_fit_inputs(X, Y, Xts, Yts)
dtype = X.dtype
self.fit_times_ = []
self.val_errors_ = []
self.ny_points_ = None
self.alpha_ = None
self.beta_ = None
self._reset_state()

# Start training timer
t_s = time.time()

with torch.autograd.inference_mode():
# Pick Nystrom centers
if self.weight_fn is not None:
# noinspection PyTupleAssignmentBalance
ny_points, ny_indices = self.center_selection.select_indices(X, None)
if self.weight_fn is None: # don't need indices.
ny_points, ny_indices = self.center_selection.select(X, None), None
else:
# noinspection PyTypeChecker
ny_points: Union[torch.Tensor, falkon.sparse.SparseTensor] = self.center_selection.select(X, None)
ny_indices = None
ny_points, ny_indices = self.center_selection.select_indices(X, None)
num_centers = ny_points.shape[0]

# Decide whether to use CUDA for preconditioning and iterations, based on number of centers
# Decide whether to use CUDA for preconditioning and iterations
_use_cuda_preconditioner = (
self.use_cuda_
and (not self.options.cpu_preconditioner)
and num_centers >= get_min_cuda_preconditioner_size(dtype, self.options)
and num_centers >= get_min_cuda_preconditioner_size(X.dtype, self.options)
)
tot_mmv_mem_usage = X.shape[0] * X.shape[1] * num_centers
_use_cuda_mmv = self.use_cuda_ and tot_mmv_mem_usage / self.num_gpus >= get_min_cuda_mmv_size(
dtype, self.options
X.dtype, self.options
)

if self.use_cuda_:
ny_points = ny_points.pin_memory()

with TicToc(f"Calcuating Preconditioner of size {num_centers}", debug=self.options.debug):
pc_opt: FalkonOptions = dataclasses.replace(self.options, use_cpu=not _use_cuda_preconditioner)
if pc_opt.debug:
print("Preconditioner will run on %s" % ("CPU" if pc_opt.use_cpu else ("%d GPUs" % self.num_gpus)))
precond = falkon.preconditioner.FalkonPreconditioner(self.penalty, self.kernel, pc_opt)
self.precond = precond
ny_weight_vec = None
if self.weight_fn is not None:
ny_weight_vec = self.weight_fn(Y[ny_indices], X[ny_indices], ny_indices)
precond.init(ny_points, weight_vec=ny_weight_vec)
self.precond = self.init_pc(ny_points, _use_cuda_preconditioner, X, Y, ny_indices)

if _use_cuda_mmv:
# Cache must be emptied to ensure enough memory is visible to the optimizer
torch.cuda.empty_cache()
X = X.pin_memory()

# K_NM storage decision
k_opt = dataclasses.replace(self.options, use_cpu=True)
cpu_info = get_device_info(k_opt)
available_ram = min(k_opt.max_cpu_mem, cpu_info[-1].free_memory) * 0.9
Knm = None
if self._can_store_knm(X, ny_points, available_ram):
Knm = self.kernel(X, ny_points, opt=self.options)
calc_kernel = self.init_kernel_matrix(X, ny_points)
self.fit_times_.append(time.time() - t_s) # Preparation time

# Here we define the callback function which will run at the end
# of conjugate gradient iterations. This function computes and
# displays the validation error.
# Define the callback function which runs after each CG iteration. Optionally computes
# and displays the validation error.
validation_cback = None
if self.error_fn is not None and self.error_every is not None:
validation_cback = self._get_callback_fn(X, Y, Xts, Yts, ny_points, precond)

# Start with the falkon algorithm
with TicToc("Computing Falkon iterations", debug=self.options.debug):
o_opt: FalkonOptions = dataclasses.replace(self.options, use_cpu=not _use_cuda_mmv)
if o_opt.debug:
optim_dev_str = "CPU" if o_opt.use_cpu else f"{self.num_gpus} GPUs"
print(f"Optimizer will run on {optim_dev_str}", flush=True)
optim = falkon.optim.FalkonConjugateGradient(self.kernel, precond, o_opt, weight_fn=self.weight_fn)
if Knm is not None:
beta = optim.solve(
Knm,
None,
Y,
self.penalty,
initial_solution=warm_start,
max_iter=self.maxiter,
callback=validation_cback,
)
else:
beta = optim.solve(
X,
ny_points,
Y,
self.penalty,
initial_solution=warm_start,
max_iter=self.maxiter,
callback=validation_cback,
)

self.alpha_ = precond.apply(beta)
self.beta_ = beta
self.ny_points_ = ny_points
validation_cback = self._get_callback_fn(X, Y, Xts, Yts, ny_points, self.precond)

alpha, beta = self.run_solver(_use_cuda_mmv, calc_kernel, X, Y, ny_points, warm_start, validation_cback)
self.alpha_, self.beta_, self.ny_points_ = alpha, beta, ny_points
return self

def _predict(self, X, ny_points, alpha: torch.Tensor) -> torch.Tensor:
with torch.autograd.inference_mode():
if ny_points is None:
warnings.warn("This code-path is deprecated and may be removed. Nys_points must be specified.")
# Then X is the kernel itself
return X @ alpha
num_centers = alpha.shape[0]
Expand Down
Loading

0 comments on commit 373c03b

Please sign in to comment.