Skip to content

Commit

Permalink
add inclusive_end
Browse files Browse the repository at this point in the history
Signed-off-by: Danny <[email protected]>
  • Loading branch information
nvdreidenbach committed Jan 8, 2025
1 parent c978f26 commit 87afa0a
Show file tree
Hide file tree
Showing 3 changed files with 261 additions and 28 deletions.
42 changes: 37 additions & 5 deletions sub-packages/bionemo-moco/documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -1717,6 +1717,32 @@ class ContinuousInferenceSchedule(InferenceSchedule)

A base class for continuous time inference schedules.

<a id="mocoschedulesinference_time_schedulesContinuousInferenceSchedule__init__"></a>

#### \_\_init\_\_

```python
def __init__(nsteps: int,
inclusive_end: bool = False,
min_t: Float = 0,
padding: Float = 0,
dilation: Float = 0,
direction: Union[TimeDirection, str] = TimeDirection.UNIFIED,
device: Union[str, torch.device] = "cpu")
```

Initialize the ContinuousInferenceSchedule.

**Arguments**:

- `nsteps` _int_ - Number of time steps.
- `inclusive_end` _bool_ - If True, include the end value (1.0) in the schedule otherwise ends at 1.0-1/nsteps (default is False).
- `min_t` _Float_ - minimum time value defaults to 0.
- `padding` _Float_ - padding time value defaults to 0.
- `dilation` _Float_ - dilation time value defaults to 0 ie the number of replicates.
- `direction` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").

<a id="mocoschedulesinference_time_schedulesContinuousInferenceSchedulediscretize"></a>

#### discretize
Expand Down Expand Up @@ -1844,6 +1870,7 @@ A linear time schedule for continuous time inference.

```python
def __init__(nsteps: int,
inclusive_end: bool = False,
min_t: Float = 0,
padding: Float = 0,
dilation: Float = 0,
Expand All @@ -1856,6 +1883,7 @@ Initialize the LinearInferenceSchedule.
**Arguments**:

- `nsteps` _int_ - Number of time steps.
- `inclusive_end` _bool_ - If True, include the end value (1.0) in the schedule otherwise ends at 1.0-1/nsteps (default is False).
- `min_t` _Float_ - minimum time value defaults to 0.
- `padding` _Float_ - padding time value defaults to 0.
- `dilation` _Float_ - dilation time value defaults to 0 ie the number of replicates.
Expand All @@ -1867,9 +1895,9 @@ Initialize the LinearInferenceSchedule.
#### generate\_schedule

```python
def generate_schedule(
nsteps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None) -> Tensor
def generate_schedule(nsteps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
inclusive_end: bool = False) -> Tensor
```

Generate the linear time schedule as a tensor.
Expand All @@ -1882,7 +1910,6 @@ Generate the linear time schedule as a tensor.

**Returns**:

- `Tensor` - A tensor of time steps.
- `Tensor` - A tensor of time steps.

<a id="mocoschedulesinference_time_schedulesPowerInferenceSchedule"></a>
Expand All @@ -1901,6 +1928,7 @@ A power time schedule for inference, where time steps are generated by raising a

```python
def __init__(nsteps: int,
inclusive_end: bool = False,
min_t: Float = 0,
padding: Float = 0,
dilation: Float = 0,
Expand All @@ -1914,6 +1942,7 @@ Initialize the PowerInferenceSchedule.
**Arguments**:

- `nsteps` _int_ - Number of time steps.
- `inclusive_end` _bool_ - If True, include the end value (1.0) in the schedule otherwise ends at <1.0 (default is False).
- `min_t` _Float_ - minimum time value defaults to 0.
- `padding` _Float_ - padding time value defaults to 0.
- `dilation` _Float_ - dilation time value defaults to 0 ie the number of replicates.
Expand All @@ -1939,6 +1968,7 @@ Generate the power time schedule as a tensor.
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").



**Returns**:

- `Tensor` - A tensor of time steps.
Expand All @@ -1960,6 +1990,7 @@ A log time schedule for inference, where time steps are generated by taking the

```python
def __init__(nsteps: int,
inclusive_end: bool = False,
min_t: Float = 0,
padding: Float = 0,
dilation: Float = 0,
Expand Down Expand Up @@ -1989,6 +2020,7 @@ tensor([0.0000, 0.0455, 0.0889, 0.1303, 0.1699, 0.2077, 0.2439, 0.2783, 0.3113,
**Arguments**:

- `nsteps` _int_ - Number of time steps.
- `inclusive_end` _bool_ - If True, include the end value (1.0) in the schedule otherwise ends at <1.0 (default is False).
- `min_t` _Float_ - minimum time value defaults to 0.
- `padding` _Float_ - padding time value defaults to 0.
- `dilation` _Float_ - dilation time value defaults to 0 ie the number of replicates.
Expand Down Expand Up @@ -2681,7 +2713,7 @@ Get x(t) with given time t from noise and data.

- `data` _Tensor_ - target
- `t` _Tensor_ - time
- `noise` _Tensor, optional_ - noise from prior(). Defaults to None.
- `noise` _Tensor, optional_ - noise from prior(). Defaults to None

<a id="mocointerpolantscontinuous_timecontinuousvdmVDMprocess_data_prediction"></a>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -117,7 +117,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -205,7 +205,7 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -224,6 +224,168 @@
"dts = schedule.discretize(device=DEVICE)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.0000, 0.0010, 0.0020, 0.0030, 0.0040, 0.0050, 0.0060, 0.0070, 0.0080,\n",
" 0.0090, 0.0100, 0.0110, 0.0120, 0.0130, 0.0140, 0.0150, 0.0160, 0.0170,\n",
" 0.0180, 0.0190, 0.0200, 0.0210, 0.0220, 0.0230, 0.0240, 0.0250, 0.0260,\n",
" 0.0270, 0.0280, 0.0290, 0.0300, 0.0310, 0.0320, 0.0330, 0.0340, 0.0350,\n",
" 0.0360, 0.0370, 0.0380, 0.0390, 0.0400, 0.0410, 0.0420, 0.0430, 0.0440,\n",
" 0.0450, 0.0460, 0.0470, 0.0480, 0.0490, 0.0500, 0.0510, 0.0520, 0.0530,\n",
" 0.0540, 0.0550, 0.0560, 0.0570, 0.0580, 0.0590, 0.0600, 0.0610, 0.0620,\n",
" 0.0630, 0.0640, 0.0650, 0.0660, 0.0670, 0.0680, 0.0690, 0.0700, 0.0710,\n",
" 0.0720, 0.0730, 0.0740, 0.0750, 0.0760, 0.0770, 0.0780, 0.0790, 0.0800,\n",
" 0.0810, 0.0820, 0.0830, 0.0840, 0.0850, 0.0860, 0.0870, 0.0880, 0.0890,\n",
" 0.0900, 0.0910, 0.0920, 0.0930, 0.0940, 0.0950, 0.0960, 0.0970, 0.0980,\n",
" 0.0990, 0.1000, 0.1010, 0.1020, 0.1030, 0.1040, 0.1050, 0.1060, 0.1070,\n",
" 0.1080, 0.1090, 0.1100, 0.1110, 0.1120, 0.1130, 0.1140, 0.1150, 0.1160,\n",
" 0.1170, 0.1180, 0.1190, 0.1200, 0.1210, 0.1220, 0.1230, 0.1240, 0.1250,\n",
" 0.1260, 0.1270, 0.1280, 0.1290, 0.1300, 0.1310, 0.1320, 0.1330, 0.1340,\n",
" 0.1350, 0.1360, 0.1370, 0.1380, 0.1390, 0.1400, 0.1410, 0.1420, 0.1430,\n",
" 0.1440, 0.1450, 0.1460, 0.1470, 0.1480, 0.1490, 0.1500, 0.1510, 0.1520,\n",
" 0.1530, 0.1540, 0.1550, 0.1560, 0.1570, 0.1580, 0.1590, 0.1600, 0.1610,\n",
" 0.1620, 0.1630, 0.1640, 0.1650, 0.1660, 0.1670, 0.1680, 0.1690, 0.1700,\n",
" 0.1710, 0.1720, 0.1730, 0.1740, 0.1750, 0.1760, 0.1770, 0.1780, 0.1790,\n",
" 0.1800, 0.1810, 0.1820, 0.1830, 0.1840, 0.1850, 0.1860, 0.1870, 0.1880,\n",
" 0.1890, 0.1900, 0.1910, 0.1920, 0.1930, 0.1940, 0.1950, 0.1960, 0.1970,\n",
" 0.1980, 0.1990, 0.2000, 0.2010, 0.2020, 0.2030, 0.2040, 0.2050, 0.2060,\n",
" 0.2070, 0.2080, 0.2090, 0.2100, 0.2110, 0.2120, 0.2130, 0.2140, 0.2150,\n",
" 0.2160, 0.2170, 0.2180, 0.2190, 0.2200, 0.2210, 0.2220, 0.2230, 0.2240,\n",
" 0.2250, 0.2260, 0.2270, 0.2280, 0.2290, 0.2300, 0.2310, 0.2320, 0.2330,\n",
" 0.2340, 0.2350, 0.2360, 0.2370, 0.2380, 0.2390, 0.2400, 0.2410, 0.2420,\n",
" 0.2430, 0.2440, 0.2450, 0.2460, 0.2470, 0.2480, 0.2490, 0.2500, 0.2510,\n",
" 0.2520, 0.2530, 0.2540, 0.2550, 0.2560, 0.2570, 0.2580, 0.2590, 0.2600,\n",
" 0.2610, 0.2620, 0.2630, 0.2640, 0.2650, 0.2660, 0.2670, 0.2680, 0.2690,\n",
" 0.2700, 0.2710, 0.2720, 0.2730, 0.2740, 0.2750, 0.2760, 0.2770, 0.2780,\n",
" 0.2790, 0.2800, 0.2810, 0.2820, 0.2830, 0.2840, 0.2850, 0.2860, 0.2870,\n",
" 0.2880, 0.2890, 0.2900, 0.2910, 0.2920, 0.2930, 0.2940, 0.2950, 0.2960,\n",
" 0.2970, 0.2980, 0.2990, 0.3000, 0.3010, 0.3020, 0.3030, 0.3040, 0.3050,\n",
" 0.3060, 0.3070, 0.3080, 0.3090, 0.3100, 0.3110, 0.3120, 0.3130, 0.3140,\n",
" 0.3150, 0.3160, 0.3170, 0.3180, 0.3190, 0.3200, 0.3210, 0.3220, 0.3230,\n",
" 0.3240, 0.3250, 0.3260, 0.3270, 0.3280, 0.3290, 0.3300, 0.3310, 0.3320,\n",
" 0.3330, 0.3340, 0.3350, 0.3360, 0.3370, 0.3380, 0.3390, 0.3400, 0.3410,\n",
" 0.3420, 0.3430, 0.3440, 0.3450, 0.3460, 0.3470, 0.3480, 0.3490, 0.3500,\n",
" 0.3510, 0.3520, 0.3530, 0.3540, 0.3550, 0.3560, 0.3570, 0.3580, 0.3590,\n",
" 0.3600, 0.3610, 0.3620, 0.3630, 0.3640, 0.3650, 0.3660, 0.3670, 0.3680,\n",
" 0.3690, 0.3700, 0.3710, 0.3720, 0.3730, 0.3740, 0.3750, 0.3760, 0.3770,\n",
" 0.3780, 0.3790, 0.3800, 0.3810, 0.3820, 0.3830, 0.3840, 0.3850, 0.3860,\n",
" 0.3870, 0.3880, 0.3890, 0.3900, 0.3910, 0.3920, 0.3930, 0.3940, 0.3950,\n",
" 0.3960, 0.3970, 0.3980, 0.3990, 0.4000, 0.4010, 0.4020, 0.4030, 0.4040,\n",
" 0.4050, 0.4060, 0.4070, 0.4080, 0.4090, 0.4100, 0.4110, 0.4120, 0.4130,\n",
" 0.4140, 0.4150, 0.4160, 0.4170, 0.4180, 0.4190, 0.4200, 0.4210, 0.4220,\n",
" 0.4230, 0.4240, 0.4250, 0.4260, 0.4270, 0.4280, 0.4290, 0.4300, 0.4310,\n",
" 0.4320, 0.4330, 0.4340, 0.4350, 0.4360, 0.4370, 0.4380, 0.4390, 0.4400,\n",
" 0.4410, 0.4420, 0.4430, 0.4440, 0.4450, 0.4460, 0.4470, 0.4480, 0.4490,\n",
" 0.4500, 0.4510, 0.4520, 0.4530, 0.4540, 0.4550, 0.4560, 0.4570, 0.4580,\n",
" 0.4590, 0.4600, 0.4610, 0.4620, 0.4630, 0.4640, 0.4650, 0.4660, 0.4670,\n",
" 0.4680, 0.4690, 0.4700, 0.4710, 0.4720, 0.4730, 0.4740, 0.4750, 0.4760,\n",
" 0.4770, 0.4780, 0.4790, 0.4800, 0.4810, 0.4820, 0.4830, 0.4840, 0.4850,\n",
" 0.4860, 0.4870, 0.4880, 0.4890, 0.4900, 0.4910, 0.4920, 0.4930, 0.4940,\n",
" 0.4950, 0.4960, 0.4970, 0.4980, 0.4990, 0.5000, 0.5010, 0.5020, 0.5030,\n",
" 0.5040, 0.5050, 0.5060, 0.5070, 0.5080, 0.5090, 0.5100, 0.5110, 0.5120,\n",
" 0.5130, 0.5140, 0.5150, 0.5160, 0.5170, 0.5180, 0.5190, 0.5200, 0.5210,\n",
" 0.5220, 0.5230, 0.5240, 0.5250, 0.5260, 0.5270, 0.5280, 0.5290, 0.5300,\n",
" 0.5310, 0.5320, 0.5330, 0.5340, 0.5350, 0.5360, 0.5370, 0.5380, 0.5390,\n",
" 0.5400, 0.5410, 0.5420, 0.5430, 0.5440, 0.5450, 0.5460, 0.5470, 0.5480,\n",
" 0.5490, 0.5500, 0.5510, 0.5520, 0.5530, 0.5540, 0.5550, 0.5560, 0.5570,\n",
" 0.5580, 0.5590, 0.5600, 0.5610, 0.5620, 0.5630, 0.5640, 0.5650, 0.5660,\n",
" 0.5670, 0.5680, 0.5690, 0.5700, 0.5710, 0.5720, 0.5730, 0.5740, 0.5750,\n",
" 0.5760, 0.5770, 0.5780, 0.5790, 0.5800, 0.5810, 0.5820, 0.5830, 0.5840,\n",
" 0.5850, 0.5860, 0.5870, 0.5880, 0.5890, 0.5900, 0.5910, 0.5920, 0.5930,\n",
" 0.5940, 0.5950, 0.5960, 0.5970, 0.5980, 0.5990, 0.6000, 0.6010, 0.6020,\n",
" 0.6030, 0.6040, 0.6050, 0.6060, 0.6070, 0.6080, 0.6090, 0.6100, 0.6110,\n",
" 0.6120, 0.6130, 0.6140, 0.6150, 0.6160, 0.6170, 0.6180, 0.6190, 0.6200,\n",
" 0.6210, 0.6220, 0.6230, 0.6240, 0.6250, 0.6260, 0.6270, 0.6280, 0.6290,\n",
" 0.6300, 0.6310, 0.6320, 0.6330, 0.6340, 0.6350, 0.6360, 0.6370, 0.6380,\n",
" 0.6390, 0.6400, 0.6410, 0.6420, 0.6430, 0.6440, 0.6450, 0.6460, 0.6470,\n",
" 0.6480, 0.6490, 0.6500, 0.6510, 0.6520, 0.6530, 0.6540, 0.6550, 0.6560,\n",
" 0.6570, 0.6580, 0.6590, 0.6600, 0.6610, 0.6620, 0.6630, 0.6640, 0.6650,\n",
" 0.6660, 0.6670, 0.6680, 0.6690, 0.6700, 0.6710, 0.6720, 0.6730, 0.6740,\n",
" 0.6750, 0.6760, 0.6770, 0.6780, 0.6790, 0.6800, 0.6810, 0.6820, 0.6830,\n",
" 0.6840, 0.6850, 0.6860, 0.6870, 0.6880, 0.6890, 0.6900, 0.6910, 0.6920,\n",
" 0.6930, 0.6940, 0.6950, 0.6960, 0.6970, 0.6980, 0.6990, 0.7000, 0.7010,\n",
" 0.7020, 0.7030, 0.7040, 0.7050, 0.7060, 0.7070, 0.7080, 0.7090, 0.7100,\n",
" 0.7110, 0.7120, 0.7130, 0.7140, 0.7150, 0.7160, 0.7170, 0.7180, 0.7190,\n",
" 0.7200, 0.7210, 0.7220, 0.7230, 0.7240, 0.7250, 0.7260, 0.7270, 0.7280,\n",
" 0.7290, 0.7300, 0.7310, 0.7320, 0.7330, 0.7340, 0.7350, 0.7360, 0.7370,\n",
" 0.7380, 0.7390, 0.7400, 0.7410, 0.7420, 0.7430, 0.7440, 0.7450, 0.7460,\n",
" 0.7470, 0.7480, 0.7490, 0.7500, 0.7510, 0.7520, 0.7530, 0.7540, 0.7550,\n",
" 0.7560, 0.7570, 0.7580, 0.7590, 0.7600, 0.7610, 0.7620, 0.7630, 0.7640,\n",
" 0.7650, 0.7660, 0.7670, 0.7680, 0.7690, 0.7700, 0.7710, 0.7720, 0.7730,\n",
" 0.7740, 0.7750, 0.7760, 0.7770, 0.7780, 0.7790, 0.7800, 0.7810, 0.7820,\n",
" 0.7830, 0.7840, 0.7850, 0.7860, 0.7870, 0.7880, 0.7890, 0.7900, 0.7910,\n",
" 0.7920, 0.7930, 0.7940, 0.7950, 0.7960, 0.7970, 0.7980, 0.7990, 0.8000,\n",
" 0.8010, 0.8020, 0.8030, 0.8040, 0.8050, 0.8060, 0.8070, 0.8080, 0.8090,\n",
" 0.8100, 0.8110, 0.8120, 0.8130, 0.8140, 0.8150, 0.8160, 0.8170, 0.8180,\n",
" 0.8190, 0.8200, 0.8210, 0.8220, 0.8230, 0.8240, 0.8250, 0.8260, 0.8270,\n",
" 0.8280, 0.8290, 0.8300, 0.8310, 0.8320, 0.8330, 0.8340, 0.8350, 0.8360,\n",
" 0.8370, 0.8380, 0.8390, 0.8400, 0.8410, 0.8420, 0.8430, 0.8440, 0.8450,\n",
" 0.8460, 0.8470, 0.8480, 0.8490, 0.8500, 0.8510, 0.8520, 0.8530, 0.8540,\n",
" 0.8550, 0.8560, 0.8570, 0.8580, 0.8590, 0.8600, 0.8610, 0.8620, 0.8630,\n",
" 0.8640, 0.8650, 0.8660, 0.8670, 0.8680, 0.8690, 0.8700, 0.8710, 0.8720,\n",
" 0.8730, 0.8740, 0.8750, 0.8760, 0.8770, 0.8780, 0.8790, 0.8800, 0.8810,\n",
" 0.8820, 0.8830, 0.8840, 0.8850, 0.8860, 0.8870, 0.8880, 0.8890, 0.8900,\n",
" 0.8910, 0.8920, 0.8930, 0.8940, 0.8950, 0.8960, 0.8970, 0.8980, 0.8990,\n",
" 0.9000, 0.9010, 0.9020, 0.9030, 0.9040, 0.9050, 0.9060, 0.9070, 0.9080,\n",
" 0.9090, 0.9100, 0.9110, 0.9120, 0.9130, 0.9140, 0.9150, 0.9160, 0.9170,\n",
" 0.9180, 0.9190, 0.9200, 0.9210, 0.9220, 0.9230, 0.9240, 0.9250, 0.9260,\n",
" 0.9270, 0.9280, 0.9290, 0.9300, 0.9310, 0.9320, 0.9330, 0.9340, 0.9350,\n",
" 0.9360, 0.9370, 0.9380, 0.9390, 0.9400, 0.9410, 0.9420, 0.9430, 0.9440,\n",
" 0.9450, 0.9460, 0.9470, 0.9480, 0.9490, 0.9500, 0.9510, 0.9520, 0.9530,\n",
" 0.9540, 0.9550, 0.9560, 0.9570, 0.9580, 0.9590, 0.9600, 0.9610, 0.9620,\n",
" 0.9630, 0.9640, 0.9650, 0.9660, 0.9670, 0.9680, 0.9690, 0.9700, 0.9710,\n",
" 0.9720, 0.9730, 0.9740, 0.9750, 0.9760, 0.9770, 0.9780, 0.9790, 0.9800,\n",
" 0.9810, 0.9820, 0.9830, 0.9840, 0.9850, 0.9860, 0.9870, 0.9880, 0.9890,\n",
" 0.9900, 0.9910, 0.9920, 0.9930, 0.9940, 0.9950, 0.9960, 0.9970, 0.9980,\n",
" 0.9990], device='cuda:0')"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ts"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.0000, 0.0100, 0.0200, 0.0300, 0.0400, 0.0500, 0.0600, 0.0700, 0.0800,\n",
" 0.0900, 0.1000, 0.1100, 0.1200, 0.1300, 0.1400, 0.1500, 0.1600, 0.1700,\n",
" 0.1800, 0.1900, 0.2000, 0.2100, 0.2200, 0.2300, 0.2400, 0.2500, 0.2600,\n",
" 0.2700, 0.2800, 0.2900, 0.3000, 0.3100, 0.3200, 0.3300, 0.3400, 0.3500,\n",
" 0.3600, 0.3700, 0.3800, 0.3900, 0.4000, 0.4100, 0.4200, 0.4300, 0.4400,\n",
" 0.4500, 0.4600, 0.4700, 0.4800, 0.4900, 0.5000, 0.5100, 0.5200, 0.5300,\n",
" 0.5400, 0.5500, 0.5600, 0.5700, 0.5800, 0.5900, 0.6000, 0.6100, 0.6200,\n",
" 0.6300, 0.6400, 0.6500, 0.6600, 0.6700, 0.6800, 0.6900, 0.7000, 0.7100,\n",
" 0.7200, 0.7300, 0.7400, 0.7500, 0.7600, 0.7700, 0.7800, 0.7900, 0.8000,\n",
" 0.8100, 0.8200, 0.8300, 0.8400, 0.8500, 0.8600, 0.8700, 0.8800, 0.8900,\n",
" 0.9000, 0.9100, 0.9200, 0.9300, 0.9400, 0.9500, 0.9600, 0.9700, 0.9800,\n",
" 0.9900])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"LinearInferenceSchedule(nsteps = 100, min_t=0, inclusive_end=False).generate_schedule()"
]
},
{
"cell_type": "code",
"execution_count": 36,
Expand Down
Loading

0 comments on commit 87afa0a

Please sign in to comment.