diff --git a/examples/2d_flow_matching.ipynb b/examples/2d_flow_matching.ipynb index f7ff170..98e0057 100644 --- a/examples/2d_flow_matching.ipynb +++ b/examples/2d_flow_matching.ipynb @@ -354,7 +354,7 @@ "metadata": {}, "outputs": [], "source": [ - "from torch.distributions.multivariate_normal import MultivariateNormal" + "from torch.distributions import Independent, Normal" ] }, { @@ -379,8 +379,8 @@ "metadata": {}, "outputs": [], "source": [ - "# source distribution is a gaussian\n", - "gaussian_log_density = MultivariateNormal(torch.zeros(2, device=device), torch.eye(2, device=device)).log_prob\n", + "# source distribution is an isotropic gaussian\n", + "gaussian_log_density = Independent(Normal(torch.zeros(2, device=device), torch.ones(2, device=device)), 1).log_prob\n", "\n", "# compute log likelihood with unbiased hutchinson estimator, average over num_acc\n", "num_acc = 10\n",