Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FR] Support for optax.contrib.reduce_on_plateau #1955

Open
zmbc opened this issue Jan 22, 2025 · 0 comments
Open

[FR] Support for optax.contrib.reduce_on_plateau #1955

zmbc opened this issue Jan 22, 2025 · 0 comments

Comments

@zmbc
Copy link

zmbc commented Jan 22, 2025

For SVI, learning rate is extremely influential, see e.g. this discussion post: https://forum.pyro.ai/t/does-svi-converges-towards-the-right-solution-4-parameters-mvn/3677/4

The guidance there is to just play around with learning rate until you get convergence, but this is both expensive and annoying to attempt programmatically (e.g. when fitting many models for cross-validation).

Optax contains a learning rate scheduler for this that works really well, but it isn't currently easy to use this in NumPyro because it takes the current loss as an extra argument.

Here is some code that does it, based on slight modifications to optax_to_numpyro and _NumPyroOptim:

from optax.contrib import reduce_on_plateau
import optax
from numpyro.optim import _NumPyroOptim, _Params, _IterOptState, _value_and_grad
from jax.typing import ArrayLike
from collections.abc import Callable
from typing import Any

class _NumPyroOptimValueArg(_NumPyroOptim):
    def update(self, g: _Params, state: _IterOptState, value) -> _IterOptState:
        """
        Gradient update for the optimizer.

        :param g: gradient information for parameters.
        :param state: current optimizer state.
        :return: new optimizer state after the update.
        """
        i, opt_state = state
        opt_state = self.update_fn(i, g, opt_state, value=value)
        return i + 1, opt_state

    def eval_and_update(
        self,
        fn: Callable[[Any], tuple],
        state: _IterOptState,
        forward_mode_differentiation: bool = False,
    ) -> tuple[tuple[Any, Any], _IterOptState]:
        """
        Performs an optimization step for the objective function `fn`.
        For most optimizers, the update is performed based on the gradient
        of the objective function w.r.t. the current state. However, for
        some optimizers such as :class:`Minimize`, the update is performed
        by reevaluating the function multiple times to get optimal
        parameters.

        :param fn: an objective function returning a pair where the first item
            is a scalar loss function to be differentiated and the second item
            is an auxiliary output.
        :param state: current optimizer state.
        :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation.
        :return: a pair of the output of objective function and the new optimizer state.
        """
        params: _Params = self.get_params(state)
        (out, aux), grads = _value_and_grad(
            fn, x=params, forward_mode_differentiation=forward_mode_differentiation
        )
        return (out, aux), self.update(grads, state, value=out)

    def eval_and_stable_update(
        self,
        fn: Callable[[Any], tuple],
        state: _IterOptState,
        forward_mode_differentiation: bool = False,
    ) -> tuple[tuple[Any, Any], _IterOptState]:
        """
        Like :meth:`eval_and_update` but when the value of the objective function
        or the gradients are not finite, we will not update the input `state`
        and will set the objective output to `nan`.

        :param fn: objective function.
        :param state: current optimizer state.
        :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation.
        :return: a pair of the output of objective function and the new optimizer state.
        """
        params: _Params = self.get_params(state)
        (out, aux), grads = _value_and_grad(
            fn, x=params, forward_mode_differentiation=forward_mode_differentiation
        )
        out, state = lax.cond(
            jnp.isfinite(out) & jnp.isfinite(ravel_pytree(grads)[0]).all(),
            lambda _: (out, self.update(grads, state, value=out)),
            lambda _: (jnp.nan, state),
            None,
        )
        return (out, aux), state

def optax_to_numpyro_value_arg(transformation) -> _NumPyroOptimValueArg:
    """
    This function produces a ``numpyro.optim._NumPyroOptim`` instance from an
    ``optax.GradientTransformation`` so that it can be used with
    ``numpyro.infer.svi.SVI``. It is a lightweight wrapper that recreates the
    ``(init_fn, update_fn, get_params_fn)`` interface defined by
    :mod:`jax.example_libraries.optimizers`.

    :param transformation: An ``optax.GradientTransformation`` instance to wrap.
    :return: An instance of ``numpyro.optim._NumPyroOptim`` wrapping the supplied
        Optax optimizer.
    """
    import optax

    def init_fn(params: _Params) -> tuple[_Params, Any]:
        opt_state = transformation.init(params)
        return params, opt_state

    def update_fn(
        step: ArrayLike, grads: ArrayLike, state: tuple[_Params, Any], value
    ) -> tuple[_Params, Any]:
        params, opt_state = state
        updates, opt_state = transformation.update(grads, opt_state, params, value=value)
        updated_params = optax.apply_updates(params, updates)
        return updated_params, opt_state

    def get_params_fn(state: tuple[_Params, Any]) -> _Params:
        params, _ = state
        return params

    return _NumPyroOptimValueArg(lambda x, y, z: (x, y, z), init_fn, update_fn, get_params_fn)

Then you can run e.g. the SVI example from the docs with:

from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import Predictive, SVI, Trace_ELBO

def model(data):
    f = numpyro.sample("latent_fairness", dist.Beta(10, 10))
    with numpyro.plate("N", data.shape[0] if data is not None else 10):
        numpyro.sample("obs", dist.Bernoulli(f), obs=data)

def guide(data):
    alpha_q = numpyro.param("alpha_q", 15., constraint=constraints.positive)
    beta_q = numpyro.param("beta_q", lambda rng_key: random.exponential(rng_key),
                           constraint=constraints.positive)
    numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

data = jnp.concatenate([jnp.ones(6), jnp.zeros(4)])
optimizer = optax_to_numpyro_value_arg(optax.chain(
    optax.adam(0.01),
    reduce_on_plateau(
        cooldown=100, accumulation_size=100, patience=200,
    ),
))
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 2000, data)
params = svi_result.params
inferred_mean = params["alpha_q"] / (params["alpha_q"] + params["beta_q"])
# use guide to make predictive
predictive = Predictive(model, guide=guide, params=params, num_samples=1000)
samples = predictive(random.PRNGKey(1), data=None)
# get posterior samples
predictive = Predictive(guide, params=params, num_samples=1000)
posterior_samples = predictive(random.PRNGKey(1), data=None)
# use posterior samples to make predictive
predictive = Predictive(model, posterior_samples, params=params, num_samples=1000)
samples = predictive(random.PRNGKey(1), data=None)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant