Skip to content

Commit

Permalink
Remove more mocks from MBM Surrogate tests
Browse files Browse the repository at this point in the history
Summary: I'm hoping the code is self-explanatory. I reorganized test_fit to make it clearer.

Differential Revision: D68358008
  • Loading branch information
esantorella authored and facebook-github-bot committed Jan 21, 2025
1 parent c10376c commit 1d2e66a
Showing 1 changed file with 81 additions and 110 deletions.
191 changes: 81 additions & 110 deletions ax/models/torch/tests/test_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import dataclasses
import math
import warnings
from collections import OrderedDict
from contextlib import ExitStack
from copy import copy
from itertools import product
from typing import Any
Expand Down Expand Up @@ -43,6 +43,7 @@
from ax.utils.testing.mock import mock_botorch_optimize
from ax.utils.testing.torch_stubs import get_torch_test_data
from ax.utils.testing.utils import generic_equals
from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_mll
from botorch.models import ModelListGP, SaasFullyBayesianSingleTaskGP, SingleTaskGP
from botorch.models.deterministic import GenericDeterministicModel
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
Expand Down Expand Up @@ -1524,32 +1525,37 @@ def test_init(self) -> None:
):
self.surrogate.model

@patch(f"{SURROGATE_PATH}.fit_botorch_model")
@patch.object(
MultiTaskGP,
"construct_inputs",
wraps=MultiTaskGP.construct_inputs,
)
def test_construct_per_outcome_options(
self, mock_MTGP_construct_inputs: Mock, mock_fit: Mock
) -> None:
@mock_botorch_optimize
def test_construct_per_outcome_options(self) -> None:
self.surrogate.surrogate_spec.model_configs[0].model_options.update(
{"output_tasks": [2]}
)
for fixed_noise in (False, True):
mock_fit.reset_mock()
mock_MTGP_construct_inputs.reset_mock()
self.surrogate.fit(
datasets=(
self.fixed_noise_training_data
if fixed_noise
else self.supervised_training_data
),
search_space_digest=dataclasses.replace(
self.multi_task_search_space_digest,
task_features=self.task_features,
),
)
datasets = (
self.fixed_noise_training_data
if fixed_noise
else self.supervised_training_data
)
search_space_digest = dataclasses.replace(
self.multi_task_search_space_digest,
task_features=self.task_features,
)
with ExitStack() as es:
mock_fit = es.enter_context(
patch(
f"{SURROGATE_PATH}.fit_botorch_model", wraps=fit_botorch_model
)
)
mock_MTGP_construct_inputs = es.enter_context(
patch.object(
MultiTaskGP,
"construct_inputs",
wraps=MultiTaskGP.construct_inputs,
)
)
self.surrogate.fit(
datasets=datasets, search_space_digest=search_space_digest
)
# Should construct inputs for MTGP twice.
self.assertEqual(len(mock_MTGP_construct_inputs.call_args_list), 2)
self.assertEqual(mock_fit.call_count, 2)
Expand All @@ -1573,67 +1579,60 @@ def test_construct_per_outcome_options(
},
)

@patch(
f"{CURRENT_PATH}.SaasFullyBayesianMultiTaskGP.load_state_dict",
return_value=None,
)
@patch(
f"{CURRENT_PATH}.SaasFullyBayesianSingleTaskGP.load_state_dict",
return_value=None,
)
@patch(f"{CURRENT_PATH}.Model.load_state_dict", return_value=None)
@patch(f"{CURRENT_PATH}.ExactMarginalLogLikelihood")
@patch(f"{UTILS_PATH}.fit_gpytorch_mll")
@patch(f"{UTILS_PATH}.fit_fully_bayesian_model_nuts")
def test_fit(
self,
mock_fit_nuts: Mock,
mock_fit_gpytorch: Mock,
mock_MLL: Mock,
mock_state_dict: Mock,
mock_state_dict_saas: Mock,
mock_state_dict_saas_mtgp: Mock,
) -> None:
@mock_botorch_optimize
def test_fit(self) -> None:
default_class = self.botorch_model_class
surrogates = [
Surrogate(
cases = {
"default": Surrogate(
botorch_model_class=default_class,
mll_class=ExactMarginalLogLikelihood,
# Check that empty lists also work fine.
outcome_transform_classes=[],
input_transform_classes=[],
),
Surrogate(botorch_model_class=SaasFullyBayesianSingleTaskGP),
Surrogate(botorch_model_class=SaasFullyBayesianMultiTaskGP),
Surrogate( # Batch model
"bayesian_stgp": Surrogate(
botorch_model_class=SaasFullyBayesianSingleTaskGP
),
"bayesian_mtgp": Surrogate(
botorch_model_class=SaasFullyBayesianMultiTaskGP
),
"batch": Surrogate(
botorch_model_class=SingleTaskGP, mll_class=ExactMarginalLogLikelihood
),
Surrogate( # ModelListGP
"ModelListGP": Surrogate(
botorch_model_class=SingleTaskGP,
mll_class=ExactMarginalLogLikelihood,
allow_batched_models=False,
),
]
}

# offset makes task feature point to valid outcome indices
Xs, Ys, Yvars, _, _, _, _ = get_torch_test_data(dtype=self.dtype, offset=-1)
ds1 = SupervisedDataset(
X=Xs[0],
Y=Ys[0],
Yvar=Yvars[0],
feature_names=self.feature_names,
outcome_names=self.outcomes[:1],
)
datasets = [ds1, ds1]

for i, surrogate in enumerate(surrogates):
for case, surrogate in cases.items():
# Reset mocks
mock_state_dict.reset_mock()
mock_MLL.reset_mock()
mock_fit_gpytorch.reset_mock()
mock_fit_nuts.reset_mock()

# Checking that model is None before `fit` (and `construct`) calls.
self.assertIsNone(surrogate._model)
# Should instantiate mll and `fit_gpytorch_mll` when `state_dict`
# is `None`.

is_mtgp = issubclass(
# pyre-ignore[6]: Incompatible parameter type: In call
# `issubclass`, for 1st positional argument, expected
# `Type[typing.Any]` but got `Optional[Type[Model]]`.
surrogate.surrogate_spec.model_configs[0].botorch_model_class,
MultiTaskGP,
)
is_mtgp = case in ("default", "bayesian_mtgp")
if case in ("bayesian_stgp", "bayesian_mtgp"):
fit_fn_name = "fit_fully_bayesian_model_nuts"
fit_fn = fit_fully_bayesian_model_nuts
else:
fit_fn_name = "fit_gpytorch_mll"
fit_fn = fit_gpytorch_mll

search_space_digest = (
self.multi_task_search_space_digest
if is_mtgp
Expand All @@ -1645,70 +1644,42 @@ def test_fit(
"output_tasks or target task value must be provided for"
" MultiTaskGP."
)
with self.assertRaisesRegex(
UserInputError,
msg,
):
with self.assertRaisesRegex(UserInputError, msg):
surrogate.fit(
datasets=[self.ds1, self.ds3],
datasets=datasets,
search_space_digest=search_space_digest,
)
# add target values
search_space_digest = dataclasses.replace(
search_space_digest, target_values={0: 2}
search_space_digest, target_values={0: 1}
)

with patch(f"{UTILS_PATH}.{fit_fn_name}", wraps=fit_fn) as mock_fit:
surrogate.fit(
datasets=datasets,
search_space_digest=search_space_digest,
)
surrogate.fit(
datasets=[self.ds1, self.ds3],
search_space_digest=search_space_digest,
)

mock_state_dict.assert_not_called()
if i == 0:
self.assertEqual(mock_MLL.call_count, 2)
self.assertEqual(mock_fit_gpytorch.call_count, 2)
self.assertTrue(isinstance(surrogate.model, ModelListGP))
elif i in [1, 2]:
self.assertEqual(mock_MLL.call_count, 0)
self.assertEqual(mock_fit_nuts.call_count, 2)
self.assertTrue(isinstance(surrogate.model, ModelListGP))
elif i == 3:
self.assertEqual(mock_MLL.call_count, 1)
self.assertEqual(mock_fit_gpytorch.call_count, 1)
self.assertTrue(isinstance(surrogate.model, SingleTaskGP))
elif i == 4:
self.assertEqual(mock_MLL.call_count, 2)
self.assertEqual(mock_fit_gpytorch.call_count, 2)
self.assertTrue(isinstance(surrogate.model, ModelListGP))
mock_MLL.reset_mock()
mock_fit_gpytorch.reset_mock()
mock_fit_nuts.reset_mock()
mock_fit.assert_called_once()
self.assertIsInstance(
surrogate.model, SingleTaskGP if case == "batch" else ModelListGP
)

# Should `load_state_dict` when `state_dict` is not `None`
# and `refit` is `False`.
state_dict = OrderedDict({"state_attribute": torch.ones(2)})
state_dict = surrogate.model.state_dict()
surrogate._submodels = {} # Prevent re-use of fitted model.
surrogate.fit(
datasets=[self.ds1, self.ds3],
datasets=datasets,
search_space_digest=search_space_digest,
refit=False,
# pyre-fixme: Incompatible parameter type [6]: In call
# `Surrogate.fit`, for argument `state_dict`, expected
# `Optional[OrderedDict[str, Tensor]]` but got `Dict[str,
# typing.Any]`
state_dict=state_dict,
)

if i == 1:
self.assertEqual(mock_state_dict_saas.call_count, 2)
mock_state_dict_saas.reset_mock()
elif i == 2:
self.assertEqual(mock_state_dict_saas_mtgp.call_count, 2)
mock_state_dict_saas_mtgp.reset_mock()
elif i == 3:
mock_state_dict.assert_called_once()
else:
self.assertEqual(mock_state_dict.call_count, 2)
mock_state_dict.reset_mock()
mock_MLL.assert_not_called()
mock_fit_gpytorch.assert_not_called()
mock_fit_nuts.assert_not_called()

# Fitting with PairwiseGP should be ok
fit_botorch_model(
model=PairwiseGP(
Expand Down

0 comments on commit 1d2e66a

Please sign in to comment.