Skip to content

Commit

Permalink
use most recent trial in TransformToNewSQ by default
Browse files Browse the repository at this point in the history
Summary: see title. This ensures that status_quo_data_by_trial contains the target trial index by default.

Differential Revision: D67875128
  • Loading branch information
sdaulton authored and facebook-github-bot committed Jan 6, 2025
1 parent 97e0b24 commit 4b88d90
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 20 deletions.
24 changes: 8 additions & 16 deletions ax/modelbridge/transforms/tests/test_transform_to_new_sq.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,14 @@ def setUp(self) -> None:
t.mark_completed()
self.data = self.exp.fetch_data()

self._refresh_modelbridge()

def _refresh_modelbridge(self) -> None:
self.modelbridge = ModelBridge(
search_space=self.exp.search_space,
model=Model(),
experiment=self.exp,
data=self.data,
data=self.exp.lookup_data(),
status_quo_name="status_quo",
)

Expand Down Expand Up @@ -139,14 +142,16 @@ def test_single_trial_is_not_transformed(self) -> None:
obs2 = tf.transform_observations(obs)
self.assertEqual(obs, obs2)

def test_taget_trial_index(self) -> None:
def test_target_trial_index(self) -> None:
sobol = get_sobol(search_space=self.exp.search_space)
self.exp.new_batch_trial(generator_run=sobol.gen(2))
self.exp.new_batch_trial(generator_run=sobol.gen(2), optimize_for_power=True)
t = self.exp.trials[1]
t = checked_cast(BatchTrial, t)
t.mark_running(no_runner_required=True)
self.exp.attach_data(get_branin_data_batch(batch=checked_cast(BatchTrial, t)))

self._refresh_modelbridge()

observations = observations_from_data(
experiment=self.exp,
data=self.exp.lookup_data(),
Expand All @@ -157,17 +162,4 @@ def test_taget_trial_index(self) -> None:
observations=observations,
modelbridge=self.modelbridge,
)

self.assertEqual(t.default_trial_idx, 1)

with mock.patch(
"ax.modelbridge.transforms.transform_to_new_sq.get_target_trial_index",
return_value=10,
):
t = TransformToNewSQ(
search_space=self.exp.search_space,
observations=observations,
modelbridge=self.modelbridge,
)

self.assertEqual(t.default_trial_idx, 10)
5 changes: 1 addition & 4 deletions ax/modelbridge/transforms/transform_to_new_sq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from ax.core.optimization_config import OptimizationConfig
from ax.core.outcome_constraint import OutcomeConstraint
from ax.core.search_space import SearchSpace
from ax.core.utils import get_target_trial_index
from ax.modelbridge.transforms.relativize import BaseRelativize, get_metric_index
from ax.models.types import TConfig
from ax.utils.common.typeutils import checked_cast
Expand Down Expand Up @@ -71,9 +70,7 @@ def __init__(
and modelbridge is not None
and modelbridge._experiment is not None
):
target_trial_index = get_target_trial_index(
experiment=modelbridge._experiment
)
target_trial_index = max(self.status_quo_data_by_trial.keys())

if target_trial_index is not None:
self.default_trial_idx: int = checked_cast(int, target_trial_index)
Expand Down

0 comments on commit 4b88d90

Please sign in to comment.