Skip to content

Commit

Permalink
Some random type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
Giacomo Meanti committed Jul 16, 2024
1 parent b4c7814 commit 5559fd6
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 6 deletions.
2 changes: 1 addition & 1 deletion falkon/kernels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __call__(
opt: Optional[FalkonOptions] = None,
kwargs_m1: Optional[Dict[str, torch.Tensor]] = None,
kwargs_m2: Optional[Dict[str, torch.Tensor]] = None,
):
) -> torch.Tensor:
"""Compute the kernel matrix between ``X1`` and ``X2``
Parameters
Expand Down
2 changes: 1 addition & 1 deletion falkon/mmv_ops/fmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ def fmm(
X2: Union[torch.Tensor, SparseTensor],
kwargs_m1: Optional[Dict[str, torch.Tensor]] = None,
kwargs_m2: Optional[Dict[str, torch.Tensor]] = None,
):
) -> torch.Tensor:
import falkon.kernels

if isinstance(kernel, falkon.kernels.DiffKernel):
Expand Down
1 change: 1 addition & 0 deletions falkon/models/falkon.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def init_pc(
pc = FalkonPreconditioner(self.penalty, self.kernel, pc_opt)
ny_weight_vec = None
if self.weight_fn is not None:
assert ny_indices is not None
ny_weight_vec = self.weight_fn(Y[ny_indices], X[ny_indices], ny_indices)
pc.init(ny_points, weight_vec=ny_weight_vec)
return pc
Expand Down
15 changes: 11 additions & 4 deletions falkon/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,25 @@ def _reset_state(self):

def _get_callback_fn(
self,
X: _tensor_type,
Y: torch.Tensor,
Xts: _tensor_type,
Yts: torch.Tensor,
X: Optional[_tensor_type],
Y: Optional[torch.Tensor],
Xts: Optional[_tensor_type],
Yts: Optional[torch.Tensor],
ny_points: _tensor_type,
precond: falkon.preconditioner.Preconditioner,
):
"""Returns the callback function for CG iterations.
The callback computes and displays the validation error.
"""
assert not (X is None and Xts is None), "At least one of `X` or `Xts` must be specified"
assert not (Y is None and Yts is None), "At least one of `Y` or `Yts` must be specified"
assert self.error_fn is not None, "Error function must be specified for callbacks"

def val_cback(it, beta, train_time):
assert self.error_fn is not None
assert self.fit_times_ is not None
assert self.val_errors_ is not None
# fit_times_[0] is the preparation (i.e. preconditioner time).
# train_time is the cumulative training time (excludes time for this function)
self.fit_times_.append(self.fit_times_[0] + train_time)
Expand All @@ -103,6 +109,7 @@ def val_cback(it, beta, train_time):
pred = self._predict(Xts, ny_points, alpha)
err = self.error_fn(Yts, pred)
else:
assert X is not None and Y is not None
pred = self._predict(X, ny_points, alpha)
err = self.error_fn(Y, pred)
err_name = "error"
Expand Down

0 comments on commit 5559fd6

Please sign in to comment.