Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiobjective Multifidelity BO using the Service API #2514

Open
Abrikosoff opened this issue Jun 12, 2024 · 11 comments
Open

Multiobjective Multifidelity BO using the Service API #2514

Abrikosoff opened this issue Jun 12, 2024 · 11 comments
Assignees

Comments

@Abrikosoff
Copy link

Abrikosoff commented Jun 12, 2024

Dear Ax Team,

I am currently trying to run a MOMF use case in the Service API; have been mostly consulting the BoTorch tutorials for this, and so far I have come up with the following repro that seems to work, but I am not sure is working correctly (details below):

import os
from botorch.acquisition.utils import project_to_target_fidelity
from botorch.models.cost import AffineFidelityCostModel
from botorch.acquisition.cost_aware import InverseCostWeightedUtility
from botorch.test_functions.multi_objective_multi_fidelity import MOMFBraninCurrin
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy

from botorch.acquisition.knowledge_gradient import qMultiFidelityKnowledgeGradient
from ax.service.ax_client import AxClient, ObjectiveProperties

tkwargs = {  # Tkwargs is a dictionary contaning data about data type and data device
    "dtype": torch.double,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
}
SMOKE_TEST = os.environ.get("SMOKE_TEST")

from botorch.test_functions.multi_objective_multi_fidelity import MOMFBraninCurrin
BC = MOMFBraninCurrin(negate=True).to(**tkwargs)
dim_x = BC.dim
dim_y = BC.num_objectives

ref_point = torch.zeros(dim_y, **tkwargs)
target_fidelities = {2: 1.0}
cost_intercept = 0.0

cost_model = AffineFidelityCostModel(
    fixed_cost=5.0
)
cost_aware_utility = InverseCostWeightedUtility(cost_model=cost_model)

def project(X):
    return project_to_target_fidelity(X=X, target_fidelities=target_fidelities)

generation_strategy = GenerationStrategy(
                        steps=[
                            GenerationStep(
                                model=Models.SOBOL,
                                num_trials=1,  # https://github.com/facebook/Ax/issues/922
                                min_trials_observed=1,
                                max_parallelism=6,
                                model_kwargs={"seed": 9999},
                            ), 
                            GenerationStep(
                            model=Models.BOTORCH_MODULAR,
                            num_trials=-1,
                            model_kwargs={
                                "botorch_acqf_class": qMultiFidelityKnowledgeGradient,
                            },
                            model_gen_kwargs={
                                "model_gen_options": {
                                    "acqf_kwargs": {"cost_intercept": cost_intercept,
                                                    "num_fantasies": 2,
                                                    # "num_pareto": 1,
                                                    # "current_value": 1,
                                                    "cost_aware_utility": cost_aware_utility,
                                                    # "target_fidelities": normalized_target_fidelities,
                                                    project: project,
                                                    },        
                                },
                            },
                        )
                        ]
                    )

ax_client = AxClient(generation_strategy=generation_strategy)

ax_client.create_experiment(
    name="hartmann_test_experiment",
    parameters=[
        {
            "name": "x1",
            "type": "range",
            "bounds": [0.0, 1.0],
            "value_type": "float",  # Optional, defaults to inference from type of "bounds".
            "log_scale": False,  # Optional, defaults to False.
        },
        {
            "name": "x2",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
        {
            "name": "x3",
            "type": "range",
            "bounds": [0.0, 1.0],
            "is_fidelity": True,
            "target_value": 1.0,  
        },
    ],

    # Multi-objective optimization, using augmented Hartmann function (6D+1D).
    objectives={
        "a": ObjectiveProperties(minimize=False, threshold=BC.ref_point[0]),
        "b": ObjectiveProperties(minimize=False, threshold=BC.ref_point[1]),
    },
)

# Multiobjective optimization
def evaluate(parameters):
    evaluation = BC(
        torch.tensor([parameters.get("x1"), parameters.get("x2"), parameters.get("x3")])
    )
    # In our case, standard error is 0, since we are computing a synthetic function.
    # Set standard error to None if the noise level is unknown.
    return {"a": (evaluation[0].item(), 0.0), "b": (evaluation[1].item(), 0.0)}

for i in range(25):
    parameterization, trial_index = ax_client.get_next_trial()
    ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate(parameterization))

I have mainly two (or three) questions regarding this repro:

  1. It runs, but is this due to some mistakes in the setup such that I have inadvertently simplified this problem to be a non-MF MOBO problem?
  2. I would have expected that for a MFMOBO problem, the correct acqf to use would have been the qMultiFidelityHypervolumeKnowledgeGradient, but here it seems that qMultiFidelityKnowledgeGradient works as well. Why is that?
  3. In the current definition of the evaluate function, I think that the fidelity parameter ('x3') is not being taken into account in the whole optimization process. What is actually the correct way to deal with this parameter?
  4. Also, is it the case that currently, only continuous fidelities are supported for a Service API workflow (related to 2475)?

Thanks in advance, and thank you as well for the replies on my previous questions!

@Abrikosoff Abrikosoff changed the title Multiobjective Multifidelity BO Multiobjective Multifidelity BO using the Service API Jun 12, 2024
@mgrange1998
Copy link
Contributor

Hi, thank you for opening this issue.

  1. Could you give some details on what leads you to believe the experiment has been simplified to a non MF MOBO problem?
  2. qMultiFidelityKnowledgeGradient is for single objective cases, and qMultiFidelityHypervolumeKnowledgeGradient is for multi-objective cases. In your run, does using qMultiFidelityKnowledgeGradient only optimize the first objective?
  3. Could you provide some logs/results which demonstrate that 'x3' is not being taken into account for the optimization process?
  4. For the continuous fidelity question, I will ask @saitcakmak to help answer

@saitcakmak
Copy link
Contributor

Hi @Abrikosoff. I will start by saying we don't actively use MF BO internally, so this functionality is not fully battle tested.

I have more detailed answers below, but after looking into this, I can't recommend you to keep using it. If you are comfortable digging around with the debugger to make sure the correct arguments are constructed and passed in, you can get it to work.

It runs, but is this due to some mistakes in the setup such that I have inadvertently simplified this problem to be a non-MF MOBO problem?

Using some mocks, I see that the MFKG acquisition function is indeed being used under the hood. Whether it is utilizing all the arguments you provided is another question.

You can use a mock like this to trigger an exception & use the debugger to inspect each argument to the acquisition function:

from unittest import mock
with mock.patch.object(qMultiFidelityKnowledgeGradient, "__init__", side_effect=Exception):
    parameterization, trial_index = ax_client.get_next_trial()
    ax_client.complete_trial(
        trial_index=trial_index, raw_data=evaluate(parameterization)
    )

I haven't spent too much time investigating these args, but I do see these being passed into the model.

'cost_aware_utility': InverseCostWeightedUtility(
  (cost_model): AffineFidelityCostModel()
  (cost_objective): IdentityMCObjective()
), 'expand': <function construct_inputs_mf_base.<locals>.<lambda> at 0x7f8a8437c700>, 'project': <function construct_inputs_mf_base.<locals>.<lambda> at 0x7f8a8437ca60>,

I am pretty sure these are not the arguments you are passing in though. I believe these are the defaults constructed by the input constructor: https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py#L1244-L1254

Also, is it the case that currently, only continuous fidelities are supported for a Service API workflow (related to #2475 (comment))?

I don't think this is necessarily the case. Looks like you were getting an error with Choice parameters. There shouldn't be any issue with using integer valued Range parameters though. You can update definition of x3 like this to make it integer range

        {
            "name": "x3",
            "type": "range",
            "value_type": "int",
            "bounds": [0, 1],
            "is_fidelity": True,
            "target_value": 1,  
        },

I would have expected that for a MFMOBO problem, the correct acqf to use would have been the qMultiFidelityHypervolumeKnowledgeGradient, but here it seems that qMultiFidelityKnowledgeGradient works as well. Why is that?

Looks like we end up fitting a SingleTaskMultiFidelityGP with 2 outputs. Somewhere in the acquisition function, the samples must be getting reduced. I do see that there is a ScalarizedPosteriorTransfrom that is being passed to the acquisition function, that simply sums up the two model outputs. This appears to be a bug.

In the current definition of the evaluate function, I think that the fidelity parameter ('x3') is not being taken into account in the whole optimization process. What is actually the correct way to deal with this parameter?

While optimizing the acquisition function, x3 is being optimized along with the other parameters. It is not treated any differently in the optimizer. The acquisition function is responsible for assigning a value to each candidate that includes x1, x2 & x3.

If you wanted to fix x3 to a constant during generation, you could specify this using fixed_features argument to AxClient.get_next_trial.

saitcakmak added a commit to saitcakmak/Ax that referenced this issue Jun 13, 2024
Summary:
I believe this was originally designed based on the legacy MF models, but it appears to be incompatible with what MBM evolved to be. Here are a few reasons for removing it:
- Some of kwargs that are produced by `MultiFidelityAcquisition.compute_model_dependencies` (e.g., `cost_aware_utility`) are not accepted by the MFKG input constructor.
- The main kwargs that are needed by the MF acquisition functions (`cost_aware_utility`, `project`, `expand`) are readily constructed by the MF input constructors: https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py#L1221
- We strongly discourage subclassing MBM components  and rather encourage expanding the base components to support new use cases. In this case, the MF input constructors seem to have replaced what `MultiFidelityAcquisition` originally aimed to do, while being compatible with the design philosophy.

Discovered while investigating facebook#2514

Differential Revision: D58560934
saitcakmak added a commit to saitcakmak/Ax that referenced this issue Jun 13, 2024
Summary:

I believe this was originally designed based on the legacy MF models, but it appears to be incompatible with what MBM evolved to be. Here are a few reasons for removing it:
- Some of kwargs that are produced by `MultiFidelityAcquisition.compute_model_dependencies` (e.g., `cost_aware_utility`) are not accepted by the MFKG input constructor.
- The main kwargs that are needed by the MF acquisition functions (`cost_aware_utility`, `project`, `expand`) are readily constructed by the MF input constructors: https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py#L1221
- We strongly discourage subclassing MBM components  and rather encourage expanding the base components to support new use cases. In this case, the MF input constructors seem to have replaced what `MultiFidelityAcquisition` originally aimed to do, while being compatible with the design philosophy.

Discovered while investigating facebook#2514

Differential Revision: D58560934
@Abrikosoff
Copy link
Author

Abrikosoff commented Jun 14, 2024

@saitcakmak Dear Sait, thanks a lot for the information! I have some follow-ups regarding the points you raised:

Using some mocks, I see that the MFKG acquisition function is indeed being used under the hood. Whether it is utilizing all the arguments you provided is another question.
I am pretty sure these are not the arguments you are passing in though. I believe these are the defaults constructed by the input constructor: https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py#L1244-L1254

The actual reason I've been passing these in explicitly and not resorting to defaults is because of the docstring for qMultiFidelityKnowledgeGradient, code link states that (among other things):

A version of _qKnowledgeGradient_ that supports multi-fidelity optimization via a _CostAwareUtility_ and the _project_ and _expand_ operators. If none of these are set, this acquisition function reduces to _qKnowledgeGradient_.

Hence I was under the impression that if I leave these out the acqf automatically reverts to a non-MF one. Is this not the case? If it is, and since you found out that the MFKG was indeed being used, can I assume that these have been passed in (of course, it might be working because the defaults were being invoked and my actual params were NOT passed in, which raises the question of how to correctly pass these)?

I don't think this is necessarily the case. Looks like you were getting an error with Choice parameters. There shouldn't be any issue with using integer valued Range parameters though. You can update definition of x3 like this to make it integer range

        {
            "name": "x3",
            "type": "range",
            "value_type": "int",
            "bounds": [0, 1],
            "is_fidelity": True,
            "target_value": 1,  
        },

Thank you very much for this! Really a palm-to-the-forehead moment for me :(

Looks like we end up fitting a SingleTaskMultiFidelityGP with 2 outputs. Somewhere in the acquisition function, the samples must be getting reduced. I do see that there is a ScalarizedPosteriorTransfrom that is being passed to the acquisition function, that simply sums up the two model outputs. This appears to be a bug.

Actually, in my (very limited) knowledge, isn't this how MOBO is supposed to work? If you look at the BoTorch documentation for MOBO, especially where the model is initialized, you find:

def initialize_model(train_x, train_obj):
    # define models for objective and constraint
    train_x = normalize(train_x, problem.bounds)
    models = []
    for i in range(train_obj.shape[-1]):
        train_y = train_obj[..., i : i + 1]
        train_yvar = torch.full_like(train_y, NOISE_SE[i] ** 2)
        models.append(
            FixedNoiseGP(
                train_x, train_y, train_yvar, outcome_transform=Standardize(m=1)
            )
        )
    model = ModelListGP(*models)
    mll = SumMarginalLogLikelihood(model.likelihood, model)
    return mll, model

(in our case we have SingleTaskMFGPs, but this does not change the nature of the problem, I think). If this tracks I would think that a ModelList containing two SingleTaskMFGPs would be the way to do MOBO for MFMOBO as well, no? But this very much requires clarification, as it is a core issue here, I think. I have also left a comment on the bug issue

While optimizing the acquisition function, x3 is being optimized along with the other parameters. It is not treated any differently in the optimizer. The acquisition function is responsible for assigning a value to each candidate that includes x1, x2 & x3.

If you wanted to fix x3 to a constant during generation, you could specify this using fixed_features argument to AxClient.get_next_trial.

Thanks for this!

Edit: I have actually tried to use qMultiFidelityHypervolumeKnowledgeGradient in the GenerationStrategy definition as follows:

generation_strategy = GenerationStrategy(
                        steps=[
                            GenerationStep(
                                model=Models.SOBOL,
                                num_trials=1,  # https://github.com/facebook/Ax/issues/922
                                min_trials_observed=1,
                                max_parallelism=6,
                                model_kwargs={"seed": 9999},
                            ),                
                            GenerationStep(
                            model=Models.BOTORCH_MODULAR,
                            num_trials=-1,
                            model_kwargs={
                                # "botorch_acqf_class": qMultiFidelityKnowledgeGradient,
                                "botorch_acqf_class": qMultiFidelityHypervolumeKnowledgeGradient,
                            },
                            model_gen_kwargs={
                                "model_gen_options": {
                                    "acqf_kwargs": {"cost_intercept": cost_intercept,
                                                    "num_fantasies": 2,
                                                    "cost_aware_utility": cost_aware_utility,
                                                    project: project,
                                                    },        
                                },
                            },
                        )
                        ]
                    )

and running the repro shown above, but I get thrown the error
RuntimeError: Input constructor for acquisition class _qMultiFidelityHypervolumeKnowledgeGradient_ not registered. Use the _@acqf_input_constructor_ decorator to register a new method.

Perhaps there is a quick fix for this?

Edit 2: found the tutorial for registration here, which I've redone in the form

# 1. Add input constructor
@acqf_input_constructor(qMultiFidelityHypervolumeKnowledgeGradient)
def construct_inputs_my_acqf(
    model: Model,
    training_data: MaybeDict[SupervisedDataset],
    objective_thresholds: Tensor,
    **kwargs: Any,
) -> Dict[str, Any]:
    **I'm not sure how to correctly do this here**


# 2. Register default optimizer options
@optimizer_argparse.register(qMultiFidelityHypervolumeKnowledgeGradient)
def _argparse_my_acqf(
    acqf: qMultiFidelityHypervolumeKnowledgeGradient, sequential: bool = True
) -> dict:
    return {
        "sequential": sequential
    }  # default to sequentially optimizing batches of queries


# 3-4. Specifying `botorch_acqf_class` and `acquisition_options`
BoTorchModel(
    botorch_acqf_class=qMultiFidelityHypervolumeKnowledgeGradient,
    acquisition_options={
        "alpha": 10**-6,
        # The sub-dict by the key "optimizer_options" can be passed
        # to propagate options to `optimize_acqf`, used in
        # `Acquisition.optimize`, to add/override the default
        # optimizer options registered above.
        "optimizer_options": {"sequential": False},
    },
)

And now I am getting the error

TypeError: botorch.acquisition.multi_objective.hypervolume_knowledge_gradient.qMultiFidelityHypervolumeKnowledgeGradient() argument after ** must be a mapping, not NoneType

which makes sense as I have not yet actually defined construct_inputs_my_acqf, but I am not sure how to do this correctly.

@Abrikosoff
Copy link
Author

@sgbaird If you are interested in updates

facebook-github-bot pushed a commit that referenced this issue Jun 14, 2024
Summary:
Pull Request resolved: #2520

I believe this was originally designed based on the legacy MF models, but it appears to be incompatible with what MBM evolved to be. Here are a few reasons for removing it:
- Some of kwargs that are produced by `MultiFidelityAcquisition.compute_model_dependencies` (e.g., `cost_aware_utility`) are not accepted by the MFKG input constructor.
- The main kwargs that are needed by the MF acquisition functions (`cost_aware_utility`, `project`, `expand`) are readily constructed by the MF input constructors: https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py#L1221
- We strongly discourage subclassing MBM components  and rather encourage expanding the base components to support new use cases. In this case, the MF input constructors seem to have replaced what `MultiFidelityAcquisition` originally aimed to do, while being compatible with the design philosophy.

Discovered while investigating #2514

Reviewed By: Balandat

Differential Revision: D58560934

fbshipit-source-id: fc58675eff4ff81dc0a4a93084e01f8a4c8e0efc
@saitcakmak
Copy link
Contributor

Hence I was under the impression that if I leave these out the acqf automatically reverts to a non-MF one. Is this not the case?

The way Ax constructs BoTorch acquisition functions involves using acquisition function input constructors (e.g., this one) to convert the data available on the Ax experiment to the inputs expected by the acquisition function. The input constructors often define default behaviors. In the case of MFKG, this involves constructing the cost utility, expand & project arguments. The search space includes a fidelity parameter with a target fidelity, so this is used to figure out what target fidelity to project to etc.

MOBO modeling behavior

Addressed in the other issue.

Design of acqf input constructor

The job of the input constructor is to take the inputs passed in from Ax (this is done here) and convert them into the inputs required for the qMultiFidelityHypervolumeKnowledgeGradient.__init__. This is an active area of development, and we don't have a clear documentation on it, so it may require a bit of digging around in the code to figure out what inputs are available. But the input constructor of MFKG should provide a good starting point: https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py#L1291

@Abrikosoff
Copy link
Author

@saitcakmak Thanks a million!I'm keeping this open for now as I plan to come back and put up a working repro later, at which point I'll close it; in the meantime might have some more questions :-)

@Abrikosoff
Copy link
Author

Hence I was under the impression that if I leave these out the acqf automatically reverts to a non-MF one. Is this not the case?

The way Ax constructs BoTorch acquisition functions involves using acquisition function input constructors (e.g., this one) to convert the data available on the Ax experiment to the inputs expected by the acquisition function. The input constructors often define default behaviors. In the case of MFKG, this involves constructing the cost utility, expand & project arguments. The search space includes a fidelity parameter with a target fidelity, so this is used to figure out what target fidelity to project to etc.

MOBO modeling behavior

Addressed in the other issue.

Design of acqf input constructor

The job of the input constructor is to take the inputs passed in from Ax (this is done here) and convert them into the inputs required for the qMultiFidelityHypervolumeKnowledgeGradient.__init__. This is an active area of development, and we don't have a clear documentation on it, so it may require a bit of digging around in the code to figure out what inputs are available. But the input constructor of MFKG should provide a good starting point: https://github.com/pytorch/botorch/blob/main/botorch/acquisition/input_constructors.py#L1291

@saitcakmak So I went and took a look at the code, which I am reproducing here:

@acqf_input_constructor(qMultiFidelityKnowledgeGradient)
def construct_inputs_qMFKG(
    model: Model,
    training_data: MaybeDict[SupervisedDataset],
    bounds: List[Tuple[float, float]],
    target_fidelities: Dict[int, Union[int, float]],
    objective: Optional[MCAcquisitionObjective] = None,
    posterior_transform: Optional[PosteriorTransform] = None,
    fidelity_weights: Optional[Dict[int, float]] = None,
    cost_intercept: float = 1.0,
    num_trace_observations: int = 0,
    num_fantasies: int = 64,
) -> Dict[str, Any]:
    r"""Construct kwargs for `qMultiFidelityKnowledgeGradient` constructor."""

    inputs_mf = construct_inputs_mf_base(
        target_fidelities=target_fidelities,
        fidelity_weights=fidelity_weights,
        cost_intercept=cost_intercept,
        num_trace_observations=num_trace_observations,
    )

    inputs_kg = construct_inputs_qKG(
        model=model,
        training_data=training_data,
        bounds=bounds,
        objective=objective,
        posterior_transform=posterior_transform,
        num_fantasies=num_fantasies,
    )

    return {**inputs_mf, **inputs_kg}

My understanding is that a hypothetical construct_inputs_qMFHVKG function will require something along the lines of

inputs_mf = construct_inputs_mf_base(
     target_fidelities=target_fidelities,
     fidelity_weights=fidelity_weights,
     cost_intercept=cost_intercept,
     num_trace_observations=num_trace_observations,
 )

 inputs_hvkg = construct_inputs_qHVKG(
     model=model,
     training_data=training_data,
     bounds=bounds,
     objective=objective,
     posterior_transform=posterior_transform,
     num_fantasies=num_fantasies,
 )

 return {**inputs_mf, **inputs_hvkg}

while implies that i need to define my own construct_inputs_qHVKG function (I looked, no such function exists in the Ax codebase). But here comes a question: the qHVKG acqf is directly useable in a GS definition (already tested this), but it does not have a construct_inputs_qHVKG acqf registration function. How does this track? Am I missing something and I actually DON'T need a construct_inputs_qHVKG (maybe construct_inputs_qKG works for qHVKG as well)?

@saitcakmak
Copy link
Contributor

the qHVKG acqf is directly useable in a GS definition (already tested this), but it does not have a construct_inputs_qHVKG acqf registration function. How does this track?

This piece of code errors out for me:

from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import qMultiFidelityHypervolumeKnowledgeGradient
from botorch.acquisition.input_constructors import get_acqf_input_constructor
get_acqf_input_constructor(qMultiFidelityHypervolumeKnowledgeGradient)

saying that there is no registered input constructor for it. This will be executed when trying to generate candidates with it, so even though you can construct a GS (which doesn't involve any checks on model kwargs), it will not work for generating candidates.

I'd like to help further here but I will not be around for the next two weeks and I have other things to wrap up before I leave. Let's leave this issue open so we can follow up and fix any remaining gaps for MF support in MBM.
cc @esantorella in case you have any inputs on setting up the input constructor here

@saitcakmak saitcakmak self-assigned this Jun 14, 2024
@Abrikosoff
Copy link
Author

Abrikosoff commented Jun 17, 2024

@saitcakmak @esantorella Hi guys, so I went ahead and tried to set up the input constructor as so:


@acqf_input_constructor(qHypervolumeKnowledgeGradient)
def construct_inputs_qHVKG(
    model: Model,
    training_data: MaybeDict[SupervisedDataset],
    objective_thresholds: Tensor,
    bounds: List[Tuple[float, float]],
    num_fantasies: int = 64,
    num_pareto: int = 10,
    objective: Optional[MCAcquisitionObjective] = None,
    sampler: Optional[MCSampler] = None,
    inner_sampler: Optional[MCSampler] = None,
    posterior_transform: Optional[PosteriorTransform] = None,
    X_pending: Optional[Tensor] = None,
    use_posterior_mean: bool = False,
    mc_samples: int = 128,
    qmc: bool = True,
    alpha: Optional[float] = None,
    cost_aware_utility: Optional[CostAwareUtility] = None,
    **optimize_objective_kwargs: TOptimizeObjectiveKwargs,
) -> Dict[str, Any]:
    r"""Construct kwargs for `qHypervolumeKnowledgeGradient` constructor."""

    X = _get_dataset_field(
        training_data, 
        "X", 
        first_only=True, 
        assert_shared=True
        )

    posterior_transform = ScalarizedPosteriorTransform(weights=torch.tensor([0.5, 0.5], dtype=X.dtype, device=X.device))

    # compute posterior mean (for ref point computation ref pareto frontier)
    with torch.no_grad():
        Y_pmean = model.posterior(X).mean

    # this part is redundant? At least I don't see where I should use Y
    if objective is None:
        ref_point = objective_thresholds
        Y = Y_pmean
    elif isinstance(objective, RiskMeasureMCObjective):
        ref_point = objective.preprocessing_function(objective_thresholds)
        Y = objective.preprocessing_function(Y_pmean)
    else:
        ref_point = objective(objective_thresholds)
        Y = objective(Y_pmean)

    # if sampler is None and isinstance(model, GPyTorchModel):
    #     sampler = _get_sampler(mc_samples=mc_samples, qmc=qmc)

    num_objectives = objective_thresholds.shape[0]

    # ref_point = objective_thresholds

    X = _get_dataset_field(
        training_data, 
        "X", 
        first_only=True)
    
    # this part seems redundant as well
    alpha = (
        get_default_partitioning_alpha(num_objectives=num_objectives)
        if alpha is None
        else alpha
    )

    _bounds = torch.as_tensor(bounds, dtype=X.dtype, device=X.device)

    _, current_value = optimize_objective(
        model=model,
        bounds=_bounds.t(),
        q=1,
        objective=objective,
        posterior_transform=posterior_transform,
        **optimize_objective_kwargs,
    )

    # HVKG requires the following params
    return {
        "model": model,
        "ref_point": ref_point,
        "num_fantasies": num_fantasies,
        "num_pareto": num_pareto,
        "sampler": sampler,
        "objective": objective,
        "inner_sampler": inner_sampler,
        "X_pending": X_pending,
        "current_value": current_value.detach().cpu().max(),
        "use_posterior_mean": use_posterior_mean,
        "cost_aware_utility": cost_aware_utility,
    }

# 1. Add input constructor
@acqf_input_constructor(qMultiFidelityHypervolumeKnowledgeGradient)
def construct_inputs_qMFHVKG(
    model: Model,
    training_data: MaybeDict[SupervisedDataset],
    bounds: List[Tuple[float, float]],
    # ref_point: Tensor,
    objective_thresholds: Tensor,
    target_fidelities: Dict[int, Union[int, float]] = target_fidelities,
    posterior_transform: Optional[PosteriorTransform] = None,
    objective: Optional[MCAcquisitionObjective] = None,
    fidelity_weights: Optional[Dict[int, float]] = None,
    num_trace_observations: int = 0,
    num_fantasies: int = 64,
    **kwargs: Any,
) -> Dict[str, Any]:
    
    r"""Construct kwargs for `qMultiFidelityHypervolumeKnowledgeGradient` constructor."""

    inputs_mf = construct_inputs_mf_base(
        target_fidelities=target_fidelities,
        fidelity_weights=fidelity_weights,
        cost_intercept=cost_intercept,
        num_trace_observations=num_trace_observations,
    )

    # Might need to change this to qHVKG
    inputs_hvkg = construct_inputs_qHVKG(
        model=model,
        training_data=training_data,
        objective_thresholds=objective_thresholds,
        bounds=bounds,
        posterior_transform=posterior_transform,
    )

    return {**inputs_mf, **inputs_hvkg, "target_fidelities": target_fidelities}
    
# 3-4. Specifying `botorch_acqf_class` and `acquisition_options`
BoTorchModel(
    botorch_acqf_class=qMultiFidelityHypervolumeKnowledgeGradient,
    acquisition_options={
        "alpha": 10**-6,
        "target_fidelities": target_fidelities,
        "optimizer_options": {"sequential": False},
    },
)

and defined the GS as follows:


generation_strategy = GenerationStrategy(
                        steps=[
                            GenerationStep(
                                model=Models.SOBOL,
                                num_trials=1,  # https://github.com/facebook/Ax/issues/922
                                min_trials_observed=1,
                                max_parallelism=6,
                                model_kwargs={"seed": 9999},
                            ), 
                            GenerationStep(
                            model=Models.BOTORCH_MODULAR,
                            num_trials=-1,
                            model_kwargs={
                                # "botorch_acqf_class": qMultiFidelityKnowledgeGradient,
                                # Actual MOBO acquisition function for MF 
                                "botorch_acqf_class": qMultiFidelityHypervolumeKnowledgeGradient,
                                # "surrogate": Surrogate(ModelListGP(SingleTaskGP, SingleTaskGP)),
                            },
                            model_gen_kwargs={
                                "model_gen_options": {
                                    "acqf_kwargs": {"cost_intercept": cost_intercept,
                                                    "num_fantasies": 2,
                                                    # "num_pareto": 1,
                                                    # "current_value": 1,
                                                    "cost_aware_utility": cost_aware_utility,
                                                    "target_fidelities": target_fidelities,
                                                    # "target_fidelities": normalized_target_fidelities,
                                                    project: project,
                                                    },        
                                },
                            },
                        )
                        ]
                    )

but now I am having this error: ValueError: qMultiFidelityHypervolumeKnowledgeGradient requires using a ModelList., but I am not sure how to do this in the context of Ax Service. I feel I am getting close to getting this to work, so all and any help at all would be very much appreciated!

Edit: So I did the naivest possible thing an modified the BoTorch step in the GS above to the following:

GenerationStep(
                            model=Models.BOTORCH_MODULAR,
                            num_trials=-1,
                            model_kwargs={
                                # "botorch_acqf_class": qMultiFidelityKnowledgeGradient,
                                # Actual MOBO acquisition function for MF 
                                "botorch_acqf_class": qMultiFidelityHypervolumeKnowledgeGradient,
                                "surrogate": Surrogate(ModelListGP),
                            },

which throws me an error I wasn't expecting: TypeError: ModelListGP.__init__() got an unexpected keyword argument 'train_X', which is a bit baffling because there is nowhere defined an arg 'train_X' in my code.

Edit 2: As suggested in the tutorial, I think what should go in here is an implementation of a ModelListGP, similar to the definition of the SimpleCustomGP in that tutorial, but I am not sure how that could be achieved, or even if this intuition is correct. Would appreciate any help!

@Balandat
Copy link
Contributor

I looked into this and hacked my way around setting a ModelList (basically just locally had this function return True):

def use_model_list(

However the next problem (and that seems more serious) is that this will throw a NotImplementedError("Trace observations are not currently supported.") error here: https://github.com/pytorch/botorch/blob/bf529df0428aee1bddb29b3c5e7c23682cbc56b7/botorch/acquisition/multi_objective/hypervolume_knowledge_gradient.py#L389-L390

@sdaulton you know that part of the codebase best - can you shed some light on what the limitations here are, why this isn't supported, and what would be needed to support it?

@sdaulton
Copy link
Contributor

sdaulton commented Jul 1, 2024

Re: limitations with trace observations

We just haven't tested MF-HVKG with trace observations, so we would just need to do some due diligence to make sure the shapes are correct. MF-HVKG will work fine with trace-observations, but practically we should make sure the implementation supports it.

@Abrikosoff, In the meantime, you can get around this simply by setting inputs_mf["expand"] = None. Currently, the inputs_mf["expand"] that is returned from construct_inputs_mf_base is just the identity function, so it isn't needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants