Skip to content

Commit

Permalink
Merge pull request #15 from ziatdinovmax/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
ziatdinovmax committed Apr 14, 2020
2 parents 2232901 + 47cef71 commit c6ff323
Show file tree
Hide file tree
Showing 10 changed files with 200 additions and 69 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@

## What is GPim?

GPim is a python package that provides a systematic and easy way to apply Gaussian processes (GP)
to images and hyperspectral data in [Pyro](https://pyro.ai/) and [Gpytorch](https://gpytorch.ai/) frameworks
(without a need to learn those frameworks).
GPim is a python package that provides an easy way to apply Gaussian processes (GP) and GP-based Bayesian optimization
to images and hyperspectral data in [Pyro](https://pyro.ai/) and [Gpytorch](https://gpytorch.ai/) without a need to learn those frameworks.

For the examples, see our papers:

Expand Down
18 changes: 15 additions & 3 deletions gpim/gpbayes/acqfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@ def confidence_bound(gpmodel, X_full, **kwargs):
:math:`\\alpha` coefficient in :math:`\\alpha \\mu + \\beta \\sigma`
**beta (float):
:math:`\\beta` coefficient in :math:`\\alpha \\mu + \\beta \\sigma`
Returns:
Acquisition function and GP prediction (mean + standard devaition)
"""
alpha = kwargs.get("alpha", 0)
beta = kwargs.get("beta", 1)
mean, sd = gpmodel.predict(X_full, verbose=0)
return alpha * mean + beta * sd
acq = alpha * mean + beta * sd
return acq, (mean, sd)


def expected_improvement(gpmodel, X_full, X_sparse, **kwargs):
Expand All @@ -44,6 +48,9 @@ def expected_improvement(gpmodel, X_full, X_sparse, **kwargs):
Sparse grid indices
**xi (float):
xi constant value
Returns:
Acquisition function and GP prediction (mean + standard devaition)
"""
xi = kwargs.get("xi", 0.01)
mean, sd = gpmodel.predict(X_full, verbose=0)
Expand All @@ -52,7 +59,8 @@ def expected_improvement(gpmodel, X_full, X_sparse, **kwargs):
mean_sample_opt = np.nanmax(mean_sample)
imp = mean - mean_sample_opt - xi
z = imp / sd
return imp * norm.cdf(z) + sd * norm.pdf(z)
acq = imp * norm.cdf(z) + sd * norm.pdf(z)
return acq, (mean, sd)


def probability_of_improvement(gpmodel, X_full, X_sparse, **kwargs):
Expand All @@ -69,6 +77,9 @@ def probability_of_improvement(gpmodel, X_full, X_sparse, **kwargs):
Sparse grid indices
**xi (float):
xi constant value
Returns:
Acquisition function and GP prediction (mean + standard devaition)
"""
xi = kwargs.get("xi", 0.01)
mean, sd = gpmodel.predict(X_full, verbose=0)
Expand All @@ -77,4 +88,5 @@ def probability_of_improvement(gpmodel, X_full, X_sparse, **kwargs):
mean_sample_opt = np.nanmax(mean_sample)
z = mean - mean_sample_opt - xi
z = z / sd
return norm.cdf(z)
acq = norm.cdf(z)
return acq, (mean, sd)
54 changes: 37 additions & 17 deletions gpim/gpbayes/boptim.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ class boptimizer:
Args:
X_seed (ndarray):
Seed sparse grid indices with dimensions :math:`c \\times N \\times M`
Seeded sparse grid indices with dimensions :math:`c \\times N \\times M`
or :math:`c \\times N \\times M \\times L`
where *c* is equal to the number of coordinates
(for example, for *xyz* coordinates, *c* = 3)
y_seed (ndarray):
Seed sparse "observations" (data points) with dimensions
Seeded sparse "observations" (data points) with dimensions
:math:`N \\times M` or :math:`N \\times M \\times L`.
Typically, for 2D image *N* and *M* are image height and width,
whereas for 3D hyperspectral data *N* and *M* are spatial dimensions
Expand Down Expand Up @@ -86,19 +86,22 @@ class boptimizer:
**beta (float or int):
beta coefficient in the 'confidence bound' acquisition function
(Default: 1)
**xi (float):
**xi (float):
xi coefficient in 'expected improvement'
and 'probability of improvement' acquisition functions
**use_gpu (bool):
Uses GPU hardware accelerator when set to 'True'.
Notice that for large datasets training model without GPU
is extremely slow.
**verbose (int):
Level of verbosity (0, 1, or 2)
**lscale (float):
Lengthscale determining the separation (euclidean)
distance between query points. Defaults to the kernel
lengthscale at a given step
**extent(list of lists):
Define multi-dimensional data bounds. For example, for 2D data,
the extent parameter is [[xmin, xmax], [ymin, ymax]]
**verbose (int):
Level of verbosity (0, 1, or 2)
"""
def __init__(self,
X_seed,
Expand Down Expand Up @@ -133,24 +136,29 @@ def __init__(self,
X_seed, y_seed, X_full, kernel, lengthscale, sparse, indpoints,
learning_rate, iterations, self.use_gpu, self.verbose, seed)

self.X_sparse = X_seed
self.y_sparse = y_seed
self.X_sparse = X_seed.copy()
self.y_sparse = y_seed.copy()
self.X_full = X_full
if self.use_gpu and torch.cuda.is_available():
self.X_sparse = self.X_sparse.cuda()
self.y_sparse = self.y_sparse.cuda()

self.target_function = target_function
self.acquisition_function = acquisition_function
self.exploration_steps = exploration_steps
self.batch_update = batch_update
self.batch_size = batch_size
self.simulate_measurement = kwargs.get("simulate_measurement", False)
if self.simulate_measurement:
self.y_true = kwargs.get("y_true")
if self.y_true is None:
raise AssertionError(
"To simulate measurements, add ground truth ('y_true)")
self.extent = kwargs.get("extent", None)
self.alpha = kwargs.get("alpha", 0)
self.beta = kwargs.get("beta", 1)
self.xi = kwargs.get("xi", 0.01)
self.lscale = kwargs.get("lscale", None)
self.indices_all, self.vals_all = [], []
self.target_func_vals_all = [y_seed.copy()]
self.gp_predictions = []

def update_posterior(self):
"""
Expand All @@ -170,9 +178,20 @@ def evaluate_function(self, indices):
Evaluates target function in the new point(s)
"""
indices = [indices] if not self.batch_update else indices
for idx in indices:
self.y_sparse[tuple(idx)] = self.target_function(idx)
self.X_sparse = gprutils.get_sparse_grid(self.y_sparse)
if self.simulate_measurement:
for idx in indices:
self.y_sparse[tuple(idx)] = self.y_true[tuple(idx)]
else:
for idx in indices:
if self.extent is not None:
_idx = []
for i, e in zip(idx, self.extent):
_idx.append(i + e[0])
_idx = tuple(_idx)
else:
_idx = tuple(idx)
self.y_sparse[tuple(idx)] = self.target_function(_idx)
self.X_sparse = gprutils.get_sparse_grid(self.y_sparse, self.extent)
self.target_func_vals_all.append(self.y_sparse.copy())
return

Expand All @@ -182,20 +201,21 @@ def next_point(self):
"""
indices_list, vals_list = [], []
if self.acquisition_function == 'cb':
acq = acqfunc.confidence_bound(
acq, pred = acqfunc.confidence_bound(
self.surrogate_model, self.X_full,
alpha=self.alpha, beta=self.beta)
elif self.acquisition_function == 'ei':
acq = acqfunc.expected_improvement(
acq, pred = acqfunc.expected_improvement(
self.surrogate_model, self.X_full,
self.X_sparse, xi=self.xi)
elif self.acquisition_function == 'poi':
acq = acqfunc.probability_of_improvement(
acq, pred = acqfunc.probability_of_improvement(
self.surrogate_model, self.X_full,
self.X_sparse, xi=self.xi)
else:
raise NotImplementedError(
"Choose between 'cb', 'ei', and 'poi' acquisition functions")
self.gp_predictions.append(pred)
for i in range(self.batch_size):
amax_idx = [i[0] for i in np.where(acq == acq.max())]
indices_list.append(amax_idx)
Expand Down Expand Up @@ -286,7 +306,7 @@ def single_step(self, *args):
e = args[0]
if self.verbose:
print("\nExploration step {} / {}".format(
e, self.exploration_steps))
e+1, self.exploration_steps))
# train with seeded data
if e == 0:
self.surrogate_model.train()
Expand Down
28 changes: 20 additions & 8 deletions gpim/gpreg/gpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class reconstructor:
seed (int):
for reproducibility
**amplitude (float): kernel variance or amplitude squared
**precision (str):
Choose between single ('single') and double ('double') precision
"""
def __init__(self,
X,
Expand All @@ -85,6 +87,13 @@ def __init__(self,
Initiates reconstructor parameters
and pre-processes training and test data arrays
"""
self.precision = kwargs.get("precision", "double")
if self.precision == 'single':
self.tensor_type = torch.FloatTensor
self.tensor_type_gpu = torch.cuda.FloatTensor
else:
self.tensor_type = torch.DoubleTensor
self.tensor_type_gpu = torch.cuda.DoubleTensor
self.verbose = verbose
torch.manual_seed(seed)
pyro.set_rng_seed(seed)
Expand All @@ -94,28 +103,30 @@ def __init__(self,
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.set_default_tensor_type(torch.cuda.DoubleTensor)
torch.set_default_tensor_type(self.tensor_type_gpu)
use_gpu = True
else:
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_tensor_type(self.tensor_type)
use_gpu = False
input_dim = np.ndim(y)
self.X, self.y = gprutils.prepare_training_data(X, y)
self.X, self.y = gprutils.prepare_training_data(
X, y, precision=self.precision)
self.do_sparse = sparse
if lengthscale is None and kwargs.get("isotropic") is None:
if lengthscale is None and not kwargs.get("isotropic"):
lengthscale = [[0. for l in range(input_dim)],
[np.mean(y.shape) / 2 for l in range(input_dim)]] # TODO Make separate lscale for each dim
elif lengthscale is None and kwargs.get("isotropic"):
lengthscale = [0., np.mean(y.shape) / 2]
kernel = pyro_kernels.get_kernel(
kernel, input_dim, lengthscale, use_gpu,
amplitude=kwargs.get('amplitude'))
amplitude=kwargs.get('amplitude'), precision=self.precision)
if Xtest is not None:
self.fulldims = Xtest.shape[1:]
else:
self.fulldims = X.shape[1:]
if Xtest is not None:
self.Xtest = gprutils.prepare_test_data(Xtest)
self.Xtest = gprutils.prepare_test_data(
Xtest, precision=self.precision)
else:
self.Xtest = Xtest
if use_gpu:
Expand Down Expand Up @@ -216,7 +227,8 @@ def predict(self, Xtest=None, **kwargs):
UserWarning)
self.Xtest = self.X
elif Xtest is not None:
self.Xtest = gprutils.prepare_test_data(Xtest)
self.Xtest = gprutils.prepare_test_data(
Xtest, precision=self.precision)
self.fulldims = Xtest.shape[1:]
if next(self.model.parameters()).is_cuda:
self.Xtest = self.Xtest.cuda()
Expand Down Expand Up @@ -255,7 +267,7 @@ def run(self, **kwargs):
mean, sd = self.predict()
if next(self.model.parameters()).is_cuda:
self.model.cpu()
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_tensor_type(self.tensor_type)
self.X, self.y = self.X.cpu(), self.y.cpu()
self.Xtest = self.Xtest.cpu()
torch.cuda.empty_cache()
Expand Down
31 changes: 21 additions & 10 deletions gpim/gpreg/skgpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class skreconstructor:
Number of batches for splitting the Xtest array
(for large datasets, you may not have enough GPU memory
to process the entire dataset at once)
**precision (str):
Choose between single ('single') and double ('double') precision
"""
def __init__(self,
X,
Expand All @@ -88,21 +90,28 @@ def __init__(self,
Initiates reconstructor parameters
and pre-processes training and test data arrays
"""
self.precision = kwargs.get("precision", "double")
if self.precision == 'single':
self.tensor_type = torch.FloatTensor
self.tensor_type_gpu = torch.cuda.FloatTensor
else:
self.tensor_type = torch.DoubleTensor
self.tensor_type_gpu = torch.cuda.DoubleTensor
torch.manual_seed(seed)
if use_gpu and torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.set_default_tensor_type(torch.cuda.DoubleTensor)
torch.set_default_tensor_type(self.tensor_type_gpu)
input_dim = np.ndim(y)
if Xtest is not None:
self.fulldims = Xtest.shape[1:]
else:
self.fulldims = X.shape[1:]
X, y = gprutils.prepare_training_data(X, y)
X, y = gprutils.prepare_training_data(X, y, precision=self.precision)
if Xtest is not None:
Xtest = gprutils.prepare_test_data(Xtest)
Xtest = gprutils.prepare_test_data(Xtest, precision=self.precision)
self.X, self.y, self.Xtest = X, y, Xtest
self.do_ski = sparse
self.toeplitz = gpytorch.settings.use_toeplitz(True)
Expand All @@ -114,12 +123,12 @@ def __init__(self,
self.Xtest = self.Xtest.cuda()
self.toeplitz = gpytorch.settings.use_toeplitz(False)
else:
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_tensor_type(self.tensor_type)
self.likelihood = gpytorch.likelihoods.GaussianLikelihood()
isotropic = kwargs.get("isotropic")
_kernel = gpytorch_kernels.get_kernel(
kernel, input_dim, use_gpu,
lengthscale=lengthscale, isotropic=isotropic)
kernel, input_dim, use_gpu, lengthscale=lengthscale,
isotropic=isotropic, precision=self.precision)
grid_points_ratio = kwargs.get("grid_points_ratio", 1.)
self.model = skgprmodel(self.X, self.y,
_kernel, self.likelihood, input_dim,
Expand Down Expand Up @@ -216,7 +225,8 @@ def predict(self, Xtest=None, **kwargs):
UserWarning)
self.Xtest = self.X
elif Xtest is not None:
self.Xtest = gprutils.prepare_test_data(Xtest)
self.Xtest = gprutils.prepare_test_data(
Xtest, precision=self.precision)
self.fulldims = Xtest.shape[1:]
if next(self.model.parameters()).is_cuda:
self.Xtest = self.Xtest.cuda()
Expand All @@ -229,8 +239,9 @@ def predict(self, Xtest=None, **kwargs):
self.model.eval()
self.likelihood.eval()
batch_range = len(self.Xtest) // self.num_batches
mean = np.zeros((self.Xtest.shape[0]))
sd = np.zeros((self.Xtest.shape[0]))
dtype_ = np.float32 if self.precision == 'single' else np.float64
mean = np.zeros((self.Xtest.shape[0]), dtype_)
sd = np.zeros((self.Xtest.shape[0]), dtype_)
if self.verbose:
print('Calculating predictive mean and uncertainty...')
for i in range(self.num_batches):
Expand All @@ -255,7 +266,7 @@ def run(self):
mean, sd = self.predict()
if next(self.model.parameters()).is_cuda:
self.model.cpu()
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_tensor_type(self.tensor_type)
self.X, self.y = self.X.cpu(), self.y.cpu()
self.Xtest = self.Xtest.cpu()
torch.cuda.empty_cache()
Expand Down
Loading

0 comments on commit c6ff323

Please sign in to comment.