Skip to content

Commit

Permalink
deploy: d23367d
Browse files Browse the repository at this point in the history
  • Loading branch information
mhavasi committed Dec 11, 2024
1 parent 159eb56 commit 3c87c7f
Show file tree
Hide file tree
Showing 43 changed files with 490 additions and 489 deletions.
31 changes: 15 additions & 16 deletions _modules/flow_matching/path/affine.html
Original file line number Diff line number Diff line change
Expand Up @@ -393,9 +393,9 @@ <h1>Source code for flow_matching.path.affine</h1><div class="highlight"><pre>
<span class="sd"> | return :math:`X_0, X_1, X_t = \alpha_t X_1 + \sigma_t X_0`, and the conditional velocity at :math:`X_t, \dot{X}_t = \dot{\alpha}_t X_1 + \dot{\sigma}_t X_0`.</span>

<span class="sd"> Args:</span>
<span class="sd"> x_0 (Tensor): source data point, shape (Batch, ...).</span>
<span class="sd"> x_1 (Tensor): target data point, shape (Batch, ...).</span>
<span class="sd"> t (Tensor, optional): times in [0,1], shape (Batch).</span>
<span class="sd"> x_0 (Tensor): source data point, shape (batch_size, ...).</span>
<span class="sd"> x_1 (Tensor): target data point, shape (batch_size, ...).</span>
<span class="sd"> t (Tensor): times in [0,1], shape (batch_size).</span>

<span class="sd"> Returns:</span>
<span class="sd"> PathSample: a conditional sample at :math:`X_t \sim p_t`.</span>
Expand All @@ -404,19 +404,18 @@ <h1>Source code for flow_matching.path.affine</h1><div class="highlight"><pre>

<span class="n">scheduler_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scheduler</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>

<span class="k">if</span> <span class="n">t</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">alpha_t</span> <span class="o">=</span> <span class="n">expand_tensor_like</span><span class="p">(</span>
<span class="n">input_tensor</span><span class="o">=</span><span class="n">scheduler_output</span><span class="o">.</span><span class="n">alpha_t</span><span class="p">,</span> <span class="n">expand_to</span><span class="o">=</span><span class="n">x_1</span>
<span class="p">)</span>
<span class="n">sigma_t</span> <span class="o">=</span> <span class="n">expand_tensor_like</span><span class="p">(</span>
<span class="n">input_tensor</span><span class="o">=</span><span class="n">scheduler_output</span><span class="o">.</span><span class="n">sigma_t</span><span class="p">,</span> <span class="n">expand_to</span><span class="o">=</span><span class="n">x_1</span>
<span class="p">)</span>
<span class="n">d_alpha_t</span> <span class="o">=</span> <span class="n">expand_tensor_like</span><span class="p">(</span>
<span class="n">input_tensor</span><span class="o">=</span><span class="n">scheduler_output</span><span class="o">.</span><span class="n">d_alpha_t</span><span class="p">,</span> <span class="n">expand_to</span><span class="o">=</span><span class="n">x_1</span>
<span class="p">)</span>
<span class="n">d_sigma_t</span> <span class="o">=</span> <span class="n">expand_tensor_like</span><span class="p">(</span>
<span class="n">input_tensor</span><span class="o">=</span><span class="n">scheduler_output</span><span class="o">.</span><span class="n">d_sigma_t</span><span class="p">,</span> <span class="n">expand_to</span><span class="o">=</span><span class="n">x_1</span>
<span class="p">)</span>
<span class="n">alpha_t</span> <span class="o">=</span> <span class="n">expand_tensor_like</span><span class="p">(</span>
<span class="n">input_tensor</span><span class="o">=</span><span class="n">scheduler_output</span><span class="o">.</span><span class="n">alpha_t</span><span class="p">,</span> <span class="n">expand_to</span><span class="o">=</span><span class="n">x_1</span>
<span class="p">)</span>
<span class="n">sigma_t</span> <span class="o">=</span> <span class="n">expand_tensor_like</span><span class="p">(</span>
<span class="n">input_tensor</span><span class="o">=</span><span class="n">scheduler_output</span><span class="o">.</span><span class="n">sigma_t</span><span class="p">,</span> <span class="n">expand_to</span><span class="o">=</span><span class="n">x_1</span>
<span class="p">)</span>
<span class="n">d_alpha_t</span> <span class="o">=</span> <span class="n">expand_tensor_like</span><span class="p">(</span>
<span class="n">input_tensor</span><span class="o">=</span><span class="n">scheduler_output</span><span class="o">.</span><span class="n">d_alpha_t</span><span class="p">,</span> <span class="n">expand_to</span><span class="o">=</span><span class="n">x_1</span>
<span class="p">)</span>
<span class="n">d_sigma_t</span> <span class="o">=</span> <span class="n">expand_tensor_like</span><span class="p">(</span>
<span class="n">input_tensor</span><span class="o">=</span><span class="n">scheduler_output</span><span class="o">.</span><span class="n">d_sigma_t</span><span class="p">,</span> <span class="n">expand_to</span><span class="o">=</span><span class="n">x_1</span>
<span class="p">)</span>

<span class="c1"># construct xt ~ p_t(x|x1).</span>
<span class="n">x_t</span> <span class="o">=</span> <span class="n">sigma_t</span> <span class="o">*</span> <span class="n">x_0</span> <span class="o">+</span> <span class="n">alpha_t</span> <span class="o">*</span> <span class="n">x_1</span>
Expand Down
10 changes: 4 additions & 6 deletions _modules/flow_matching/path/geodesic.html
Original file line number Diff line number Diff line change
Expand Up @@ -406,17 +406,15 @@ <h1>Source code for flow_matching.path.geodesic</h1><div class="highlight"><pre>
<span class="sd"> | return :math:`X_0, X_1, X_t = \exp_{X_1}(\kappa_t \log_{X_1}(X_0))`, and the conditional velocity at :math:`X_t, \dot{X}_t`.</span>

<span class="sd"> Args:</span>
<span class="sd"> x_0 (Tensor): source data point, shape (Batch, ...).</span>
<span class="sd"> x_1 (Tensor): target data point, shape (Batch, ...).</span>
<span class="sd"> t (Tensor, optional): times in [0,1], shape (Batch).</span>
<span class="sd"> x_0 (Tensor): source data point, shape (batch_size, ...).</span>
<span class="sd"> x_1 (Tensor): target data point, shape (batch_size, ...).</span>
<span class="sd"> t (Tensor): times in [0,1], shape (batch_size).</span>

<span class="sd"> Returns:</span>
<span class="sd"> PathSample: A conditional sample at :math:`X_t \sim p_t`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">assert_sample_shape</span><span class="p">(</span><span class="n">x_0</span><span class="o">=</span><span class="n">x_0</span><span class="p">,</span> <span class="n">x_1</span><span class="o">=</span><span class="n">x_1</span><span class="p">,</span> <span class="n">t</span><span class="o">=</span><span class="n">t</span><span class="p">)</span>

<span class="k">if</span> <span class="n">t</span><span class="o">.</span><span class="n">ndim</span> <span class="o">&lt;=</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">expand_tensor_like</span><span class="p">(</span><span class="n">input_tensor</span><span class="o">=</span><span class="n">t</span><span class="p">,</span> <span class="n">expand_to</span><span class="o">=</span><span class="n">x_1</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">expand_tensor_like</span><span class="p">(</span><span class="n">input_tensor</span><span class="o">=</span><span class="n">t</span><span class="p">,</span> <span class="n">expand_to</span><span class="o">=</span><span class="n">x_1</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">0</span><span class="p">:</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>

<span class="k">def</span> <span class="nf">cond_u</span><span class="p">(</span><span class="n">x_0</span><span class="p">,</span> <span class="n">x_1</span><span class="p">,</span> <span class="n">t</span><span class="p">):</span>
<span class="n">path</span> <span class="o">=</span> <span class="n">geodesic</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">manifold</span><span class="p">,</span> <span class="n">x_0</span><span class="p">,</span> <span class="n">x_1</span><span class="p">)</span>
Expand Down
9 changes: 4 additions & 5 deletions _modules/flow_matching/path/mixture.html
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,9 @@ <h1>Source code for flow_matching.path.mixture</h1><div class="highlight"><pre>
<span class="sd"> | given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`(\alpha_t,\sigma_t)`.</span>
<span class="sd"> | return :math:`X_0, X_1, t`, and :math:`X_t \sim p_t`.</span>
<span class="sd"> Args:</span>
<span class="sd"> x_0 (Tensor): source data point, shape (Batch, ...).</span>
<span class="sd"> x_1 (Tensor): target data point, shape (Batch, ...).</span>
<span class="sd"> t (Tensor): times in [0,1], shape (Batch).</span>
<span class="sd"> x_0 (Tensor): source data point, shape (batch_size, ...).</span>
<span class="sd"> x_1 (Tensor): target data point, shape (batch_size, ...).</span>
<span class="sd"> t (Tensor): times in [0,1], shape (batch_size).</span>

<span class="sd"> Returns:</span>
<span class="sd"> DiscretePathSample: a conditional sample at :math:`X_t ~ p_t`.</span>
Expand All @@ -413,8 +413,7 @@ <h1>Source code for flow_matching.path.mixture</h1><div class="highlight"><pre>

<span class="n">sigma_t</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">scheduler</span><span class="p">(</span><span class="n">t</span><span class="p">)</span><span class="o">.</span><span class="n">sigma_t</span>

<span class="k">if</span> <span class="n">t</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="n">sigma_t</span> <span class="o">=</span> <span class="n">expand_tensor_like</span><span class="p">(</span><span class="n">input_tensor</span><span class="o">=</span><span class="n">sigma_t</span><span class="p">,</span> <span class="n">expand_to</span><span class="o">=</span><span class="n">x_1</span><span class="p">)</span>
<span class="n">sigma_t</span> <span class="o">=</span> <span class="n">expand_tensor_like</span><span class="p">(</span><span class="n">input_tensor</span><span class="o">=</span><span class="n">sigma_t</span><span class="p">,</span> <span class="n">expand_to</span><span class="o">=</span><span class="n">x_1</span><span class="p">)</span>

<span class="n">source_indices</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="n">x_1</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">x_1</span><span class="o">.</span><span class="n">device</span><span class="p">)</span> <span class="o">&lt;</span> <span class="n">sigma_t</span>
<span class="n">x_t</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">condition</span><span class="o">=</span><span class="n">source_indices</span><span class="p">,</span> <span class="nb">input</span><span class="o">=</span><span class="n">x_0</span><span class="p">,</span> <span class="n">other</span><span class="o">=</span><span class="n">x_1</span><span class="p">)</span>
Expand Down
9 changes: 6 additions & 3 deletions _modules/flow_matching/path/path.html
Original file line number Diff line number Diff line change
Expand Up @@ -376,16 +376,19 @@ <h1>Source code for flow_matching.path.path</h1><div class="highlight"><pre>
<span class="sd"> | returns :math:`X_0, X_1, X_t \sim p_t(X_t)`, and a conditional target :math:`Y`, all objects are under ``PathSample``.</span>

<span class="sd"> Args:</span>
<span class="sd"> x_0 (Tensor): source data point, shape (Batch, ...).</span>
<span class="sd"> x_1 (Tensor): target data point, shape (Batch, ...).</span>
<span class="sd"> t (Tensor, optional): times in [0,1], shape (Batch).</span>
<span class="sd"> x_0 (Tensor): source data point, shape (batch_size, ...).</span>
<span class="sd"> x_1 (Tensor): target data point, shape (batch_size, ...).</span>
<span class="sd"> t (Tensor): times in [0,1], shape (batch_size).</span>

<span class="sd"> Returns:</span>
<span class="sd"> PathSample: a conditional sample.</span>
<span class="sd"> &quot;&quot;&quot;</span></div>


<span class="k">def</span> <span class="nf">assert_sample_shape</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x_0</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">x_1</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">):</span>
<span class="k">assert</span> <span class="p">(</span>
<span class="n">t</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">1</span>
<span class="p">),</span> <span class="sa">f</span><span class="s2">&quot;The time vector t must have shape [batch_size]. Got </span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">.&quot;</span>
<span class="k">assert</span> <span class="p">(</span>
<span class="n">t</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">x_0</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">x_1</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="p">),</span> <span class="sa">f</span><span class="s2">&quot;Time t dimension must match the batch size [</span><span class="si">{</span><span class="n">x_1</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">]. Got </span><span class="si">{</span><span class="n">t</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">&quot;</span></div>
Expand Down
Loading

0 comments on commit 3c87c7f

Please sign in to comment.