Skip to content

Commit

Permalink
Assert t shape in affine path (#6)
Browse files Browse the repository at this point in the history
* ensure one dimensional t

* grammar nit

* precommit

* added t shape assert to path class. docs nit: batch_size instead of Batch.

* docs phrasing nit

---------

Co-authored-by: Marton Havasi <[email protected]>
  • Loading branch information
mhavasi and Marton Havasi authored Dec 11, 2024
1 parent 9d2160b commit d23367d
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 47 deletions.
31 changes: 15 additions & 16 deletions flow_matching/path/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample:
| return :math:`X_0, X_1, X_t = \alpha_t X_1 + \sigma_t X_0`, and the conditional velocity at :math:`X_t, \dot{X}_t = \dot{\alpha}_t X_1 + \dot{\sigma}_t X_0`.
Args:
x_0 (Tensor): source data point, shape (Batch, ...).
x_1 (Tensor): target data point, shape (Batch, ...).
t (Tensor, optional): times in [0,1], shape (Batch).
x_0 (Tensor): source data point, shape (batch_size, ...).
x_1 (Tensor): target data point, shape (batch_size, ...).
t (Tensor): times in [0,1], shape (batch_size).
Returns:
PathSample: a conditional sample at :math:`X_t \sim p_t`.
Expand All @@ -72,19 +72,18 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample:

scheduler_output = self.scheduler(t)

if t.ndim == 1:
alpha_t = expand_tensor_like(
input_tensor=scheduler_output.alpha_t, expand_to=x_1
)
sigma_t = expand_tensor_like(
input_tensor=scheduler_output.sigma_t, expand_to=x_1
)
d_alpha_t = expand_tensor_like(
input_tensor=scheduler_output.d_alpha_t, expand_to=x_1
)
d_sigma_t = expand_tensor_like(
input_tensor=scheduler_output.d_sigma_t, expand_to=x_1
)
alpha_t = expand_tensor_like(
input_tensor=scheduler_output.alpha_t, expand_to=x_1
)
sigma_t = expand_tensor_like(
input_tensor=scheduler_output.sigma_t, expand_to=x_1
)
d_alpha_t = expand_tensor_like(
input_tensor=scheduler_output.d_alpha_t, expand_to=x_1
)
d_sigma_t = expand_tensor_like(
input_tensor=scheduler_output.d_sigma_t, expand_to=x_1
)

# construct xt ~ p_t(x|x1).
x_t = sigma_t * x_0 + alpha_t * x_1
Expand Down
10 changes: 4 additions & 6 deletions flow_matching/path/geodesic.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,15 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample:
| return :math:`X_0, X_1, X_t = \exp_{X_1}(\kappa_t \log_{X_1}(X_0))`, and the conditional velocity at :math:`X_t, \dot{X}_t`.
Args:
x_0 (Tensor): source data point, shape (Batch, ...).
x_1 (Tensor): target data point, shape (Batch, ...).
t (Tensor, optional): times in [0,1], shape (Batch).
x_0 (Tensor): source data point, shape (batch_size, ...).
x_1 (Tensor): target data point, shape (batch_size, ...).
t (Tensor): times in [0,1], shape (batch_size).
Returns:
PathSample: A conditional sample at :math:`X_t \sim p_t`.
"""
self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t)

if t.ndim <= 1:
t = expand_tensor_like(input_tensor=t, expand_to=x_1[..., 0:1]).clone()
t = expand_tensor_like(input_tensor=t, expand_to=x_1[..., 0:1]).clone()

def cond_u(x_0, x_1, t):
path = geodesic(self.manifold, x_0, x_1)
Expand Down
9 changes: 4 additions & 5 deletions flow_matching/path/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample:
| given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`(\alpha_t,\sigma_t)`.
| return :math:`X_0, X_1, t`, and :math:`X_t \sim p_t`.
Args:
x_0 (Tensor): source data point, shape (Batch, ...).
x_1 (Tensor): target data point, shape (Batch, ...).
t (Tensor): times in [0,1], shape (Batch).
x_0 (Tensor): source data point, shape (batch_size, ...).
x_1 (Tensor): target data point, shape (batch_size, ...).
t (Tensor): times in [0,1], shape (batch_size).
Returns:
DiscretePathSample: a conditional sample at :math:`X_t ~ p_t`.
Expand All @@ -81,8 +81,7 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample:

sigma_t = self.scheduler(t).sigma_t

if t.ndim == 1:
sigma_t = expand_tensor_like(input_tensor=sigma_t, expand_to=x_1)
sigma_t = expand_tensor_like(input_tensor=sigma_t, expand_to=x_1)

source_indices = torch.rand(size=x_1.shape, device=x_1.device) < sigma_t
x_t = torch.where(condition=source_indices, input=x_0, other=x_1)
Expand Down
9 changes: 6 additions & 3 deletions flow_matching/path/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,18 @@ def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> PathSample:
| returns :math:`X_0, X_1, X_t \sim p_t(X_t)`, and a conditional target :math:`Y`, all objects are under ``PathSample``.
Args:
x_0 (Tensor): source data point, shape (Batch, ...).
x_1 (Tensor): target data point, shape (Batch, ...).
t (Tensor, optional): times in [0,1], shape (Batch).
x_0 (Tensor): source data point, shape (batch_size, ...).
x_1 (Tensor): target data point, shape (batch_size, ...).
t (Tensor): times in [0,1], shape (batch_size).
Returns:
PathSample: a conditional sample.
"""

def assert_sample_shape(self, x_0: Tensor, x_1: Tensor, t: Tensor):
assert (
t.ndim == 1
), f"The time vector t must have shape [batch_size]. Got {t.shape}."
assert (
t.shape[0] == x_0.shape[0] == x_1.shape[0]
), f"Time t dimension must match the batch size [{x_1.shape[0]}]. Got {t.shape}"
22 changes: 11 additions & 11 deletions flow_matching/path/path_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@ class PathSample:
x_1 (Tensor): the target sample :math:`X_1`.
x_0 (Tensor): the source sample :math:`X_0`.
t (Tensor): the time sample :math:`t`.
x_t (Tensor): samples :math:`X_t \sim p_t(X_t)`, shape (Batch, ...).
dx_t (Tensor): conditional target :math:`\frac{\partial X}{\partial t}`, shape: (Batch, ...).
x_t (Tensor): samples :math:`X_t \sim p_t(X_t)`, shape (batch_size, ...).
dx_t (Tensor): conditional target :math:`\frac{\partial X}{\partial t}`, shape: (batch_size, ...).
"""

x_1: Tensor = field(metadata={"help": "target samples X_1 (Batch, ...)."})
x_0: Tensor = field(metadata={"help": "source samples X_0 (Batch, ...)."})
t: Tensor = field(metadata={"help": "time samples t (Batch, ...)."})
x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."})
x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."})
t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."})
x_t: Tensor = field(
metadata={"help": "samples x_t ~ p_t(X_t), shape (Batch, ...)."}
metadata={"help": "samples x_t ~ p_t(X_t), shape (batch_size, ...)."}
)
dx_t: Tensor = field(
metadata={"help": "conditional target dX_t, shape: (Batch, ...)."}
metadata={"help": "conditional target dX_t, shape: (batch_size, ...)."}
)


Expand All @@ -45,9 +45,9 @@ class DiscretePathSample:
x_t (Tensor): the sample along the path :math:`X_t \sim p_t`.
"""

x_1: Tensor = field(metadata={"help": "target samples X_1 (Batch, ...)."})
x_0: Tensor = field(metadata={"help": "source samples X_0 (Batch, ...)."})
t: Tensor = field(metadata={"help": "time samples t (Batch, ...)."})
x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."})
x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."})
t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."})
x_t: Tensor = field(
metadata={"help": "samples X_t ~ p_t(X_t), shape (Batch, ...)."}
metadata={"help": "samples X_t ~ p_t(X_t), shape (batch_size, ...)."}
)
2 changes: 1 addition & 1 deletion flow_matching/path/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __call__(self, t: Tensor) -> SchedulerOutput:
"""Scheduler for convex paths.
Args:
t (Tensor, optional): times in [0,1], shape (...).
t (Tensor): times in [0,1], shape (...).
Returns:
SchedulerOutput: :math:`\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t`
Expand Down
4 changes: 2 additions & 2 deletions flow_matching/utils/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
| returns the model output for input x at time t, with extra information `extra`.
Args:
x (Tensor): input data to the model (Batch, ...).
t (Tensor): time (Batch).
x (Tensor): input data to the model (batch_size, ...).
t (Tensor): time (batch_size).
**extras: additional information forwarded to the model, e.g., text condition.
Returns:
Expand Down
9 changes: 6 additions & 3 deletions flow_matching/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,16 @@ def expand_tensor_like(input_tensor: Tensor, expand_to: Tensor) -> Tensor:
expand `input_tensor` to have the same shape as `expand_to` along all remaining dimensions.
Args:
input_tensor (Tensor): (B,).
expand_to (Tensor): (B, ...).
input_tensor (Tensor): (batch_size,).
expand_to (Tensor): (batch_size, ...).
Returns:
Tensor: (B, ...).
Tensor: (batch_size, ...).
"""
assert input_tensor.ndim == 1, "Input tensor must be a 1d vector."
assert (
input_tensor.shape[0] == expand_to.shape[0]
), f"The first (batch_size) dimension must match. Got shape {input_tensor.shape} and {expand_to.shape}."

dim_diff = expand_to.ndim - input_tensor.ndim

Expand Down

0 comments on commit d23367d

Please sign in to comment.