Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Callback #73

Merged
merged 13 commits into from
Aug 31, 2017
30 changes: 22 additions & 8 deletions gpflowopt/acquisition/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,27 +396,31 @@ class MCMCAcquistion(AcquisitionSum):
"""
Apply MCMC over the hyperparameters of an acquisition function (= over the hyperparameters of the contained models).

The models passed into an object of this class are optimized with MLE, and then further sampled with HMC.
These hyperparameter samples are then set in copies of the acquisition.
The models passed into an object of this class are optimized with MLE (fast burn-in), and then further sampled with
HMC. These hyperparameter samples are then set in copies of the acquisition.

For evaluating the underlying acquisition function, the predictions of the acquisition copies are averaged.
"""
def __init__(self, acquisition, n_slices, **kwargs):
assert isinstance(acquisition, Acquisition)
assert n_slices > 0

copies = [copy.deepcopy(acquisition) for _ in range(n_slices - 1)]
for c in copies:
c.optimize_restarts = 0

# the call to the constructor of the parent classes, will optimize acquisition, so it obtains the MLE solution.
super(MCMCAcquistion, self).__init__([acquisition] + copies)
super(MCMCAcquistion, self).__init__([acquisition]*n_slices)
self._needs_new_copies = True
self._sample_opt = kwargs

def _optimize_models(self):
# Optimize model #1
self.operands[0]._optimize_models()

# Copy it again if needed due to changed free state
if self._needs_new_copies:
new_copies = [copy.deepcopy(self.operands[0]) for _ in range(len(self.operands) - 1)]
for c in new_copies:
c.optimize_restarts = 0
self.operands = ParamList([self.operands[0]] + new_copies)
self._needs_new_copies = False

# Draw samples using HMC
# Sample each model of the acquisition function - results in a list of 2D ndarrays.
hypers = np.hstack([model.sample(len(self.operands), **self._sample_opt) for model in self.models])
Expand All @@ -440,3 +444,13 @@ def set_data(self, X, Y):
def build_acquisition(self, Xcand):
# Average the predictions of the copies.
return 1. / len(self.operands) * super(MCMCAcquistion, self).build_acquisition(Xcand)

def _kill_autoflow(self):
"""
Flag for recreation on next optimize.

Following the recompilation of models, the free state might have changed. This means updating the samples can
cause inconsistencies and errors.
"""
super(MCMCAcquistion, self)._kill_autoflow()
self._needs_new_copies = True
50 changes: 46 additions & 4 deletions gpflowopt/bo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,40 @@

import numpy as np
from scipy.optimize import OptimizeResult
import tensorflow as tf
from gpflow.gpr import GPR

from .acquisition import Acquisition, MCMCAcquistion
from .optim import Optimizer, SciPyOptimizer
from .objective import ObjectiveWrapper
from .design import Design, EmptyDesign
from .objective import ObjectiveWrapper
from .optim import Optimizer, SciPyOptimizer
from .pareto import non_dominated_sort
from .models import ModelWrapper


def jitchol_callback(models):
"""
Increase the likelihood in case of Cholesky failures.

This is similar to the use of jitchol in GPy. Default callback for BayesianOptimizer.
Only usable on GPR models, other types are ignored.
"""
for m in np.atleast_1d(models):
if isinstance(m, ModelWrapper):
jitchol_callback(m.wrapped) # pragma: no cover

if not isinstance(m, GPR):
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe show a warning?


s = m.get_free_state()
eKdiag = np.mean(np.diag(m.kern.compute_K_symm(m.X.value)))
for e in [0] + [10**ex for ex in range(-6,-1)]:
try:
m.likelihood.variance = m.likelihood.variance.value + e * eKdiag
m.optimize(maxiter=5)
break
except tf.errors.InvalidArgumentError: # pragma: no cover
m.set_state(s)


class BayesianOptimizer(Optimizer):
Expand All @@ -32,7 +60,8 @@ class BayesianOptimizer(Optimizer):
Additionally, it is configured with a separate optimizer for the acquisition function.
"""

def __init__(self, domain, acquisition, optimizer=None, initial=None, scaling=True, hyper_draws=None):
def __init__(self, domain, acquisition, optimizer=None, initial=None, scaling=True, hyper_draws=None,
callback=jitchol_callback):
"""
:param Domain domain: The optimization space.
:param Acquisition acquisition: The acquisition function to optimize over the domain.
Expand All @@ -51,6 +80,12 @@ def __init__(self, domain, acquisition, optimizer=None, initial=None, scaling=Tr
are obtained using Hamiltonian MC.
(see `GPflow documentation <https://gpflow.readthedocs.io/en/latest//>`_ for details) for each model.
The acquisition score is computed for each draw, and averaged.
:param callable callback: (optional) this function or object will be called, after the
data of all models has been updated with all models as retrieved by acquisition.models as argument without
the wrapping model handling any scaling . This allows custom model optimization strategies to be implemented.
All manipulations of GPflow models are permitted. Combined with the optimize_restarts parameter of
:class:`~.Acquisition` this allows several scenarios: do the optimization manually from the callback
(optimize_restarts equals 0), or choose the starting point + some random restarts (optimize_restarts > 0).
"""
assert isinstance(acquisition, Acquisition)
assert hyper_draws is None or hyper_draws > 0
Expand All @@ -69,6 +104,8 @@ def __init__(self, domain, acquisition, optimizer=None, initial=None, scaling=Tr
initial = initial or EmptyDesign(domain)
self.set_initial(initial.generate())

self._model_callback = callback

@Optimizer.domain.setter
def domain(self, dom):
assert self.domain.size == dom.size
Expand All @@ -86,6 +123,8 @@ def _update_model_data(self, newX, newY):
assert self.acquisition.data[0].shape[1] == newX.shape[-1]
assert self.acquisition.data[1].shape[1] == newY.shape[-1]
assert newX.shape[0] == newY.shape[0]
if newX.size == 0:
return
X = np.vstack((self.acquisition.data[0], newX))
Y = np.vstack((self.acquisition.data[1], newY))
self.acquisition.set_data(X, Y)
Expand Down Expand Up @@ -174,7 +213,6 @@ def _optimize(self, fx, n_iter):
:param n_iter: number of iterations to run
:return: OptimizeResult object
"""

assert isinstance(fx, ObjectiveWrapper)

# Evaluate and add the initial design (if any)
Expand All @@ -190,6 +228,10 @@ def inverse_acquisition(x):

# Optimization loop
for i in range(n_iter):
# If a callback is specified, and acquisition has the setup flag enabled (indicating an upcoming
# compilation), run the callback.
if self._model_callback and self.acquisition._needs_setup:
self._model_callback([m.wrapped for m in self.acquisition.models])
result = self.optimizer.optimize(inverse_acquisition)
self._update_model_data(result.x, fx(result.x))

Expand Down
19 changes: 16 additions & 3 deletions testing/test_acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def test_object_integrity(self, acquisition):
for oper in acquisition.operands:
self.assertTrue(isinstance(oper, gpflowopt.acquisition.Acquisition),
msg="All operands should be an acquisition object")

self.assertTrue(all(isinstance(m, gpflowopt.models.ModelWrapper) for m in acquisition.models))

@parameterized.expand(list(zip(aggregations)))
Expand Down Expand Up @@ -218,9 +217,23 @@ def test_marginalized_score(self, acquisition):
ei_mcmc = acquisition.evaluate(Xt)
np.testing.assert_almost_equal(ei_mle, ei_mcmc, decimal=5)

@parameterized.expand(list(zip([aggregations[2]])))
def test_mcmc_acq_models(self, acquisition):
def test_mcmc_acq(self):
acquisition = gpflowopt.acquisition.MCMCAcquistion(
gpflowopt.acquisition.ExpectedImprovement(create_parabola_model(domain)), 10)
for oper in acquisition.operands:
self.assertListEqual(acquisition.models, oper.models)
self.assertEqual(acquisition.operands[0], oper)
self.assertTrue(acquisition._needs_new_copies)
acquisition._optimize_models()
self.assertListEqual(acquisition.models, acquisition.operands[0].models)
for oper in acquisition.operands[1:]:
self.assertNotEqual(acquisition.operands[0], oper)
self.assertFalse(acquisition._needs_new_copies)
acquisition._setup()
Xt = np.random.rand(20, 2) * 2 - 1
ei_mle = acquisition.operands[0].evaluate(Xt)
ei_mcmc = acquisition.evaluate(Xt)
np.testing.assert_almost_equal(ei_mle, ei_mcmc, decimal=5)


class TestJointAcquisition(unittest.TestCase):
Expand Down
70 changes: 68 additions & 2 deletions testing/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ def test_optimize_multi_objective(self):
result = optimizer.optimize(vlmop2, n_iter=2)
self.assertTrue(result.success)
self.assertEqual(result.nfev, 2, "Only 2 evaluations permitted")
self.assertTupleEqual(result.x.shape, (9, 2))
self.assertTupleEqual(result.fun.shape, (9, 2))
self.assertTupleEqual(result.x.shape, (7, 2))
self.assertTupleEqual(result.fun.shape, (7, 2))
_, dom = gpflowopt.pareto.non_dominated_sort(result.fun)
self.assertTrue(np.all(dom==0))

Expand Down Expand Up @@ -288,6 +288,71 @@ def test_mcmc(self):
self.assertTrue(np.allclose(result.x, 0), msg="Optimizer failed to find optimum")
self.assertTrue(np.allclose(result.fun, 0), msg="Incorrect function value returned")

def test_callback(self):
class DummyCallback(object):
def __init__(self):
self.counter = 0

def __call__(self, models):
self.counter += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets think about the callback signature some more. Is there any information we want to pass that might be useful for model building?

For instance, to let the model building strategy depend on the iteration number (we can stop optimizing the hyps after a while like in the MES paper). Although we can also look at the data set size.

What about model building strategies that changes model.X en model.Y (like replace clusters etc.). Not sure if that fits here or is even relevant (the GPflow model should be able to cope with it).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the model contains all the data you need to accomplish something. I believe X and Y can even be updated in this callback as long as the model supports it (all models in GPflow do).

If at some point some information is really missing, this can be added.


c = DummyCallback()
optimizer = gpflowopt.BayesianOptimizer(self.domain, self.acquisition, callback=c)
result = optimizer.optimize(lambda X: parabola2d(X)[0], n_iter=2)
self.assertEqual(c.counter, 2)

def test_callback_recompile(self):
class DummyCallback(object):
def __init__(self):
self.recompile = False

def __call__(self, models):
c = np.random.randint(2, 10)
models[0].kern.variance.prior = gpflow.priors.Gamma(c, 1./c)
self.recompile = models[0]._needs_recompile

c = DummyCallback()
optimizer = gpflowopt.BayesianOptimizer(self.domain, self.acquisition, callback=c)
self.acquisition.evaluate(np.zeros((1,2))) # Make sure its run and setup to skip
result = optimizer.optimize(lambda X: parabola2d(X)[0], n_iter=1)
self.assertFalse(c.recompile)
result = optimizer.optimize(lambda X: parabola2d(X)[0], n_iter=1)
self.assertTrue(c.recompile)
self.assertFalse(self.acquisition.models[0]._needs_recompile)

def test_callback_recompile_mcmc(self):
class DummyCallback(object):
def __init__(self):
self.no_models = 0

def __call__(self, models):
c = np.random.randint(2, 10)
models[0].kern.variance.prior = gpflow.priors.Gamma(c, 1. / c)
self.no_models = len(models)

c = DummyCallback()
optimizer = gpflowopt.BayesianOptimizer(self.domain, self.acquisition, hyper_draws=5, callback=c)
opers = optimizer.acquisition.operands
result = optimizer.optimize(lambda X: parabola2d(X)[0], n_iter=1)
self.assertEqual(c.no_models, 1)
self.assertEqual(id(opers[0]), id(optimizer.acquisition.operands[0]))
for op1, op2 in zip(opers[1:], optimizer.acquisition.operands[1:]):
self.assertNotEqual(id(op1), id(op2))
opers = optimizer.acquisition.operands
result = optimizer.optimize(lambda X: parabola2d(X)[0], n_iter=1)
self.assertEqual(id(opers[0]), id(optimizer.acquisition.operands[0]))
for op1, op2 in zip(opers[1:], optimizer.acquisition.operands[1:]):
self.assertNotEqual(id(op1), id(op2))

def test_nongpr_model(self):
design = gpflowopt.design.LatinHyperCube(16, self.domain)
X, Y = design.generate(), parabola2d(design.generate())[0]
m = gpflow.vgp.VGP(X, Y, gpflow.kernels.RBF(2, ARD=True), likelihood=gpflow.likelihoods.Gaussian())
acq = gpflowopt.acquisition.ExpectedImprovement(m)
optimizer = gpflowopt.BayesianOptimizer(self.domain, acq)
result = optimizer.optimize(lambda X: parabola2d(X)[0], n_iter=1)
self.assertTrue(result.success)


class TestSilentOptimization(unittest.TestCase):
@contextmanager
Expand Down Expand Up @@ -323,3 +388,4 @@ def _optimize(self, objective):
opt.optimize(None)
output = out.getvalue().strip()
self.assertEqual(output, '')