diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py index 928ff81fa8..105d4bae9f 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py @@ -18,8 +18,8 @@ from typing import List, Optional, Sequence, get_args from lightning.pytorch.callbacks import LearningRateMonitor, RichModelSummary -from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig +from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig from nemo import lightning as nl from nemo.collections import llm from nemo.lightning import resume @@ -31,7 +31,6 @@ from bionemo.esm2.data.datamodule import ESMDataModule from bionemo.esm2.data.dataset import RandomMaskStrategy from bionemo.esm2.data.tokenizer import get_tokenizer -from bionemo.llm.lightning import PerplexityLoggingCallback from bionemo.llm.model.biobert.lightning import biobert_lightning_module from bionemo.llm.model.biobert.model import BiobertSpecOption from bionemo.llm.model.lr_scheduler import WarmupAnnealDecayHoldScheduler @@ -84,6 +83,8 @@ def main( save_best_checkpoint: bool = True, save_last_checkpoint: bool = True, metric_to_monitor_for_checkpoints: str = "val_loss", + log_train_ppl: bool = False, + log_val_ppl: bool = True, save_top_k: int = 2, nsys_profiling: bool = False, nsys_start_step: int = 0, @@ -145,6 +146,8 @@ def main( save_best_checkpoint (bool): whether to save the best checkpoint save_last_checkpoint (bool): whether to save the last checkpoint metric_to_monitor_for_checkpoints (str): metric to monitor for checkpoints + log_train_ppl (bool): log training perplexity + log_val_ppl (bool): log validation perplexity save_top_k (int): number of top checkpoints to save nsys_profiling (bool): whether to enable nsys profiling nsys_start_step (int): start step for nsys profiling @@ -186,7 +189,6 @@ def main( use_distributed_optimizer=True, ), find_unused_parameters=True, - gradient_as_bucket_view=True, ckpt_include_optimizer=True, ckpt_async_save=True, ckpt_parallel_load=True, @@ -211,7 +213,6 @@ def main( ) callbacks = [ - PerplexityLoggingCallback(log_train=False, log_val=True), RichModelSummary(max_depth=4), LearningRateMonitor(), nl_callbacks.PreemptionCallback(), @@ -243,6 +244,7 @@ def main( autocast_enabled=False, ), enable_checkpointing=create_checkpoint_callback, + num_sanity_val_steps=0, ) tokenizer = get_tokenizer() @@ -281,6 +283,9 @@ def main( if scheduler_num_steps is None: scheduler_num_steps = num_steps + if (log_train_ppl or log_val_ppl) and pipeline_model_parallel_size > 1: + raise NotImplementedError("Perplexity logging does not support pipeline parallelism yet.") + model = biobert_lightning_module( esm2_config, tokenizer=tokenizer, @@ -301,6 +306,9 @@ def main( anneal_percentage=0.10, ), ), + # perplexity logging + log_train_ppl=log_train_ppl, + log_val_ppl=log_val_ppl, ) # Configure our custom Checkpointer @@ -384,6 +392,8 @@ def train_esm2_entrypoint(): save_best_checkpoint=args.save_best_checkpoint, save_last_checkpoint=args.save_last_checkpoint, metric_to_monitor_for_checkpoints=args.metric_to_monitor_for_checkpoints, + log_train_ppl=args.log_train_ppl, + log_val_ppl=args.log_val_ppl, save_top_k=args.save_top_k, nsys_profiling=args.nsys_profiling, nsys_start_step=args.nsys_start_step, @@ -637,6 +647,25 @@ def get_parser(): default="val_loss", help="The metric to monitor for checkpointing.", ) + parser.add_argument( + "--log-train-ppl", + action="store_true", + default=False, + help="Log perplexity during training. Requires synchronization every training step and hurts performance. Enable only when necessary.", + ) + parser.add_argument( + "--log-val-ppl", + action="store_true", + default=False, + help="Log perplexity during validation.", + ) + parser.add_argument( + "--no-log-val-ppl", + action="store_false", + dest="log_val_ppl", + default=True, + help="Disable logging perplexity during validation.", + ) parser.add_argument( "--save-top-k", type=int, diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/lightning.py b/sub-packages/bionemo-llm/src/bionemo/llm/lightning.py index baf5e72bde..e4b88b9b00 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/lightning.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/lightning.py @@ -17,6 +17,8 @@ import lightning.pytorch as pl import torch.distributed +import torchmetrics.text +from torchmetrics.functional.text.perplexity import _perplexity_update from megatron.core import parallel_state from megatron.core.optimizer.optimizer_config import OptimizerConfig from nemo.lightning import io as nlio @@ -210,6 +212,33 @@ def predict_loss_reduction(self) -> PassthroughLossReduction: """ +class MegatronPerplexityMetric(torchmetrics.text.Perplexity): + def __init__(self, *args, **kwargs): + if parallel_state.get_context_parallel_world_size() > 1: + raise NotImplementedError(f"{self.__class__} does not support context parallelism yet.") + + self.cross_entropy_loss_fusion = kwargs.pop("cross_entropy_loss_fusion", False) + super().__init__(*args, **kwargs) + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + """Update state with predictions and targets under tensor parallelism.""" + unreduced_token_loss = unreduced_token_loss_fn( # TP-aware log prob function + preds.clone().transpose(0, 1).contiguous(), + target.clone(), + cross_entropy_loss_fusion=self.cross_entropy_loss_fusion, + ) # (b, s) + + if self.ignore_index is not None: + mask = target.ne(self.ignore_index) + target = target.where(target != self.ignore_index, torch.tensor(0, device=target.device)) + else: + mask = torch.ones_like(target, dtype=torch.bool) + unreduced_token_loss = unreduced_token_loss[mask] + + self.total_log_probs += unreduced_token_loss.sum() + self.count += mask.sum() + + class BionemoLightningModule( Generic[MegatronModelType, MegatronLossType], pl.LightningModule, @@ -227,6 +256,8 @@ def __init__( # TODO: Add transformer_layer_spec when we update mcore optimizer: MegatronOptimizerModule, model_transform: Optional[Callable[[MegatronModelType], MegatronModelType]] = None, + log_train_ppl: bool = False, + log_val_ppl: bool = False, **model_construct_args, ) -> None: """Constructor. @@ -242,6 +273,8 @@ def __init__( model_construct_args: Optional. Any arguments necessary to construct the model in the `config`'s `configure_model` method. model_transform: Optional. The model transform function. + log_train_ppl (bool): Log training perplexity. + log_val_ppl (bool): Log validation perplexity. **model_construct_args: Optional. Arguments necessary for the supplied model configuration's `configure_model` method, which will make an instance of the model. """ @@ -258,6 +291,10 @@ def __init__( self._forward_step = forward_step self.model_transform = model_transform + # torchmetrics must init here for fiddle serialization + self.train_ppl = torchmetrics.text.Perplexity(ignore_index=-100) if log_train_ppl else None + self.valid_ppl = torchmetrics.text.Perplexity(ignore_index=-100) if log_val_ppl else None + def configure_model(self) -> None: """Updates internal state: instantiates the model from the object's config, assigns to `model` attribute. @@ -273,9 +310,13 @@ def configure_model(self) -> None: else self.config.configure_model() ) self.module = model + if self.module is None: raise ValueError("Invalid semantics: configure_model method **MUST** initialize the model.") + def is_on_logging_device(self): + return parallel_state.is_pipeline_last_stage() and parallel_state.get_tensor_model_parallel_rank() == 0 + def forward(self, *args, **kwargs) -> DataT: """Call the forward method of the underlying model, and return whatever it outputs.""" # safe to do because configure_model is idempotent @@ -304,11 +345,26 @@ def forward_step(self, batch) -> Tensor: def training_step(self, batch, batch_idx: Optional[int] = None) -> Tensor: """In mcore the loss-function is part of the forward-pass when labels are provided.""" - return self.forward_step(batch) + outputs = self.forward_step(batch) + logits = outputs["token_logits"].transpose(0, 1).clone().detach() # [s, b, v] -> [b, s, v] + + if self.train_ppl is not None: + if self.is_on_logging_device(): + self.train_ppl(logits, batch["labels"]) + + self.log("train_ppl", self.train_ppl, on_step=True, on_epoch=False, prog_bar=True) + + return outputs def validation_step(self, batch, batch_idx: Optional[int] = None) -> Tensor: """In mcore the loss-function is part of the forward-pass when labels are provided.""" - return self.forward_step(batch) + outputs = self.forward_step(batch) + logits = outputs["token_logits"].transpose(0, 1).clone().detach() # [s, b, v] -> [b, s, v] + + if self.valid_ppl is not None and self.is_on_logging_device(): + self.valid_ppl.update(logits, batch["labels"]) + + return outputs def predict_step(self, batch, batch_idx: Optional[int] = None) -> Tensor: """Alias for forward_step.""" @@ -326,6 +382,16 @@ def validation_loss_reduction(self) -> MegatronLossType: # noqa: D102 def test_loss_reduction(self) -> MegatronLossType: # noqa: D102 return self.loss_reduction_class(validation_step=True) + def on_validation_epoch_end(self): # noqa: D102 + if self.valid_ppl is None: + return + + if self.trainer.sanity_checking: + self.valid_ppl.reset() # clean up sanity runs + return + + self.log("valid_ppl", self.valid_ppl, on_step=False, on_epoch=True, prog_bar=True) + def default_megatron_optimizer() -> MegatronOptimizerModule: """Default distributed optimizer uses Adam with a 1e-4 learning rate.""" diff --git a/sub-packages/bionemo-llm/tests/bionemo/llm/test_lightning.py b/sub-packages/bionemo-llm/tests/bionemo/llm/test_lightning.py index 802cbeb94e..9ea52a1cc1 100644 --- a/sub-packages/bionemo-llm/tests/bionemo/llm/test_lightning.py +++ b/sub-packages/bionemo-llm/tests/bionemo/llm/test_lightning.py @@ -24,7 +24,7 @@ from torchmetrics.text import Perplexity from bionemo.llm import lightning as bnptl -from bionemo.llm.lightning import PerplexityLoggingCallback, batch_collator, get_dtype_device +from bionemo.llm.lightning import MegatronPerplexityMetric, PerplexityLoggingCallback, batch_collator, get_dtype_device from bionemo.testing import megatron_parallel_state_utils from bionemo.testing.lightning import get_random_microbatch @@ -186,6 +186,79 @@ def test_mixin_strategy_contract_get_loss_reduction(): assert isinstance(strategy_reduction_function(mixin), bnptl.PassthroughLossReduction) +def test_megatron_perplexity_metric_with_single_microbatch_golden_value_without_parallelism(seed: int = 42): + """Test PerplexityLoggingCallback with a single microbatch without parallelism""" + with megatron_parallel_state_utils.distributed_model_parallel_state(seed=seed): + # setup test input + microbatch_size, max_sequence_length, vocab_size = 1, 1024, 2 + microbatch_outputs = [get_random_microbatch(microbatch_size, max_sequence_length, vocab_size, seed)] + + # setup metric + megatron_ppl_metric = MegatronPerplexityMetric(ignore_index=-100).to(torch.cuda.current_device()) + metric = Perplexity(ignore_index=-100).to(torch.cuda.current_device()) + + # compute values + for microbatch_output in microbatch_outputs: + megatron_ppl_metric.update( + microbatch_output["forward_out"]["token_logits"] + .transpose(0, 1) + .contiguous(), # (s, b, v) -> (b, s, v) + microbatch_output["batch"]["labels"], + ) + metric.update( + microbatch_output["forward_out"]["token_logits"] + .transpose(0, 1) + .contiguous(), # (s, b, v) -> (b, s, v) + microbatch_output["batch"]["labels"], + ) + ppl_value = megatron_ppl_metric.compute() + ppl_golden_value = metric.compute() + + torch.testing.assert_close( + ppl_value, + ppl_golden_value, + ) + + +def test_megatron_perplexity_metric_with_with_variable_length_microbatches_golden_value_without_parallelism( + seed: int = 42, +): + """Test PerplexityLoggingCallback with a single microbatch without parallelism""" + with megatron_parallel_state_utils.distributed_model_parallel_state(seed=seed): + # setup test input + microbatch_size, max_sequence_length, vocab_size = 2, 1024, 2 + microbatch_outputs = [ + get_random_microbatch(microbatch_size, max_sequence_length // 2, vocab_size, seed), + get_random_microbatch(microbatch_size, max_sequence_length, vocab_size, seed), + ] + + # setup metric + megatron_ppl_metric = MegatronPerplexityMetric(ignore_index=-100).to(torch.cuda.current_device()) + metric = Perplexity(ignore_index=-100).to(torch.cuda.current_device()) + + # compute values + for microbatch_output in microbatch_outputs: + megatron_ppl_metric.update( + microbatch_output["forward_out"]["token_logits"] + .transpose(0, 1) + .contiguous(), # (s, b, v) -> (b, s, v) + microbatch_output["batch"]["labels"], + ) + metric.update( + microbatch_output["forward_out"]["token_logits"] + .transpose(0, 1) + .contiguous(), # (s, b, v) -> (b, s, v) + microbatch_output["batch"]["labels"], + ) + ppl_value = megatron_ppl_metric.compute() + ppl_golden_value = metric.compute() + + torch.testing.assert_close( + ppl_value, + ppl_golden_value, + ) + + def test_perplexity_logging_callback_with_single_microbatch_golden_value_without_parallelism(seed: int = 42): """Test PerplexityLoggingCallback with a single microbatch without parallelism""" with megatron_parallel_state_utils.distributed_model_parallel_state(seed=seed): diff --git a/sub-packages/bionemo-testing/src/bionemo/testing/lightning.py b/sub-packages/bionemo-testing/src/bionemo/testing/lightning.py index 67acf2b945..010098c4e4 100644 --- a/sub-packages/bionemo-testing/src/bionemo/testing/lightning.py +++ b/sub-packages/bionemo-testing/src/bionemo/testing/lightning.py @@ -20,7 +20,11 @@ def get_random_microbatch( - microbatch_size: int, max_sequence_length: int, vocab_size: int, seed: int + microbatch_size: int, + max_sequence_length: int, + vocab_size: int, + seed: int, + mask_index: int = -100, ) -> Dict[str, Dict[str, torch.Tensor]]: """Generate random microbatches for testing. @@ -35,7 +39,7 @@ def get_random_microbatch( device=torch.cuda.current_device(), ) # [b s] loss_mask = torch.randint( - low=1, + low=0, high=1 + 1, size=(microbatch_size, max_sequence_length), dtype=torch.long, @@ -45,7 +49,7 @@ def get_random_microbatch( token_logits = torch.rand( max_sequence_length, microbatch_size, vocab_size, device=torch.cuda.current_device(), generator=generator ) # [s b v] - labels[loss_mask == 0] = -100 # propagate masking to labels + labels[loss_mask == 0] = mask_index # propagate masking to labels microbatch_output = { "batch": {"labels": labels, "loss_mask": loss_mask}, "forward_out": {"token_logits": token_logits},