Skip to content

Commit

Permalink
Merge pull request #30 from ziatdinovmax/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
ziatdinovmax authored Sep 15, 2020
2 parents a207995 + de567da commit 5029232
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 6 deletions.
33 changes: 28 additions & 5 deletions gpim/gpbayes/boptim.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ class boptimizer:
**extent(list of lists):
Define bounds for multi-dimensional data. For example, for 2D data,
the extent parameter is [[xmin, xmax], [ymin, ymax]]
**save_checkpoints (bool):
save results to disk after each iteration
**filename (str):
name of a file to save checkpoints to
**verbose (int):
Level of verbosity (0, 1, or 2)
"""
Expand Down Expand Up @@ -171,7 +175,6 @@ def __init__(self,
learning_rate = kwargs.get("learning_rate", 5e-2)
jitter = kwargs.get("jitter", 1.0e-6)
isotropic = kwargs.get("isotropic", False)

self.precision = kwargs.get("precision", "double")

if self.use_gpu and torch.cuda.is_available():
Expand Down Expand Up @@ -213,6 +216,8 @@ def __init__(self,
self.points_mem = kwargs.get("memory", 10)
self.exit_strategy = kwargs.get("exit_strategy", 0)
self.mask = kwargs.get("mask", None)
self.save_checkpoints = kwargs.get("save_checkpoints", False)
self.filename = kwargs.get("filename", "./boptim_results")
self.indices_all, self.vals_all = [], []
self.target_func_vals, self.gp_predictions = [y_seed.copy()], []

Expand Down Expand Up @@ -375,23 +380,23 @@ def dist(idx):

dscale_ = 0 if self.dscale is None else self.dscale
_idx = 0
if self.verbose:
if self.verbose == 2:
print('Acquisition function max value {} at {}'.format(
val_list[_idx], idx_list[_idx]))
if len(self.indices_all) == 0:
return idx_list[_idx], val_list[_idx]
while (1 in [1 for a in self.indices_all if a == idx_list[_idx]]
or dist(idx_list[_idx])):
if self.verbose:
if self.verbose == 2:
print("Finding the next max point...")
_idx = _idx + 1
if _idx == len(idx_list):
_idx = np.random.randint(0, len(idx_list)) if self.exit_strategy else -1
if self.verbose:
if self.verbose == 2:
print('Index out of list. Exiting with acquisition function value {} at {}'.format(
val_list[_idx], idx_list[_idx]))
break
if self.verbose:
if self.verbose == 2:
print('Acquisition function max value {} at {}'.format(
val_list[_idx], idx_list[_idx]))
return idx_list[_idx], val_list[_idx]
Expand Down Expand Up @@ -426,6 +431,24 @@ def run(self):
"""
for i in range(self.exploration_steps):
self.single_step(i)
if self.save_checkpoints:
self.save_results()
self.save_results()
if self.verbose:
print("\nExploration completed")
return

def save_results(self, *args):
"""
Save indermediary and/or final results
"""
try:
filename = args[0]
except IndexError:
filename = self.filename
results = {}
results['gp_pred'] = self.gp_predictions
results['func_val'] = self.target_func_vals
results['inds_all'] = np.array(self.indices_all)
results['vals_all'] = np.array(self.vals_all)
np.save(filename+".npy", results)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__author__ = "Maxim Ziatdinov"
__copyright__ = "Copyright Maxim Ziatdinov (2020)"
__version__ = "0.3.4"
__version__ = "0.3.5"
__maintainer__ = "Maxim Ziatdinov"
__email__ = "[email protected]"
__date__ = "05/20/2020"
Expand Down

0 comments on commit 5029232

Please sign in to comment.