You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The guidance there is to just play around with learning rate until you get convergence, but this is both expensive and annoying to attempt programmatically (e.g. when fitting many models for cross-validation).
Optax contains a learning rate scheduler for this that works really well, but it isn't currently easy to use this in NumPyro because it takes the current loss as an extra argument.
fromoptax.contribimportreduce_on_plateauimportoptaxfromnumpyro.optimimport_NumPyroOptim, _Params, _IterOptState, _value_and_gradfromjax.typingimportArrayLikefromcollections.abcimportCallablefromtypingimportAnyclass_NumPyroOptimValueArg(_NumPyroOptim):
defupdate(self, g: _Params, state: _IterOptState, value) ->_IterOptState:
""" Gradient update for the optimizer. :param g: gradient information for parameters. :param state: current optimizer state. :return: new optimizer state after the update. """i, opt_state=stateopt_state=self.update_fn(i, g, opt_state, value=value)
returni+1, opt_statedefeval_and_update(
self,
fn: Callable[[Any], tuple],
state: _IterOptState,
forward_mode_differentiation: bool=False,
) ->tuple[tuple[Any, Any], _IterOptState]:
""" Performs an optimization step for the objective function `fn`. For most optimizers, the update is performed based on the gradient of the objective function w.r.t. the current state. However, for some optimizers such as :class:`Minimize`, the update is performed by reevaluating the function multiple times to get optimal parameters. :param fn: an objective function returning a pair where the first item is a scalar loss function to be differentiated and the second item is an auxiliary output. :param state: current optimizer state. :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation. :return: a pair of the output of objective function and the new optimizer state. """params: _Params=self.get_params(state)
(out, aux), grads=_value_and_grad(
fn, x=params, forward_mode_differentiation=forward_mode_differentiation
)
return (out, aux), self.update(grads, state, value=out)
defeval_and_stable_update(
self,
fn: Callable[[Any], tuple],
state: _IterOptState,
forward_mode_differentiation: bool=False,
) ->tuple[tuple[Any, Any], _IterOptState]:
""" Like :meth:`eval_and_update` but when the value of the objective function or the gradients are not finite, we will not update the input `state` and will set the objective output to `nan`. :param fn: objective function. :param state: current optimizer state. :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation. :return: a pair of the output of objective function and the new optimizer state. """params: _Params=self.get_params(state)
(out, aux), grads=_value_and_grad(
fn, x=params, forward_mode_differentiation=forward_mode_differentiation
)
out, state=lax.cond(
jnp.isfinite(out) &jnp.isfinite(ravel_pytree(grads)[0]).all(),
lambda_: (out, self.update(grads, state, value=out)),
lambda_: (jnp.nan, state),
None,
)
return (out, aux), statedefoptax_to_numpyro_value_arg(transformation) ->_NumPyroOptimValueArg:
""" This function produces a ``numpyro.optim._NumPyroOptim`` instance from an ``optax.GradientTransformation`` so that it can be used with ``numpyro.infer.svi.SVI``. It is a lightweight wrapper that recreates the ``(init_fn, update_fn, get_params_fn)`` interface defined by :mod:`jax.example_libraries.optimizers`. :param transformation: An ``optax.GradientTransformation`` instance to wrap. :return: An instance of ``numpyro.optim._NumPyroOptim`` wrapping the supplied Optax optimizer. """importoptaxdefinit_fn(params: _Params) ->tuple[_Params, Any]:
opt_state=transformation.init(params)
returnparams, opt_statedefupdate_fn(
step: ArrayLike, grads: ArrayLike, state: tuple[_Params, Any], value
) ->tuple[_Params, Any]:
params, opt_state=stateupdates, opt_state=transformation.update(grads, opt_state, params, value=value)
updated_params=optax.apply_updates(params, updates)
returnupdated_params, opt_statedefget_params_fn(state: tuple[_Params, Any]) ->_Params:
params, _=statereturnparamsreturn_NumPyroOptimValueArg(lambdax, y, z: (x, y, z), init_fn, update_fn, get_params_fn)
Then you can run e.g. the SVI example from the docs with:
For SVI, learning rate is extremely influential, see e.g. this discussion post: https://forum.pyro.ai/t/does-svi-converges-towards-the-right-solution-4-parameters-mvn/3677/4
The guidance there is to just play around with learning rate until you get convergence, but this is both expensive and annoying to attempt programmatically (e.g. when fitting many models for cross-validation).
Optax contains a learning rate scheduler for this that works really well, but it isn't currently easy to use this in NumPyro because it takes the current loss as an extra argument.
Here is some code that does it, based on slight modifications to
optax_to_numpyro
and_NumPyroOptim
:Then you can run e.g. the SVI example from the docs with:
The text was updated successfully, but these errors were encountered: