Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v0.6 -- InCoreFalkon, CUDA LAUUM, Bug fixes #10

Merged
merged 56 commits into from
Aug 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
b570341
Add c-sources for CyBlas
Jun 29, 2020
3aab9f9
Multiple changes to allow full-CUDA operation:
Jul 2, 2020
10b2ebb
WIP: Allow in-core mmv
Jul 2, 2020
3621d01
WIP: Add some more tests for new functionality.
Jul 2, 2020
4c58ef1
Fixes for cuda prediction
Jul 2, 2020
c38f0b5
Cleanup & comments
Jul 3, 2020
2922d92
WIP: CUDA copy_triang
Jul 3, 2020
9f637a6
WIP: working copy_triang and mul_triang. Need to refactor.
Jul 3, 2020
5533def
Add CUDA functions for 'mul_triang', 'copy_triang'. Refactor.
Jul 3, 2020
d1c4739
Fix CUDA copy_triang
Jul 4, 2020
1bec80e
Fix CUDA mul-triang
Jul 4, 2020
b0e45e5
Fix LogisticFalkon bug
Jul 4, 2020
45c2e2d
Fix bugs in multiGPU POTRF (concurrency)
Jul 6, 2020
0590111
Fix bugs in multiGPU POTRF (concurrency)
Jul 6, 2020
ba19c0e
Add CUDA TRSM wrapper
Jul 9, 2020
7cfb139
Merge branch 'faster-mmv' of github.com:FalkonML/falkon into faster-mmv
Jul 9, 2020
5863d7d
Fix missing () in super call
Jul 10, 2020
c303bd7
Fix cuda transpose and cuda TRSM
Jul 10, 2020
f7c0374
Tests in preparation of in-core falkon
Jul 13, 2020
9ac3859
Support fMM from CUDA inputs
Jul 14, 2020
b6176ec
Merge branch 'master' into faster-mmv
Jul 24, 2020
5dcbfee
fMM with cuda inputs
Jul 26, 2020
340eac7
Fix LAUUM block size calc
Jul 26, 2020
7afb84a
Enable OOC lauum for CUDA inputs (WIP)
Jul 27, 2020
314753b
Add in-core lauum operation
Aug 25, 2020
11e5af4
Fixed lower-F LAUUM impl.
Giodiro Aug 27, 2020
3185978
WIP: small changes to lauum
Giodiro Aug 27, 2020
e93ab37
WIP: 1st working LAUUM_C_LOWER
Giodiro Aug 27, 2020
a7ad092
Working LAUUM (lower_c)
Giodiro Aug 27, 2020
fa4ad62
Added/fixed tests
Giodiro Aug 28, 2020
9eee53e
Add in-core falkon; refactor models.
Aug 28, 2020
ab7846c
Added in-core falkon tests
Aug 28, 2020
7e29126
Bug fixes for in-core falkon.
Giodiro Aug 28, 2020
451b646
Small fixes to parallel lauum:
Giodiro Aug 28, 2020
b531776
Merge branch 'master' into faster-mmv
Aug 28, 2020
1c46d56
version bump 0.5.1
Aug 28, 2020
277b3ff
Flake8
Aug 28, 2020
7b80cff
Fix falkon api change in docs
Aug 28, 2020
f6d8439
Fix incorrect tests
Aug 28, 2020
f686572
Fix a couple of failing tests
Giodiro Aug 28, 2020
3802fb2
Simpler-shorter sparse tests
Aug 29, 2020
20a1fb8
Merge branch 'faster-mmv' of github.com:FalkonML/falkon into faster-mmv
Aug 29, 2020
faa4b0b
Fixes for sparse tests
Aug 29, 2020
92aa378
More simplification for test_fmmv
Aug 29, 2020
d87ce67
More simplification for test_fmmv 2
Aug 29, 2020
b7700f8
Test refactoring/fixes
Aug 29, 2020
3337395
More test refactoring
Aug 29, 2020
2b960d9
Refactor ooc-potrf test
Aug 29, 2020
ffb5e2d
Remove comment
Aug 29, 2020
2411341
Fix test-bugs
Aug 29, 2020
7c9f5ab
Merge branch 'faster-mmv' of github.com:FalkonML/falkon into faster-mmv
Aug 29, 2020
ba5c34f
Bump version to 0.6.0
Aug 29, 2020
184bd77
Updated documentation
Aug 29, 2020
a1d6708
Fixes weird bug with CUDA detection:
Aug 31, 2020
93d316c
Fix incorrect import of ooc-module
Aug 31, 2020
1df1556
Fix travis build
Aug 31, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ install:
- conda install -y flake8 codecov
# check manifest needs conda-forge:
- conda install -y -c conda-forge check-manifest
- pip install -e ./keops/
- pip install -e .
script:
- pytest -lv --cov-report term-missing falkon --cov=falkon --cov-config .coveragerc
- pytest --cov-report=term-missing --cov=falkon --cov-config setup.cfg
- flake8 --count falkon
after_success:
- codecov
8 changes: 4 additions & 4 deletions benchmark/mmv_timings.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def run_experiments(experiments):
'dt': torch.float32,
'timings': [],
'repetitions': 10,
'fn': 'kernel._keops_mmv_impl(X1, X2, v, kernel, out=None, opt={"no_keops": False, "compute_arch_speed": False});',
'fn': 'kernel._keops_mmv_impl(X1, X2, v, kernel, out=None, opt={"keops_active": "auto", "compute_arch_speed": False});',
},
{
'name': 'varying N - Our 32',
Expand All @@ -81,7 +81,7 @@ def run_experiments(experiments):
'dt': torch.float32,
'timings': [],
'repetitions': 10,
'fn': 'kernel.mmv(X1, X2, v, out=None, opt={"no_keops": True, "compute_arch_speed": False});'
'fn': 'kernel.mmv(X1, X2, v, out=None, opt={"keops_active": "no", "compute_arch_speed": False});'
},
{
'name': 'varying D - KeOps 32',
Expand All @@ -94,7 +94,7 @@ def run_experiments(experiments):
'dt': torch.float32,
'timings': [],
'repetitions': 10,
'fn': 'kernel._keops_mmv_impl(X1, X2, v, kernel, out=None, opt={"no_keops": False, "compute_arch_speed": False});'
'fn': 'kernel._keops_mmv_impl(X1, X2, v, kernel, out=None, opt={"keops_active": "auto", "compute_arch_speed": False});'
},
{
'name': 'varying D - Our 32',
Expand All @@ -107,7 +107,7 @@ def run_experiments(experiments):
'dt': torch.float32,
'timings': [],
'repetitions': 20,
'fn': 'kernel.mmv(X1, X2, v, out=None, opt={"no_keops": True, "compute_arch_speed": False});'
'fn': 'kernel.mmv(X1, X2, v, out=None, opt={"keops_active": "no", "compute_arch_speed": False});'
},
]

Expand Down
10 changes: 5 additions & 5 deletions benchmark/time_improvements.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,19 @@ def run(exp_num, dset, show_intermediate_errors: bool = False):
'maxiter': 10,
}
if exp_num == 1:
opt = dataclasses.replace(opt, cpu_preconditioner=True, no_keops=True)
opt = dataclasses.replace(opt, cpu_preconditioner=True, keops_active="no")
dtype = DataType.float64
elif exp_num == 2:
opt = dataclasses.replace(opt, cpu_preconditioner=True, no_keops=True)
opt = dataclasses.replace(opt, cpu_preconditioner=True, keops_active="no")
dtype = DataType.float32
elif exp_num == 3:
opt = dataclasses.replace(opt, cpu_preconditioner=False, no_keops=True)
opt = dataclasses.replace(opt, cpu_preconditioner=False, keops_active="no")
dtype = DataType.float32
elif exp_num == 4:
opt = dataclasses.replace(opt, cpu_preconditioner=False, no_keops=True)
opt = dataclasses.replace(opt, cpu_preconditioner=False, keops_active="no")
dtype = DataType.float32
elif exp_num == 5:
opt = dataclasses.replace(opt, cpu_preconditioner=False, no_keops=False)
opt = dataclasses.replace(opt, cpu_preconditioner=False, keops_active="force")
dtype = DataType.float32
else:
raise ValueError("exp num %d not valid" % (exp_num))
Expand Down
7 changes: 7 additions & 0 deletions doc/api_reference/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,10 @@ LogisticFalkon

.. autoclass:: LogisticFalkon
:members:


InCoreFalkon
------------

.. autoclass:: InCoreFalkon
:members:
8 changes: 8 additions & 0 deletions doc/api_reference/outofcore.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
falkon.ooc_ops
==============

The out-of-core algorithms for the Cholesky decomposition and the LAUUM operation are crucial for speeding up our
library. To find out more about how they work, check the source code:

- `Out of core Cholesky <https://github.com/FalkonML/falkon/blob/master/falkon/ooc_ops/multigpu/cuda/multigpu_potrf.cu>`_ (CUDA code)
- `Out of core LAUUM <https://github.com/FalkonML/falkon/blob/master/falkon/ooc_ops/parallel_lauumm.py>`_ (Python code)

The following functions provide a higher-level interface to the two operations.

.. automodule:: falkon.ooc_ops
.. py:currentmodule:: falkon.ooc_ops
Expand Down
6 changes: 4 additions & 2 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
import sys

sys.path.insert(0, os.path.abspath('..'))
sys.path.insert(0, os.path.abspath('../falkon'))
#sys.path.insert(0, os.path.abspath('../falkon'))

# Need mocking to allow everything to be imported even on no-GPU machines
autodoc_mock_imports = [
# "torch",
# "pykeops",
# "numpy",
"falkon.la_helpers.cuda_la_helpers",
"falkon.ooc_ops.cuda",
"falkon.cuda",
"falkon.ooc_ops.multigpu_potrf"
]
Expand Down Expand Up @@ -125,4 +127,4 @@ def get_version(root_dir):
# }
# }

html_sidebars = {'**': ['globaltoc.html', 'localtoc.html', 'searchbox.html']}
html_sidebars = {'**': ['globaltoc.html', 'localtoc.html', 'searchbox.html']}
2 changes: 1 addition & 1 deletion doc/examples/simple_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@
{
"data": {
"text/plain": [
"Falkon(M=404, center_selection=<falkon.center_selection.UniformSel object at 0x7fdf14416a10>, cg_tolerance=1e-07, cholesky_opt=CholeskyOptions(chol_block_size=256, chol_tile_size='auto', chol_force_in_core=False, chol_force_ooc=False, chol_force_parallel=False, chol_par_blk_multiplier=2), compute_arch_speed=False, cpu_preconditioner=False, debug=False, error_every=1, error_fn=None, kernel=GaussianKernel(sigma=5.0), lauum_opt=LauumOptions(lauum_par_blk_multiplier=8), max_cpu_mem=inf, max_gpu_mem=inf, maxiter=20, no_keops=False, no_single_kernel=True, pc_epsilon={torch.float32: 1e-06, torch.float64: 1e-13}, penalty=0.0001, seed=None, use_cpu=True)"
"Falkon(M=404, center_selection=<falkon.center_selection.UniformSel object at 0x7feede906190>, error_every=1, error_fn=None, kernel=GaussianKernel(sigma=5.0), maxiter=20, options=FalkonOptions(keops_acc_dtype='auto', keops_sum_scheme='auto', no_keops=False, chol_force_in_core=False, chol_force_ooc=False, chol_par_blk_multiplier=2, lauum_par_blk_multiplier=8, pc_epsilon_32=1e-05, pc_epsilon_64=1e-13, cpu_preconditioner=False, cg_epsilon_32=1e-07, cg_epsilon_64=1e-15, cg_tolerance=1e-07, cg_full_gradient_every=10, debug=False, use_cpu=False, max_gpu_mem=inf, max_cpu_mem=inf, compute_arch_speed=False, no_single_kernel=True), penalty=0.0001, seed=None)"
]
},
"execution_count": 7,
Expand Down
2 changes: 1 addition & 1 deletion falkon/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.5.0
0.6.0
5 changes: 3 additions & 2 deletions falkon/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os

from . import kernels, optim, preconditioner, center_selection
from .models import Falkon, LogisticFalkon
from .options import FalkonOptions
from .models import Falkon, LogisticFalkon, InCoreFalkon

__all__ = ('Falkon', 'LogisticFalkon', 'kernels', 'optim', 'preconditioner', 'center_selection',
__all__ = ('Falkon', 'LogisticFalkon', 'InCoreFalkon',
'kernels', 'optim', 'preconditioner', 'center_selection',
'FalkonOptions')


Expand Down
72 changes: 53 additions & 19 deletions falkon/center_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import torch

from falkon.sparse.sparse_tensor import SparseTensor
from falkon.utils.tensor_helpers import is_f_contig
from falkon.utils.tensor_helpers import create_same_stride

__all__ = ("CenterSelector", "FixedSelector", "UniformSelector")
_tensor_type = Union[torch.Tensor, SparseTensor]


class NySel(ABC):
class CenterSelector(ABC):
def __init__(self, random_gen):
self.random_gen = random_gen

Expand All @@ -20,17 +21,56 @@ def select(self, X, Y, M):
pass


class UniformSel(NySel):
class FixedSelector(CenterSelector):
def __init__(self, centers: _tensor_type, y_centers: Union[torch.Tensor, None] = None,
random_gen=None):
super().__init__(random_gen)
self.centers = centers
self.y_centers = y_centers

def select(self, X, Y, M):
if Y is not None:
return self.centers, self.y_centers
return self.centers


class UniformSelector(CenterSelector):
def __init__(self, random_gen):
super().__init__(random_gen)

def select(self,
X: _tensor_type,
Y: Union[torch.Tensor, None],
M: int) -> Union[_tensor_type, Tuple[_tensor_type, torch.Tensor]]:
"""Select M rows from 2D array `X`, preserving the memory order of `X`.
"""Select M observations from 2D tensor `X`, preserving device and memory order.
The selection strategy is uniformly at random. To control the randomness,
pass an appropriate numpy random generator to this class's constructor.
Parameters
----------
X
N x D tensor containing the whole input dataset. We have that N <= M.
Y
Optional N x T tensor containing the input targets. If `Y` is provided,
the same observations selected for `X` will also be selected from `Y`.
Certain models (such as :class:`falkon.LogisticFalkon`) require centers to be
extracted from both predictors and targets, while others (such as
:class:`falkon.Falkon`) only require the centers from the predictors.
M
The number of observations to choose. M <= N, otherwise M is forcibly set to N
with a warning.
Returns
-------
X_M
The randomly selected centers. They will be in a new, memory-contiguous tensor.
All characteristics of the input tensor will be preserved.
(X_M, Y_M)
If `Y` was different than `None` then the entries of `Y` corresponding to the
selected centers of `X` will also be returned.
"""
N = X.size(0)
N = X.shape[0]
if M > N:
warnings.warn("Number of centers M greater than the "
"number of data-points. Setting M to %d" % (N))
Expand All @@ -42,21 +82,15 @@ def select(self,
centers = X[idx, :].copy()
Xc = SparseTensor.from_scipy(centers)
else:
Xnp = X.numpy() # work on np array
if is_f_contig(X):
order = 'F'
else:
order = 'C'
Xc_np = np.empty((M, Xnp.shape[1]), dtype=Xnp.dtype, order=order)
Xc = torch.from_numpy(np.take(Xnp, idx, axis=0, out=Xc_np, mode='wrap'))
Xc = create_same_stride((M, X.shape[1]), other=X, dtype=X.dtype, device=X.device,
pin_memory=False)
th_idx = torch.from_numpy(idx.astype(np.long)).to(X.device)
torch.index_select(X, dim=0, index=th_idx, out=Xc)

if Y is not None:
Ynp = Y.numpy() # work on np array
if is_f_contig(X):
order = 'F'
else:
order = 'C'
Yc_np = np.empty((M, Ynp.shape[1]), dtype=Ynp.dtype, order=order)
Yc = torch.from_numpy(np.take(Ynp, idx, axis=0, out=Yc_np, mode='wrap'))
Yc = create_same_stride((M, Y.shape[1]), other=Y, dtype=Y.dtype, device=Y.device,
pin_memory=False)
th_idx = torch.from_numpy(idx.astype(np.long)).to(Y.device)
torch.index_select(Y, dim=0, index=th_idx, out=Yc)
return Xc, Yc
return Xc
2 changes: 1 addition & 1 deletion falkon/cuda/cublas_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, status_code):
self.message = CUBLAS_EXCEPTIONS[status_code]
except KeyError:
self.message = "Unknown CuBLAS error %d" % (status_code)
super.__init__(self.message)
super().__init__(self.message)


def cublas_check_status(status):
Expand Down
2 changes: 1 addition & 1 deletion falkon/cuda/cudart_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(self, status_code):
self.message = CUDA_EXCEPTIONS[status_code]
except KeyError:
self.message = "Unknown CUDA error %d" % (status_code)
super.__init__(self.message)
super().__init__(self.message)


def cuda_check_status(status):
Expand Down
2 changes: 1 addition & 1 deletion falkon/cuda/cusolver_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self, status_code):
self.message = CUSOLVER_EXCEPTIONS[status_code]
except KeyError:
self.message = "Unknown CuBLAS error %d" % (status_code)
super.__init__(self.message)
super().__init__(self.message)


def cusolver_check_status(status):
Expand Down
8 changes: 4 additions & 4 deletions falkon/kernels/distance_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class GaussianKernel(L2DistanceKernel, KeopsKernelMixin):
Creating a Gaussian kernel with a single length-scale. Operations on this kernel will not
use KeOps.
>>> K = GaussianKernel(sigma=3.0, opt=FalkonOptions(no_keops=True))
>>> K = GaussianKernel(sigma=3.0, opt=FalkonOptions(keops_active="no"))
Creating a Gaussian kernel with a different length-scale per dimension
Expand Down Expand Up @@ -208,7 +208,7 @@ def _keops_mmv_impl(self, X1, X2, v, kernel, out, opt: FalkonOptions):
'v = Vj(%d)' % (v.shape[1]),
'g = Pm(1)'
]
other_vars = [torch.tensor([self.gamma]).to(dtype=X1.dtype)]
other_vars = [torch.tensor([self.gamma]).to(device=X1.device, dtype=X1.dtype)]
else:
dim = self.gamma.shape[0]
formula = (
Expand All @@ -222,7 +222,7 @@ def _keops_mmv_impl(self, X1, X2, v, kernel, out, opt: FalkonOptions):
'v = Vj(%d)' % (v.shape[1]),
'g = Pm(%d)' % (dim ** 2)
]
other_vars = [self.gamma.reshape(-1).to(dtype=X1.dtype)]
other_vars = [self.gamma.reshape(-1).to(device=X1.device, dtype=X1.dtype)]

return self.keops_mmv(X1, X2, v, out, formula, aliases, other_vars, opt)

Expand Down Expand Up @@ -337,7 +337,7 @@ def _keops_mmv_impl(self, X1, X2, v, kernel, out, opt: FalkonOptions):
'v = Vj(%d)' % (v.shape[1]),
'g = Pm(1)'
]
other_vars = [torch.tensor([self.gamma]).to(dtype=X1.dtype)]
other_vars = [torch.tensor([self.gamma]).to(device=X1.device, dtype=X1.dtype)]

return self.keops_mmv(X1, X2, v, out, formula, aliases, other_vars, opt)

Expand Down
12 changes: 6 additions & 6 deletions falkon/kernels/dot_prod_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def _keops_mmv_impl(self, X1, X2, v, kernel, out, opt):
'beta = Pm(1)'
]
other_vars = [
torch.tensor([self.gamma]).to(dtype=X1.dtype),
torch.tensor([self.beta]).to(dtype=X1.dtype)
torch.tensor([self.gamma]).to(dtype=X1.dtype, device=X1.device),
torch.tensor([self.beta]).to(dtype=X1.dtype, device=X1.device)
]
return self.keops_mmv(X1, X2, v, out, formula, aliases, other_vars, opt)

Expand Down Expand Up @@ -173,8 +173,8 @@ def _keops_mmv_impl(self, X1, X2, v, kernel, out, opt):
'beta = Pm(1)'
]
other_vars = [
torch.tensor([self.alpha]).to(dtype=X1.dtype),
torch.tensor([self.beta]).to(dtype=X1.dtype)
torch.tensor([self.alpha]).to(dtype=X1.dtype, device=X1.device),
torch.tensor([self.beta]).to(dtype=X1.dtype, device=X1.device)
]

is_int_pow = self.degree == self.degree.to(dtype=torch.int32)
Expand All @@ -183,7 +183,7 @@ def _keops_mmv_impl(self, X1, X2, v, kernel, out, opt):
else:
formula = 'Powf((alpha * (X | Y) + beta), degree) * v'
aliases.append('degree = Pm(1)')
other_vars.append(torch.tensor([self.degree]).to(dtype=X1.dtype))
other_vars.append(torch.tensor([self.degree]).to(dtype=X1.dtype, device=X1.device))

return self.keops_mmv(X1, X2, v, out, formula, aliases, other_vars, opt)

Expand Down Expand Up @@ -228,7 +228,7 @@ def _keops_mmv_impl(self, X1, X2, v, kernel, out, opt):
'v = Vj(%d)' % (v.shape[1]),
'alpha = Pm(1)',
]
other_vars = [torch.tensor([self.alpha]).to(dtype=X1.dtype)]
other_vars = [torch.tensor([self.alpha]).to(dtype=X1.dtype, device=X1.device)]

return self.keops_mmv(X1, X2, v, out, formula, aliases, other_vars, opt)

Expand Down
Loading