Skip to content

Commit

Permalink
use most recent trial if no SQ data for target trial in TransformToNe…
Browse files Browse the repository at this point in the history
…wSQ (#3225)

Summary:
Pull Request resolved: #3225

see title. This ensures that status_quo_data_by_trial contains the target trial index by default.

Reviewed By: danielcohenlive

Differential Revision: D67875128

fbshipit-source-id: 89a48f2812b9ae84a7ac3496de9f89961adce178
  • Loading branch information
sdaulton authored and facebook-github-bot committed Jan 21, 2025
1 parent b2d01c1 commit e29e8ba
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
25 changes: 21 additions & 4 deletions ax/modelbridge/transforms/tests/test_transform_to_new_sq.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,14 @@ def setUp(self) -> None:
t.mark_completed()
self.data = self.exp.fetch_data()

self._refresh_modelbridge()

def _refresh_modelbridge(self) -> None:
self.modelbridge = ModelBridge(
search_space=self.exp.search_space,
model=Model(),
experiment=self.exp,
data=self.data,
data=self.exp.lookup_data(),
status_quo_name="status_quo",
)

Expand Down Expand Up @@ -141,16 +144,18 @@ def test_single_trial_is_not_transformed(self) -> None:
obs2 = tf.transform_observations(obs)
self.assertEqual(obs, obs2)

def test_taget_trial_index(self) -> None:
def test_target_trial_index(self) -> None:
sobol = get_sobol(search_space=self.exp.search_space)
self.exp.new_batch_trial(generator_run=sobol.gen(2))
self.exp.new_batch_trial(generator_run=sobol.gen(2), optimize_for_power=True)
t = self.exp.trials[1]
t = assert_is_instance(t, BatchTrial)
t.mark_running(no_runner_required=True)
self.exp.attach_data(
get_branin_data_batch(batch=assert_is_instance(t, BatchTrial))
)

self._refresh_modelbridge()

observations = observations_from_data(
experiment=self.exp,
data=self.exp.lookup_data(),
Expand All @@ -164,6 +169,18 @@ def test_taget_trial_index(self) -> None:

self.assertEqual(t.default_trial_idx, 1)

with mock.patch(
"ax.modelbridge.transforms.transform_to_new_sq.get_target_trial_index",
return_value=0,
):
t = TransformToNewSQ(
search_space=self.exp.search_space,
observations=observations,
modelbridge=self.modelbridge,
)

self.assertEqual(t.default_trial_idx, 0)
# test falling back to latest trial with SQ data
with mock.patch(
"ax.modelbridge.transforms.transform_to_new_sq.get_target_trial_index",
return_value=10,
Expand All @@ -174,4 +191,4 @@ def test_taget_trial_index(self) -> None:
modelbridge=self.modelbridge,
)

self.assertEqual(t.default_trial_idx, 10)
self.assertEqual(t.default_trial_idx, 1)
10 changes: 10 additions & 0 deletions ax/modelbridge/transforms/transform_to_new_sq.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from __future__ import annotations

from collections.abc import Callable
from logging import Logger

from math import sqrt
from typing import TYPE_CHECKING
Expand All @@ -22,12 +23,14 @@
from ax.core.utils import get_target_trial_index
from ax.modelbridge.transforms.relativize import BaseRelativize, get_metric_index
from ax.models.types import TConfig
from ax.utils.common.logger import get_logger
from ax.utils.stats.statstools import relativize, unrelativize
from pyre_extensions import assert_is_instance, none_throws

if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import modelbridge as modelbridge_module # noqa F401
logger: Logger = get_logger(__name__)


class TransformToNewSQ(BaseRelativize):
Expand Down Expand Up @@ -73,6 +76,13 @@ def __init__(
target_trial_index = get_target_trial_index(
experiment=modelbridge._experiment
)
trials_indices_with_sq_data = self.status_quo_data_by_trial.keys()
if target_trial_index not in trials_indices_with_sq_data:
target_trial_index = max(trials_indices_with_sq_data)
logger.info(
"No SQ data for target trial. Failing back to "
f"{target_trial_index}."
)

if target_trial_index is not None:
self.default_trial_idx: int = assert_is_instance(
Expand Down

0 comments on commit e29e8ba

Please sign in to comment.