diff --git a/flow_matching/solver/ode_solver.py b/flow_matching/solver/ode_solver.py index 7c6b69f..52e1379 100644 --- a/flow_matching/solver/ode_solver.py +++ b/flow_matching/solver/ode_solver.py @@ -172,16 +172,15 @@ def dynamics_func(t, states): y_init = (x_1, torch.zeros(x_1.shape[0], device=x_1.device)) ode_opts = {"step_size": step_size} if step_size is not None else {} - with torch.no_grad(): - sol, log_det = odeint( - dynamics_func, - y_init, - time_grid, - method=method, - options=ode_opts, - atol=atol, - rtol=rtol, - ) + sol, log_det = odeint( + dynamics_func, + y_init, + time_grid, + method=method, + options=ode_opts, + atol=atol, + rtol=rtol, + ) x_source = sol[-1] source_log_p = log_p0(x_source)