Skip to content

Commit

Permalink
Merge pull request #10 from FalkonML/ci-incore
Browse files Browse the repository at this point in the history
v0.6 -- InCoreFalkon, CUDA LAUUM, Bug fixes

The driving change was the implementation of an in-core version of Falkon, suitable for smaller data analyses. Here the data is always kept inside the GPU, thus the model can train much faster. The result is the InCoreFalkon class.

LAUUM was improved to use a CUDA implementation for the inner-loop function.

API changes are limited to the `FalkonOptions` class, where certain option names have changed.

Several bug fixes were also introduced, and edge-cases fixed.
  • Loading branch information
Giodiro authored Aug 31, 2020
2 parents 02cd43e + 1df1556 commit 45d6284
Show file tree
Hide file tree
Showing 71 changed files with 3,263 additions and 1,815 deletions.
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ install:
- conda install -y flake8 codecov
# check manifest needs conda-forge:
- conda install -y -c conda-forge check-manifest
- pip install -e ./keops/
- pip install -e .
script:
- pytest -lv --cov-report term-missing falkon --cov=falkon --cov-config .coveragerc
- pytest --cov-report=term-missing --cov=falkon --cov-config setup.cfg
- flake8 --count falkon
after_success:
- codecov
8 changes: 4 additions & 4 deletions benchmark/mmv_timings.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def run_experiments(experiments):
'dt': torch.float32,
'timings': [],
'repetitions': 10,
'fn': 'kernel._keops_mmv_impl(X1, X2, v, kernel, out=None, opt={"no_keops": False, "compute_arch_speed": False});',
'fn': 'kernel._keops_mmv_impl(X1, X2, v, kernel, out=None, opt={"keops_active": "auto", "compute_arch_speed": False});',
},
{
'name': 'varying N - Our 32',
Expand All @@ -81,7 +81,7 @@ def run_experiments(experiments):
'dt': torch.float32,
'timings': [],
'repetitions': 10,
'fn': 'kernel.mmv(X1, X2, v, out=None, opt={"no_keops": True, "compute_arch_speed": False});'
'fn': 'kernel.mmv(X1, X2, v, out=None, opt={"keops_active": "no", "compute_arch_speed": False});'
},
{
'name': 'varying D - KeOps 32',
Expand All @@ -94,7 +94,7 @@ def run_experiments(experiments):
'dt': torch.float32,
'timings': [],
'repetitions': 10,
'fn': 'kernel._keops_mmv_impl(X1, X2, v, kernel, out=None, opt={"no_keops": False, "compute_arch_speed": False});'
'fn': 'kernel._keops_mmv_impl(X1, X2, v, kernel, out=None, opt={"keops_active": "auto", "compute_arch_speed": False});'
},
{
'name': 'varying D - Our 32',
Expand All @@ -107,7 +107,7 @@ def run_experiments(experiments):
'dt': torch.float32,
'timings': [],
'repetitions': 20,
'fn': 'kernel.mmv(X1, X2, v, out=None, opt={"no_keops": True, "compute_arch_speed": False});'
'fn': 'kernel.mmv(X1, X2, v, out=None, opt={"keops_active": "no", "compute_arch_speed": False});'
},
]

Expand Down
10 changes: 5 additions & 5 deletions benchmark/time_improvements.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,19 @@ def run(exp_num, dset, show_intermediate_errors: bool = False):
'maxiter': 10,
}
if exp_num == 1:
opt = dataclasses.replace(opt, cpu_preconditioner=True, no_keops=True)
opt = dataclasses.replace(opt, cpu_preconditioner=True, keops_active="no")
dtype = DataType.float64
elif exp_num == 2:
opt = dataclasses.replace(opt, cpu_preconditioner=True, no_keops=True)
opt = dataclasses.replace(opt, cpu_preconditioner=True, keops_active="no")
dtype = DataType.float32
elif exp_num == 3:
opt = dataclasses.replace(opt, cpu_preconditioner=False, no_keops=True)
opt = dataclasses.replace(opt, cpu_preconditioner=False, keops_active="no")
dtype = DataType.float32
elif exp_num == 4:
opt = dataclasses.replace(opt, cpu_preconditioner=False, no_keops=True)
opt = dataclasses.replace(opt, cpu_preconditioner=False, keops_active="no")
dtype = DataType.float32
elif exp_num == 5:
opt = dataclasses.replace(opt, cpu_preconditioner=False, no_keops=False)
opt = dataclasses.replace(opt, cpu_preconditioner=False, keops_active="force")
dtype = DataType.float32
else:
raise ValueError("exp num %d not valid" % (exp_num))
Expand Down
7 changes: 7 additions & 0 deletions doc/api_reference/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,10 @@ LogisticFalkon

.. autoclass:: LogisticFalkon
:members:


InCoreFalkon
------------

.. autoclass:: InCoreFalkon
:members:
8 changes: 8 additions & 0 deletions doc/api_reference/outofcore.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
falkon.ooc_ops
==============

The out-of-core algorithms for the Cholesky decomposition and the LAUUM operation are crucial for speeding up our
library. To find out more about how they work, check the source code:

- `Out of core Cholesky <https://github.com/FalkonML/falkon/blob/master/falkon/ooc_ops/multigpu/cuda/multigpu_potrf.cu>`_ (CUDA code)
- `Out of core LAUUM <https://github.com/FalkonML/falkon/blob/master/falkon/ooc_ops/parallel_lauumm.py>`_ (Python code)

The following functions provide a higher-level interface to the two operations.

.. automodule:: falkon.ooc_ops
.. py:currentmodule:: falkon.ooc_ops
Expand Down
6 changes: 4 additions & 2 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
import sys

sys.path.insert(0, os.path.abspath('..'))
sys.path.insert(0, os.path.abspath('../falkon'))
#sys.path.insert(0, os.path.abspath('../falkon'))

# Need mocking to allow everything to be imported even on no-GPU machines
autodoc_mock_imports = [
# "torch",
# "pykeops",
# "numpy",
"falkon.la_helpers.cuda_la_helpers",
"falkon.ooc_ops.cuda",
"falkon.cuda",
"falkon.ooc_ops.multigpu_potrf"
]
Expand Down Expand Up @@ -125,4 +127,4 @@ def get_version(root_dir):
# }
# }

html_sidebars = {'**': ['globaltoc.html', 'localtoc.html', 'searchbox.html']}
html_sidebars = {'**': ['globaltoc.html', 'localtoc.html', 'searchbox.html']}
2 changes: 1 addition & 1 deletion doc/examples/simple_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@
{
"data": {
"text/plain": [
"Falkon(M=404, center_selection=<falkon.center_selection.UniformSel object at 0x7fdf14416a10>, cg_tolerance=1e-07, cholesky_opt=CholeskyOptions(chol_block_size=256, chol_tile_size='auto', chol_force_in_core=False, chol_force_ooc=False, chol_force_parallel=False, chol_par_blk_multiplier=2), compute_arch_speed=False, cpu_preconditioner=False, debug=False, error_every=1, error_fn=None, kernel=GaussianKernel(sigma=5.0), lauum_opt=LauumOptions(lauum_par_blk_multiplier=8), max_cpu_mem=inf, max_gpu_mem=inf, maxiter=20, no_keops=False, no_single_kernel=True, pc_epsilon={torch.float32: 1e-06, torch.float64: 1e-13}, penalty=0.0001, seed=None, use_cpu=True)"
"Falkon(M=404, center_selection=<falkon.center_selection.UniformSel object at 0x7feede906190>, error_every=1, error_fn=None, kernel=GaussianKernel(sigma=5.0), maxiter=20, options=FalkonOptions(keops_acc_dtype='auto', keops_sum_scheme='auto', no_keops=False, chol_force_in_core=False, chol_force_ooc=False, chol_par_blk_multiplier=2, lauum_par_blk_multiplier=8, pc_epsilon_32=1e-05, pc_epsilon_64=1e-13, cpu_preconditioner=False, cg_epsilon_32=1e-07, cg_epsilon_64=1e-15, cg_tolerance=1e-07, cg_full_gradient_every=10, debug=False, use_cpu=False, max_gpu_mem=inf, max_cpu_mem=inf, compute_arch_speed=False, no_single_kernel=True), penalty=0.0001, seed=None)"
]
},
"execution_count": 7,
Expand Down
2 changes: 1 addition & 1 deletion falkon/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.5.0
0.6.0
5 changes: 3 additions & 2 deletions falkon/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os

from . import kernels, optim, preconditioner, center_selection
from .models import Falkon, LogisticFalkon
from .options import FalkonOptions
from .models import Falkon, LogisticFalkon, InCoreFalkon

__all__ = ('Falkon', 'LogisticFalkon', 'kernels', 'optim', 'preconditioner', 'center_selection',
__all__ = ('Falkon', 'LogisticFalkon', 'InCoreFalkon',
'kernels', 'optim', 'preconditioner', 'center_selection',
'FalkonOptions')


Expand Down
72 changes: 53 additions & 19 deletions falkon/center_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import torch

from falkon.sparse.sparse_tensor import SparseTensor
from falkon.utils.tensor_helpers import is_f_contig
from falkon.utils.tensor_helpers import create_same_stride

__all__ = ("CenterSelector", "FixedSelector", "UniformSelector")
_tensor_type = Union[torch.Tensor, SparseTensor]


class NySel(ABC):
class CenterSelector(ABC):
def __init__(self, random_gen):
self.random_gen = random_gen

Expand All @@ -20,17 +21,56 @@ def select(self, X, Y, M):
pass


class UniformSel(NySel):
class FixedSelector(CenterSelector):
def __init__(self, centers: _tensor_type, y_centers: Union[torch.Tensor, None] = None,
random_gen=None):
super().__init__(random_gen)
self.centers = centers
self.y_centers = y_centers

def select(self, X, Y, M):
if Y is not None:
return self.centers, self.y_centers
return self.centers


class UniformSelector(CenterSelector):
def __init__(self, random_gen):
super().__init__(random_gen)

def select(self,
X: _tensor_type,
Y: Union[torch.Tensor, None],
M: int) -> Union[_tensor_type, Tuple[_tensor_type, torch.Tensor]]:
"""Select M rows from 2D array `X`, preserving the memory order of `X`.
"""Select M observations from 2D tensor `X`, preserving device and memory order.
The selection strategy is uniformly at random. To control the randomness,
pass an appropriate numpy random generator to this class's constructor.
Parameters
----------
X
N x D tensor containing the whole input dataset. We have that N <= M.
Y
Optional N x T tensor containing the input targets. If `Y` is provided,
the same observations selected for `X` will also be selected from `Y`.
Certain models (such as :class:`falkon.LogisticFalkon`) require centers to be
extracted from both predictors and targets, while others (such as
:class:`falkon.Falkon`) only require the centers from the predictors.
M
The number of observations to choose. M <= N, otherwise M is forcibly set to N
with a warning.
Returns
-------
X_M
The randomly selected centers. They will be in a new, memory-contiguous tensor.
All characteristics of the input tensor will be preserved.
(X_M, Y_M)
If `Y` was different than `None` then the entries of `Y` corresponding to the
selected centers of `X` will also be returned.
"""
N = X.size(0)
N = X.shape[0]
if M > N:
warnings.warn("Number of centers M greater than the "
"number of data-points. Setting M to %d" % (N))
Expand All @@ -42,21 +82,15 @@ def select(self,
centers = X[idx, :].copy()
Xc = SparseTensor.from_scipy(centers)
else:
Xnp = X.numpy() # work on np array
if is_f_contig(X):
order = 'F'
else:
order = 'C'
Xc_np = np.empty((M, Xnp.shape[1]), dtype=Xnp.dtype, order=order)
Xc = torch.from_numpy(np.take(Xnp, idx, axis=0, out=Xc_np, mode='wrap'))
Xc = create_same_stride((M, X.shape[1]), other=X, dtype=X.dtype, device=X.device,
pin_memory=False)
th_idx = torch.from_numpy(idx.astype(np.long)).to(X.device)
torch.index_select(X, dim=0, index=th_idx, out=Xc)

if Y is not None:
Ynp = Y.numpy() # work on np array
if is_f_contig(X):
order = 'F'
else:
order = 'C'
Yc_np = np.empty((M, Ynp.shape[1]), dtype=Ynp.dtype, order=order)
Yc = torch.from_numpy(np.take(Ynp, idx, axis=0, out=Yc_np, mode='wrap'))
Yc = create_same_stride((M, Y.shape[1]), other=Y, dtype=Y.dtype, device=Y.device,
pin_memory=False)
th_idx = torch.from_numpy(idx.astype(np.long)).to(Y.device)
torch.index_select(Y, dim=0, index=th_idx, out=Yc)
return Xc, Yc
return Xc
2 changes: 1 addition & 1 deletion falkon/cuda/cublas_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, status_code):
self.message = CUBLAS_EXCEPTIONS[status_code]
except KeyError:
self.message = "Unknown CuBLAS error %d" % (status_code)
super.__init__(self.message)
super().__init__(self.message)


def cublas_check_status(status):
Expand Down
2 changes: 1 addition & 1 deletion falkon/cuda/cudart_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(self, status_code):
self.message = CUDA_EXCEPTIONS[status_code]
except KeyError:
self.message = "Unknown CUDA error %d" % (status_code)
super.__init__(self.message)
super().__init__(self.message)


def cuda_check_status(status):
Expand Down
2 changes: 1 addition & 1 deletion falkon/cuda/cusolver_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self, status_code):
self.message = CUSOLVER_EXCEPTIONS[status_code]
except KeyError:
self.message = "Unknown CuBLAS error %d" % (status_code)
super.__init__(self.message)
super().__init__(self.message)


def cusolver_check_status(status):
Expand Down
8 changes: 4 additions & 4 deletions falkon/kernels/distance_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class GaussianKernel(L2DistanceKernel, KeopsKernelMixin):
Creating a Gaussian kernel with a single length-scale. Operations on this kernel will not
use KeOps.
>>> K = GaussianKernel(sigma=3.0, opt=FalkonOptions(no_keops=True))
>>> K = GaussianKernel(sigma=3.0, opt=FalkonOptions(keops_active="no"))
Creating a Gaussian kernel with a different length-scale per dimension
Expand Down Expand Up @@ -208,7 +208,7 @@ def _keops_mmv_impl(self, X1, X2, v, kernel, out, opt: FalkonOptions):
'v = Vj(%d)' % (v.shape[1]),
'g = Pm(1)'
]
other_vars = [torch.tensor([self.gamma]).to(dtype=X1.dtype)]
other_vars = [torch.tensor([self.gamma]).to(device=X1.device, dtype=X1.dtype)]
else:
dim = self.gamma.shape[0]
formula = (
Expand All @@ -222,7 +222,7 @@ def _keops_mmv_impl(self, X1, X2, v, kernel, out, opt: FalkonOptions):
'v = Vj(%d)' % (v.shape[1]),
'g = Pm(%d)' % (dim ** 2)
]
other_vars = [self.gamma.reshape(-1).to(dtype=X1.dtype)]
other_vars = [self.gamma.reshape(-1).to(device=X1.device, dtype=X1.dtype)]

return self.keops_mmv(X1, X2, v, out, formula, aliases, other_vars, opt)

Expand Down Expand Up @@ -337,7 +337,7 @@ def _keops_mmv_impl(self, X1, X2, v, kernel, out, opt: FalkonOptions):
'v = Vj(%d)' % (v.shape[1]),
'g = Pm(1)'
]
other_vars = [torch.tensor([self.gamma]).to(dtype=X1.dtype)]
other_vars = [torch.tensor([self.gamma]).to(device=X1.device, dtype=X1.dtype)]

return self.keops_mmv(X1, X2, v, out, formula, aliases, other_vars, opt)

Expand Down
12 changes: 6 additions & 6 deletions falkon/kernels/dot_prod_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def _keops_mmv_impl(self, X1, X2, v, kernel, out, opt):
'beta = Pm(1)'
]
other_vars = [
torch.tensor([self.gamma]).to(dtype=X1.dtype),
torch.tensor([self.beta]).to(dtype=X1.dtype)
torch.tensor([self.gamma]).to(dtype=X1.dtype, device=X1.device),
torch.tensor([self.beta]).to(dtype=X1.dtype, device=X1.device)
]
return self.keops_mmv(X1, X2, v, out, formula, aliases, other_vars, opt)

Expand Down Expand Up @@ -173,8 +173,8 @@ def _keops_mmv_impl(self, X1, X2, v, kernel, out, opt):
'beta = Pm(1)'
]
other_vars = [
torch.tensor([self.alpha]).to(dtype=X1.dtype),
torch.tensor([self.beta]).to(dtype=X1.dtype)
torch.tensor([self.alpha]).to(dtype=X1.dtype, device=X1.device),
torch.tensor([self.beta]).to(dtype=X1.dtype, device=X1.device)
]

is_int_pow = self.degree == self.degree.to(dtype=torch.int32)
Expand All @@ -183,7 +183,7 @@ def _keops_mmv_impl(self, X1, X2, v, kernel, out, opt):
else:
formula = 'Powf((alpha * (X | Y) + beta), degree) * v'
aliases.append('degree = Pm(1)')
other_vars.append(torch.tensor([self.degree]).to(dtype=X1.dtype))
other_vars.append(torch.tensor([self.degree]).to(dtype=X1.dtype, device=X1.device))

return self.keops_mmv(X1, X2, v, out, formula, aliases, other_vars, opt)

Expand Down Expand Up @@ -228,7 +228,7 @@ def _keops_mmv_impl(self, X1, X2, v, kernel, out, opt):
'v = Vj(%d)' % (v.shape[1]),
'alpha = Pm(1)',
]
other_vars = [torch.tensor([self.alpha]).to(dtype=X1.dtype)]
other_vars = [torch.tensor([self.alpha]).to(dtype=X1.dtype, device=X1.device)]

return self.keops_mmv(X1, X2, v, out, formula, aliases, other_vars, opt)

Expand Down
Loading

0 comments on commit 45d6284

Please sign in to comment.