Skip to content

Commit

Permalink
Add a helper to pull out gen kwargs (#3246)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3246

This diff is purely to improve the code, I could see folks going either way on if this is preferable, so happy to abandon if we don't like it.

Why I think it's advatageous:
-  Removes some arg passing in already existing helpers by wrapping in gen_kwargs
- Removes that ugly dict creation in the function

Reviewed By: lena-kashtelyan

Differential Revision: D67318518

fbshipit-source-id: 642622af9f498276495ebaaa6c9cdd96986d4d99
  • Loading branch information
mgarrard authored and facebook-github-bot committed Jan 17, 2025
1 parent b4aa11b commit 99a2561
Showing 1 changed file with 67 additions and 26 deletions.
93 changes: 67 additions & 26 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,23 +439,23 @@ def _gen_with_multiple_nodes(
pending_observations = deepcopy(pending_observations) or {}
self.experiment = experiment
self._validate_arms_per_node(arms_per_node=arms_per_node)
pack_gs_gen_kwargs = self._initalize_gen_kwargs(
experiment=experiment,
grs_this_gen=grs,
data=data,
n=n,
fixed_features=fixed_features,
arms_per_node=arms_per_node,
pending_observations=pending_observations,
)
if self.optimization_complete:
raise GenerationStrategyCompleted(
f"Generation strategy {self} generated all the trials as "
"specified in its nodes."
)
# TODO: @mgarrard update this when gen methods are merged
gen_kwargs: dict[str, Any] = {}
gen_kwargs = {
"experiment": experiment,
"data": data,
"pending_observations": pending_observations,
"grs_this_gen": grs,
"n": n,
}

while continue_gen_for_trial:
gen_kwargs["grs_this_gen"] = grs
pack_gs_gen_kwargs["grs_this_gen"] = grs
should_transition, node_to_gen_from_name = (
self._curr.should_transition_to_next_node(
raise_data_required_error=False
Expand All @@ -469,17 +469,15 @@ def _gen_with_multiple_nodes(
node_to_gen_from._should_skip = False
arms_from_node = self._determine_arms_from_node(
node_to_gen_from=node_to_gen_from,
arms_per_node=arms_per_node,
n=n,
gen_kwargs=gen_kwargs,
gen_kwargs=pack_gs_gen_kwargs,
)
fixed_features_from_node = self._determine_fixed_features_from_node(
node_to_gen_from=node_to_gen_from,
gen_kwargs=gen_kwargs,
passed_fixed_features=fixed_features,
gen_kwargs=pack_gs_gen_kwargs,
)
sq_ft_from_node = self._determine_sq_features_from_node(
node_to_gen_from=node_to_gen_from, gen_kwargs=gen_kwargs
node_to_gen_from=node_to_gen_from, gen_kwargs=pack_gs_gen_kwargs
)
self._maybe_transition_to_next_node()
if node_to_gen_from._should_skip:
Expand Down Expand Up @@ -968,11 +966,44 @@ def _should_continue_gen_for_trial(self) -> bool:
for tc in self._curr.transition_edges[next_node]
)

def _initalize_gen_kwargs(
self,
experiment: Experiment,
grs_this_gen: list[GeneratorRun],
data: Data | None = None,
pending_observations: dict[str, list[ObservationFeatures]] | None = None,
n: int | None = None,
fixed_features: ObservationFeatures | None = None,
arms_per_node: dict[str, int] | None = None,
) -> dict[str, Any]:
"""Creates a dictionary mapping the name of all kwargs kwargs passed into
``gen_with_multiple_nodes`` to their values. This is used by
``NodeInputConstructors`` to dynamically determine node inputs during gen.
Args:
See ``gen_with_multiple_nodes`` documentation
+ grs_this_gen: A running list of generator runs produced during this
call to gen_with_multiple_nodes. Currently needed by some input
constructors.
Returns:
A dictionary mapping the name of all kwargs kwargs passed into
``gen_with_multiple_nodes`` to their values.
"""
return {
"experiment": experiment,
"grs_this_gen": grs_this_gen,
"data": data,
"n": n,
"fixed_features": fixed_features,
"arms_per_node": arms_per_node,
"pending_observations": pending_observations,
}

def _determine_fixed_features_from_node(
self,
node_to_gen_from: GenerationNode,
gen_kwargs: dict[str, Any],
passed_fixed_features: ObservationFeatures | None = None,
) -> ObservationFeatures | None:
"""Uses the ``InputConstructors`` on the node to determine the fixed features
to pass into the model. If fixed_features are provided, the will take
Expand All @@ -981,8 +1012,7 @@ def _determine_fixed_features_from_node(
Args:
node_to_gen_from: The node from which to generate from
gen_kwargs: The kwargs passed to the ``GenerationStrategy``'s
gen call.
passed_fixed_features: The fixed features passed to the ``gen`` method if
gen call, including the fixed features passed to the ``gen`` method if
any.
Returns:
Expand All @@ -991,6 +1021,7 @@ def _determine_fixed_features_from_node(
"""
# passed_fixed_features represents the fixed features that were passed by the
# user to the gen method as overrides.
passed_fixed_features = gen_kwargs.get("fixed_features")
if passed_fixed_features is not None:
return passed_fixed_features

Expand All @@ -1014,8 +1045,18 @@ def _determine_sq_features_from_node(
node_to_gen_from: GenerationNode,
gen_kwargs: dict[str, Any],
) -> ObservationFeatures | None:
"""todo"""
# TODO: @mgarrard to merge the input constructor logic into a single method
"""Uses the ``InputConstructors`` on the node to determine the status quo
features to pass into the model.
Args:
node_to_gen_from: The node from which to generate from
gen_kwargs: The kwargs passed to the ``GenerationStrategy``'s
gen call.
Returns:
An object of ObservationFeatures that represents the status quo features
to pass into the model.
"""
node_sq_features = None
if (
InputConstructorPurpose.STATUS_QUO_FEATURES
Expand All @@ -1036,7 +1077,6 @@ def _determine_arms_from_node(
node_to_gen_from: GenerationNode,
gen_kwargs: dict[str, Any],
n: int | None = None,
arms_per_node: dict[str, int] | None = None,
) -> int:
"""Calculates the number of arms to generate from the node that will be used
during generation.
Expand All @@ -1049,16 +1089,17 @@ def _determine_arms_from_node(
arms that can differ from `n`.
node_to_gen_from: The node from which to generate from
gen_kwargs: The kwargs passed to the ``GenerationStrategy``'s
gen call.
arms_per_node: An optional map from node name to the number of arms to
generate from that node. If not provided, will default to the number
of arms specified in the node's ``InputConstructors`` or n if no
``InputConstructors`` are defined on the node.
gen call, including arms_per_node: an optional map from node name to
the number of arms to generate from that node. If not provided, will
default to the numberof arms specified in the node's
``InputConstructors`` or n if no``InputConstructors`` are defined on
the node.
Returns:
The number of arms to generate from the node that will be used during this
generation via ``_gen_multiple``.
"""
arms_per_node = gen_kwargs.get("arms_per_node")
if arms_per_node is not None:
# arms_per_node provides a way to manually override input
# constructors. This should be used with caution, and only
Expand Down

0 comments on commit 99a2561

Please sign in to comment.