Skip to content

Commit

Permalink
typing around ThompsonSampler and its subclasses (#3252)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3252

- Added types in ThompsonSampler and EBAshr.
- Added warnings in ThompsonSampler and EBAshr when the objective is a multi-objective.
Both ThompsonSampler and EBAshr currently handle multi-objective in the same way by multiplying the samples in the case of TS and shrunken estimates in the case of EBAshr with the objective weights that are +/-1 for metrics in the multi-objective. Given different scaling across metrics, multiplying the metric values with the +/-1 weights does not have a meaningful value.

Reviewed By: saitcakmak

Differential Revision: D68297261

fbshipit-source-id: 4271cae072b6f6dfd774274ce3d69679b7403ce0
  • Loading branch information
Jelena Markovic-Voronov authored and facebook-github-bot committed Jan 18, 2025
1 parent 99a2561 commit c3b8285
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 28 deletions.
40 changes: 26 additions & 14 deletions ax/models/discrete/thompson.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import hashlib
import json
import warnings
from collections.abc import Iterable, Mapping, Sequence

import numpy as np
Expand All @@ -18,7 +19,8 @@
from ax.models.discrete_base import DiscreteModel
from ax.models.types import TConfig
from ax.utils.common.docutils import copy_doc
from pyre_extensions import none_throws

from pyre_extensions import assert_is_instance, none_throws


class ThompsonSampler(DiscreteModel):
Expand Down Expand Up @@ -49,12 +51,11 @@ def __init__(
self.uniform_weights = uniform_weights

self.X: Sequence[Sequence[TParamValue]] | None = None
# pyre-fixme[4]: Attribute must be annotated.
self.Ys = None
# pyre-fixme[4]: Attribute must be annotated.
self.Yvars = None
# pyre-fixme[4]: Attribute must be annotated.
self.X_to_Ys_and_Yvars = None
self.Ys: Sequence[Sequence[float]] | None = None
self.Yvars: Sequence[Sequence[float]] | None = None
self.X_to_Ys_and_Yvars: (
list[dict[TParamValueList, tuple[float, float]]] | None
) = None

@copy_doc(DiscreteModel.fit)
def fit(
Expand All @@ -70,7 +71,9 @@ def fit(
Ys=Ys, Yvars=Yvars, outcome_names=outcome_names
)
self.X_to_Ys_and_Yvars = self._fit_X_to_Ys_and_Yvars(
X=none_throws(self.X), Ys=self.Ys, Yvars=self.Yvars
X=none_throws(self.X),
Ys=none_throws(self.Ys),
Yvars=none_throws(self.Yvars),
)

@copy_doc(DiscreteModel.gen)
Expand All @@ -87,6 +90,13 @@ def gen(
if objective_weights is None:
raise ValueError("ThompsonSampler requires objective weights.")

if np.sum(abs(objective_weights) > 0) > 1:
warnings.warn(
"In case of multi-objective adding metric values together might"
" not lead to a meaningful result.",
stacklevel=2,
)

arms = none_throws(self.X)
k = len(arms)

Expand Down Expand Up @@ -135,19 +145,21 @@ def predict(
self, X: Sequence[Sequence[TParamValue]]
) -> tuple[npt.NDArray, npt.NDArray]:
n = len(X) # number of parameterizations at which to make predictions
m = len(self.Ys) # number of outcomes
m = len(none_throws(self.Ys)) # number of outcomes
f = np.zeros((n, m)) # array of outcome predictions
cov = np.zeros((n, m, m)) # array of predictive covariances
predictX = [self._hash_TParamValueList(x) for x in X]
for i, X_to_Y_and_Yvar in enumerate(self.X_to_Ys_and_Yvars):
for i, X_to_Y_and_Yvar in enumerate(none_throws(self.X_to_Ys_and_Yvars)):
# iterate through outcomes
for j, x in enumerate(predictX):
# iterate through parameterizations at which to make predictions
if x not in X_to_Y_and_Yvar:
raise ValueError(
"ThompsonSampler does not support out-of-sample prediction."
)
f[j, i], cov[j, i, i] = X_to_Y_and_Yvar[x]
f[j, i], cov[j, i, i] = X_to_Y_and_Yvar[
assert_is_instance(x, TParamValue)
]
return f, cov

def _generate_weights(
Expand Down Expand Up @@ -187,10 +199,10 @@ def _generate_weights(
def _generate_samples_per_metric(self, num_samples: int) -> npt.NDArray:
k = len(none_throws(self.X))
samples_per_metric = np.zeros(
(k, num_samples, len(self.Ys))
(k, num_samples, len(none_throws(self.Ys)))
) # k x num_samples x m
for i, Y in enumerate(self.Ys): # (k x 1)
Yvar = self.Yvars[i] # (k x 1)
for i, Y in enumerate(none_throws(self.Ys)): # (k x 1)
Yvar = none_throws(self.Yvars)[i] # (k x 1)
cov = np.diag(Yvar) # (k x k)
samples = np.random.multivariate_normal(
Y, cov, num_samples
Expand Down
4 changes: 4 additions & 0 deletions ax/models/tests/test_eb_thompson.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def test_EmpiricalBayesThompsonSamplerFit(self) -> None:
outcome_names=self.outcome_names,
)
self.assertEqual(generator.X, self.Xs[0])
print(generator.Ys)
print(
np.array([[1.3, 2.1, 2.9, 3.7], [0.25, 0.25, 0.25, 0.25]]),
)
self.assertTrue(
np.allclose(
np.array(generator.Ys),
Expand Down
24 changes: 24 additions & 0 deletions ax/models/tests/test_thompson.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# pyre-strict


from warnings import catch_warnings

import numpy as np
from ax.exceptions.model import ModelError
from ax.models.discrete.thompson import ThompsonSampler
Expand Down Expand Up @@ -199,3 +201,25 @@ def test_ThompsonSamplerPredict(self) -> None:

with self.assertRaises(ValueError):
generator.predict([[1, 2]])

def test_ThompsonSamplerMultiObjectiveWarning(self) -> None:
generator = ThompsonSampler(min_weight=0.0)
generator.fit(
Xs=self.multiple_metrics_Xs,
Ys=self.multiple_metrics_Ys,
Yvars=self.multiple_metrics_Yvars,
parameter_values=self.parameter_values,
outcome_names=self.outcome_names,
)
with catch_warnings(record=True) as warning_list:
arms, weights, _ = generator.gen(
n=4,
parameter_values=self.parameter_values,
objective_weights=np.array([1, -1]),
outcome_constraints=None,
)
self.assertEqual(
"In case of multi-objective adding metric values together might"
" not lead to a meaningful result.",
str(warning_list[0].message),
)
19 changes: 5 additions & 14 deletions ax/utils/stats/statstools.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,13 @@ def positive_part_james_stein(
sigma2_i = np.power(sems, 2)
ybar = np.mean(y_i)
s2 = np.var(y_i - ybar, ddof=3) # sample variance normalized by K-3
if s2 == 0:
phi_i = 1
else:
phi_i = np.minimum(1, sigma2_i / s2)
# pyre-fixme[6]: For 1st argument expected `int` but got `floating[typing.Any]`.
# pyre-fixme[6]: For 1st argument expected `bool` but got `ndarray[typing.Any,
# dtype[typing.Any]]`.
mu_hat_i = y_i + phi_i * (ybar - y_i)
phi_i = np.ones_like(sigma2_i) if s2 == 0 else np.minimum(1, sigma2_i / s2)
mu_hat_i = y_i + phi_i * np.subtract(ybar, y_i)

sigma_hat_i = np.sqrt(
# pyre-fixme[58]: `-` is not supported for operand types `int` and
# `Union[np.ndarray[typing.Any, np.dtype[typing.Any]], int]`.
(1 - phi_i) * sigma2_i
np.subtract(1.0, phi_i) * sigma2_i
+ phi_i * sigma2_i / K
# pyre-fixme[58]: `*` is not supported for operand types `int` and
# `Union[np.ndarray[typing.Any, np.dtype[typing.Any]], int]`.
+ 2 * phi_i**2 * (y_i - ybar) ** 2 / (K - 3)
+ np.multiply(2, phi_i**2) * (y_i - ybar) ** 2 / (K - 3)
)
return mu_hat_i, sigma_hat_i

Expand Down

0 comments on commit c3b8285

Please sign in to comment.