From 5559fd6716a2b1b98f480f5c456a6a7b86ff72a3 Mon Sep 17 00:00:00 2001 From: Giacomo Meanti Date: Tue, 16 Jul 2024 22:15:07 +0200 Subject: [PATCH] Some random type hints --- falkon/kernels/kernel.py | 2 +- falkon/mmv_ops/fmm.py | 2 +- falkon/models/falkon.py | 1 + falkon/models/model_utils.py | 15 +++++++++++---- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/falkon/kernels/kernel.py b/falkon/kernels/kernel.py index 52766c7..023dfe6 100644 --- a/falkon/kernels/kernel.py +++ b/falkon/kernels/kernel.py @@ -141,7 +141,7 @@ def __call__( opt: Optional[FalkonOptions] = None, kwargs_m1: Optional[Dict[str, torch.Tensor]] = None, kwargs_m2: Optional[Dict[str, torch.Tensor]] = None, - ): + ) -> torch.Tensor: """Compute the kernel matrix between ``X1`` and ``X2`` Parameters diff --git a/falkon/mmv_ops/fmm.py b/falkon/mmv_ops/fmm.py index be2a480..e8d79c9 100644 --- a/falkon/mmv_ops/fmm.py +++ b/falkon/mmv_ops/fmm.py @@ -705,7 +705,7 @@ def fmm( X2: Union[torch.Tensor, SparseTensor], kwargs_m1: Optional[Dict[str, torch.Tensor]] = None, kwargs_m2: Optional[Dict[str, torch.Tensor]] = None, -): +) -> torch.Tensor: import falkon.kernels if isinstance(kernel, falkon.kernels.DiffKernel): diff --git a/falkon/models/falkon.py b/falkon/models/falkon.py index 93336c0..3b28da6 100644 --- a/falkon/models/falkon.py +++ b/falkon/models/falkon.py @@ -159,6 +159,7 @@ def init_pc( pc = FalkonPreconditioner(self.penalty, self.kernel, pc_opt) ny_weight_vec = None if self.weight_fn is not None: + assert ny_indices is not None ny_weight_vec = self.weight_fn(Y[ny_indices], X[ny_indices], ny_indices) pc.init(ny_points, weight_vec=ny_weight_vec) return pc diff --git a/falkon/models/model_utils.py b/falkon/models/model_utils.py index 69e5609..bb8b3bd 100644 --- a/falkon/models/model_utils.py +++ b/falkon/models/model_utils.py @@ -77,10 +77,10 @@ def _reset_state(self): def _get_callback_fn( self, - X: _tensor_type, - Y: torch.Tensor, - Xts: _tensor_type, - Yts: torch.Tensor, + X: Optional[_tensor_type], + Y: Optional[torch.Tensor], + Xts: Optional[_tensor_type], + Yts: Optional[torch.Tensor], ny_points: _tensor_type, precond: falkon.preconditioner.Preconditioner, ): @@ -88,8 +88,14 @@ def _get_callback_fn( The callback computes and displays the validation error. """ + assert not (X is None and Xts is None), "At least one of `X` or `Xts` must be specified" + assert not (Y is None and Yts is None), "At least one of `Y` or `Yts` must be specified" + assert self.error_fn is not None, "Error function must be specified for callbacks" def val_cback(it, beta, train_time): + assert self.error_fn is not None + assert self.fit_times_ is not None + assert self.val_errors_ is not None # fit_times_[0] is the preparation (i.e. preconditioner time). # train_time is the cumulative training time (excludes time for this function) self.fit_times_.append(self.fit_times_[0] + train_time) @@ -103,6 +109,7 @@ def val_cback(it, beta, train_time): pred = self._predict(Xts, ny_points, alpha) err = self.error_fn(Yts, pred) else: + assert X is not None and Y is not None pred = self._predict(X, ny_points, alpha) err = self.error_fn(Y, pred) err_name = "error"