Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add generic torchmetric ppl logging on esm2 #557

Closed
wants to merge 36 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
d6b50ba
transfer torchmetrics changes from pstjohn repo
sichu2023 Dec 12, 2024
d9dd17d
add pp last stage in logging
sichu2023 Dec 12, 2024
a8243ef
add tp-aware update method
sichu2023 Dec 12, 2024
3347076
drop comment on tp-aware normalization
sichu2023 Dec 12, 2024
6e65e75
update comment
sichu2023 Dec 12, 2024
26a1054
fix cp error
sichu2023 Dec 12, 2024
1a106af
drop process_group
sichu2023 Dec 12, 2024
9a760c4
add MegatronPerplexityMetric testing
sichu2023 Dec 16, 2024
29654cd
fix metric device
sichu2023 Dec 17, 2024
125bfe6
fix MegatronPerplexityMetric.update
sichu2023 Dec 17, 2024
0120fb1
clean up test_megatron_perplexity_metric_with_single_microbatch_golde…
sichu2023 Dec 17, 2024
54a1e7c
fix get_random_microbatch
sichu2023 Dec 17, 2024
c811f6c
add variable length microbatch test
sichu2023 Dec 17, 2024
52d11db
ruff
sichu2023 Dec 17, 2024
9d232b2
add back self.log_{train,val}_ppl
sichu2023 Dec 18, 2024
e259783
add back return in {train,validation}_step
sichu2023 Dec 18, 2024
10a085b
add argparse
sichu2023 Dec 18, 2024
635848c
disable async ckpt save
sichu2023 Dec 19, 2024
290fd16
drop pp support
sichu2023 Dec 19, 2024
d1f056c
Revert "update ddp config"
sichu2023 Dec 19, 2024
bdadd28
disable training ppl logging by default
sichu2023 Dec 19, 2024
fb9c270
remove ppl callback
sichu2023 Dec 19, 2024
49238c4
move pp check to train_esm2.py
sichu2023 Dec 19, 2024
048dc52
ruff
sichu2023 Dec 19, 2024
575bac4
switch back to torchmetrics.text.Perplexity
sichu2023 Dec 24, 2024
58e72c3
clean up comments
sichu2023 Dec 27, 2024
952b47b
switch to update in validation epoch logging
sichu2023 Dec 28, 2024
89c2042
squash commits from debug branch
sichu2023 Jan 16, 2025
022cb01
move torchmetric instantiation into init
sichu2023 Jan 16, 2025
20bd098
restrict compute to logging device
sichu2023 Jan 16, 2025
1c7f9a9
Revert "restrict compute to logging device"
sichu2023 Jan 16, 2025
4a3ff2c
skip assigning Metric.process_group
sichu2023 Jan 16, 2025
1012355
add --log-train-ppl warning
sichu2023 Jan 16, 2025
4c28e11
move Perplexity outside of model init
sichu2023 Jan 16, 2025
48fee54
match with torchmetric tutorial
sichu2023 Jan 16, 2025
738224f
Revert "move Perplexity outside of model init"
sichu2023 Jan 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from typing import List, Optional, Sequence, get_args

from lightning.pytorch.callbacks import LearningRateMonitor, RichModelSummary
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.optimizer import OptimizerConfig
from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig
from nemo import lightning as nl
from nemo.collections import llm
from nemo.lightning import resume
Expand All @@ -31,7 +31,6 @@
from bionemo.esm2.data.datamodule import ESMDataModule
from bionemo.esm2.data.dataset import RandomMaskStrategy
from bionemo.esm2.data.tokenizer import get_tokenizer
from bionemo.llm.lightning import PerplexityLoggingCallback
from bionemo.llm.model.biobert.lightning import biobert_lightning_module
from bionemo.llm.model.biobert.model import BiobertSpecOption
from bionemo.llm.model.lr_scheduler import WarmupAnnealDecayHoldScheduler
Expand Down Expand Up @@ -84,6 +83,8 @@ def main(
save_best_checkpoint: bool = True,
save_last_checkpoint: bool = True,
metric_to_monitor_for_checkpoints: str = "val_loss",
log_train_ppl: bool = False,
log_val_ppl: bool = True,
save_top_k: int = 2,
nsys_profiling: bool = False,
nsys_start_step: int = 0,
Expand Down Expand Up @@ -145,6 +146,8 @@ def main(
save_best_checkpoint (bool): whether to save the best checkpoint
save_last_checkpoint (bool): whether to save the last checkpoint
metric_to_monitor_for_checkpoints (str): metric to monitor for checkpoints
log_train_ppl (bool): log training perplexity
log_val_ppl (bool): log validation perplexity
save_top_k (int): number of top checkpoints to save
nsys_profiling (bool): whether to enable nsys profiling
nsys_start_step (int): start step for nsys profiling
Expand Down Expand Up @@ -186,7 +189,6 @@ def main(
use_distributed_optimizer=True,
),
find_unused_parameters=True,
gradient_as_bucket_view=True,
ckpt_include_optimizer=True,
ckpt_async_save=True,
ckpt_parallel_load=True,
Expand All @@ -211,7 +213,6 @@ def main(
)

callbacks = [
PerplexityLoggingCallback(log_train=False, log_val=True),
RichModelSummary(max_depth=4),
LearningRateMonitor(),
nl_callbacks.PreemptionCallback(),
Expand Down Expand Up @@ -243,6 +244,7 @@ def main(
autocast_enabled=False,
),
enable_checkpointing=create_checkpoint_callback,
num_sanity_val_steps=0,
)

tokenizer = get_tokenizer()
Expand Down Expand Up @@ -281,6 +283,9 @@ def main(
if scheduler_num_steps is None:
scheduler_num_steps = num_steps

if (log_train_ppl or log_val_ppl) and pipeline_model_parallel_size > 1:
raise NotImplementedError("Perplexity logging does not support pipeline parallelism yet.")

model = biobert_lightning_module(
esm2_config,
tokenizer=tokenizer,
Expand All @@ -301,6 +306,9 @@ def main(
anneal_percentage=0.10,
),
),
# perplexity logging
log_train_ppl=log_train_ppl,
log_val_ppl=log_val_ppl,
)

# Configure our custom Checkpointer
Expand Down Expand Up @@ -384,6 +392,8 @@ def train_esm2_entrypoint():
save_best_checkpoint=args.save_best_checkpoint,
save_last_checkpoint=args.save_last_checkpoint,
metric_to_monitor_for_checkpoints=args.metric_to_monitor_for_checkpoints,
log_train_ppl=args.log_train_ppl,
log_val_ppl=args.log_val_ppl,
save_top_k=args.save_top_k,
nsys_profiling=args.nsys_profiling,
nsys_start_step=args.nsys_start_step,
Expand Down Expand Up @@ -637,6 +647,25 @@ def get_parser():
default="val_loss",
help="The metric to monitor for checkpointing.",
)
parser.add_argument(
"--log-train-ppl",
action="store_true",
default=False,
help="Log perplexity during training. Requires synchronization every training step and hurts performance. Enable only when necessary.",
)
parser.add_argument(
"--log-val-ppl",
action="store_true",
default=False,
help="Log perplexity during validation.",
)
parser.add_argument(
"--no-log-val-ppl",
action="store_false",
dest="log_val_ppl",
default=True,
help="Disable logging perplexity during validation.",
)
parser.add_argument(
"--save-top-k",
type=int,
Expand Down
70 changes: 68 additions & 2 deletions sub-packages/bionemo-llm/src/bionemo/llm/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import lightning.pytorch as pl
import torch.distributed
import torchmetrics.text
from torchmetrics.functional.text.perplexity import _perplexity_update
from megatron.core import parallel_state
from megatron.core.optimizer.optimizer_config import OptimizerConfig
from nemo.lightning import io as nlio
Expand Down Expand Up @@ -210,6 +212,33 @@ def predict_loss_reduction(self) -> PassthroughLossReduction:
"""


class MegatronPerplexityMetric(torchmetrics.text.Perplexity):
def __init__(self, *args, **kwargs):
if parallel_state.get_context_parallel_world_size() > 1:
raise NotImplementedError(f"{self.__class__} does not support context parallelism yet.")

self.cross_entropy_loss_fusion = kwargs.pop("cross_entropy_loss_fusion", False)
super().__init__(*args, **kwargs)

def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
"""Update state with predictions and targets under tensor parallelism."""
unreduced_token_loss = unreduced_token_loss_fn( # TP-aware log prob function
preds.clone().transpose(0, 1).contiguous(),
target.clone(),
cross_entropy_loss_fusion=self.cross_entropy_loss_fusion,
) # (b, s)

if self.ignore_index is not None:
mask = target.ne(self.ignore_index)
target = target.where(target != self.ignore_index, torch.tensor(0, device=target.device))
else:
mask = torch.ones_like(target, dtype=torch.bool)
unreduced_token_loss = unreduced_token_loss[mask]

self.total_log_probs += unreduced_token_loss.sum()
self.count += mask.sum()


class BionemoLightningModule(
Generic[MegatronModelType, MegatronLossType],
pl.LightningModule,
Expand All @@ -227,6 +256,8 @@ def __init__(
# TODO: Add transformer_layer_spec when we update mcore
optimizer: MegatronOptimizerModule,
model_transform: Optional[Callable[[MegatronModelType], MegatronModelType]] = None,
log_train_ppl: bool = False,
log_val_ppl: bool = False,
**model_construct_args,
) -> None:
"""Constructor.
Expand All @@ -242,6 +273,8 @@ def __init__(
model_construct_args: Optional. Any arguments necessary to construct the model in the `config`'s
`configure_model` method.
model_transform: Optional. The model transform function.
log_train_ppl (bool): Log training perplexity.
log_val_ppl (bool): Log validation perplexity.
**model_construct_args: Optional. Arguments necessary for the supplied model configuration's
`configure_model` method, which will make an instance of the model.
"""
Expand All @@ -258,6 +291,10 @@ def __init__(
self._forward_step = forward_step
self.model_transform = model_transform

# torchmetrics must init here for fiddle serialization
self.train_ppl = torchmetrics.text.Perplexity(ignore_index=-100) if log_train_ppl else None
self.valid_ppl = torchmetrics.text.Perplexity(ignore_index=-100) if log_val_ppl else None

def configure_model(self) -> None:
"""Updates internal state: instantiates the model from the object's config, assigns to `model` attribute.

Expand All @@ -273,9 +310,13 @@ def configure_model(self) -> None:
else self.config.configure_model()
)
self.module = model

if self.module is None:
raise ValueError("Invalid semantics: configure_model method **MUST** initialize the model.")

def is_on_logging_device(self):
return parallel_state.is_pipeline_last_stage() and parallel_state.get_tensor_model_parallel_rank() == 0

def forward(self, *args, **kwargs) -> DataT:
"""Call the forward method of the underlying model, and return whatever it outputs."""
# safe to do because configure_model is idempotent
Expand Down Expand Up @@ -304,11 +345,26 @@ def forward_step(self, batch) -> Tensor:

def training_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
"""In mcore the loss-function is part of the forward-pass when labels are provided."""
return self.forward_step(batch)
outputs = self.forward_step(batch)
logits = outputs["token_logits"].transpose(0, 1).clone().detach() # [s, b, v] -> [b, s, v]

if self.train_ppl is not None:
if self.is_on_logging_device():
self.train_ppl(logits, batch["labels"])

self.log("train_ppl", self.train_ppl, on_step=True, on_epoch=False, prog_bar=True)

return outputs

def validation_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
"""In mcore the loss-function is part of the forward-pass when labels are provided."""
return self.forward_step(batch)
outputs = self.forward_step(batch)
logits = outputs["token_logits"].transpose(0, 1).clone().detach() # [s, b, v] -> [b, s, v]

if self.valid_ppl is not None and self.is_on_logging_device():
self.valid_ppl.update(logits, batch["labels"])

return outputs

def predict_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
"""Alias for forward_step."""
Expand All @@ -326,6 +382,16 @@ def validation_loss_reduction(self) -> MegatronLossType: # noqa: D102
def test_loss_reduction(self) -> MegatronLossType: # noqa: D102
return self.loss_reduction_class(validation_step=True)

def on_validation_epoch_end(self): # noqa: D102
if self.valid_ppl is None:
return

if self.trainer.sanity_checking:
self.valid_ppl.reset() # clean up sanity runs
return

self.log("valid_ppl", self.valid_ppl, on_step=False, on_epoch=True, prog_bar=True)


def default_megatron_optimizer() -> MegatronOptimizerModule:
"""Default distributed optimizer uses Adam with a 1e-4 learning rate."""
Expand Down
75 changes: 74 additions & 1 deletion sub-packages/bionemo-llm/tests/bionemo/llm/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torchmetrics.text import Perplexity

from bionemo.llm import lightning as bnptl
from bionemo.llm.lightning import PerplexityLoggingCallback, batch_collator, get_dtype_device
from bionemo.llm.lightning import MegatronPerplexityMetric, PerplexityLoggingCallback, batch_collator, get_dtype_device
from bionemo.testing import megatron_parallel_state_utils
from bionemo.testing.lightning import get_random_microbatch

Expand Down Expand Up @@ -186,6 +186,79 @@ def test_mixin_strategy_contract_get_loss_reduction():
assert isinstance(strategy_reduction_function(mixin), bnptl.PassthroughLossReduction)


def test_megatron_perplexity_metric_with_single_microbatch_golden_value_without_parallelism(seed: int = 42):
"""Test PerplexityLoggingCallback with a single microbatch without parallelism"""
with megatron_parallel_state_utils.distributed_model_parallel_state(seed=seed):
# setup test input
microbatch_size, max_sequence_length, vocab_size = 1, 1024, 2
microbatch_outputs = [get_random_microbatch(microbatch_size, max_sequence_length, vocab_size, seed)]

# setup metric
megatron_ppl_metric = MegatronPerplexityMetric(ignore_index=-100).to(torch.cuda.current_device())
metric = Perplexity(ignore_index=-100).to(torch.cuda.current_device())

# compute values
for microbatch_output in microbatch_outputs:
megatron_ppl_metric.update(
microbatch_output["forward_out"]["token_logits"]
.transpose(0, 1)
.contiguous(), # (s, b, v) -> (b, s, v)
microbatch_output["batch"]["labels"],
)
metric.update(
microbatch_output["forward_out"]["token_logits"]
.transpose(0, 1)
.contiguous(), # (s, b, v) -> (b, s, v)
microbatch_output["batch"]["labels"],
)
ppl_value = megatron_ppl_metric.compute()
ppl_golden_value = metric.compute()

torch.testing.assert_close(
ppl_value,
ppl_golden_value,
)


def test_megatron_perplexity_metric_with_with_variable_length_microbatches_golden_value_without_parallelism(
seed: int = 42,
):
"""Test PerplexityLoggingCallback with a single microbatch without parallelism"""
with megatron_parallel_state_utils.distributed_model_parallel_state(seed=seed):
# setup test input
microbatch_size, max_sequence_length, vocab_size = 2, 1024, 2
microbatch_outputs = [
get_random_microbatch(microbatch_size, max_sequence_length // 2, vocab_size, seed),
get_random_microbatch(microbatch_size, max_sequence_length, vocab_size, seed),
]

# setup metric
megatron_ppl_metric = MegatronPerplexityMetric(ignore_index=-100).to(torch.cuda.current_device())
metric = Perplexity(ignore_index=-100).to(torch.cuda.current_device())

# compute values
for microbatch_output in microbatch_outputs:
megatron_ppl_metric.update(
microbatch_output["forward_out"]["token_logits"]
.transpose(0, 1)
.contiguous(), # (s, b, v) -> (b, s, v)
microbatch_output["batch"]["labels"],
)
metric.update(
microbatch_output["forward_out"]["token_logits"]
.transpose(0, 1)
.contiguous(), # (s, b, v) -> (b, s, v)
microbatch_output["batch"]["labels"],
)
ppl_value = megatron_ppl_metric.compute()
ppl_golden_value = metric.compute()

torch.testing.assert_close(
ppl_value,
ppl_golden_value,
)


def test_perplexity_logging_callback_with_single_microbatch_golden_value_without_parallelism(seed: int = 42):
"""Test PerplexityLoggingCallback with a single microbatch without parallelism"""
with megatron_parallel_state_utils.distributed_model_parallel_state(seed=seed):
Expand Down
10 changes: 7 additions & 3 deletions sub-packages/bionemo-testing/src/bionemo/testing/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@


def get_random_microbatch(
microbatch_size: int, max_sequence_length: int, vocab_size: int, seed: int
microbatch_size: int,
max_sequence_length: int,
vocab_size: int,
seed: int,
mask_index: int = -100,
) -> Dict[str, Dict[str, torch.Tensor]]:
"""Generate random microbatches for testing.

Expand All @@ -35,7 +39,7 @@ def get_random_microbatch(
device=torch.cuda.current_device(),
) # [b s]
loss_mask = torch.randint(
low=1,
low=0,
high=1 + 1,
size=(microbatch_size, max_sequence_length),
dtype=torch.long,
Expand All @@ -45,7 +49,7 @@ def get_random_microbatch(
token_logits = torch.rand(
max_sequence_length, microbatch_size, vocab_size, device=torch.cuda.current_device(), generator=generator
) # [s b v]
labels[loss_mask == 0] = -100 # propagate masking to labels
labels[loss_mask == 0] = mask_index # propagate masking to labels
microbatch_output = {
"batch": {"labels": labels, "loss_mask": loss_mask},
"forward_out": {"token_logits": token_logits},
Expand Down
Loading