Skip to content

Commit

Permalink
fix perplexity logging
Browse files Browse the repository at this point in the history
Signed-off-by: sichu <[email protected]>
  • Loading branch information
sichu2023 committed Jan 17, 2025
1 parent 7f9dd97 commit e7191c8
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 268 deletions.
33 changes: 31 additions & 2 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from bionemo.esm2.data.datamodule import ESMDataModule
from bionemo.esm2.data.dataset import RandomMaskStrategy
from bionemo.esm2.data.tokenizer import get_tokenizer
from bionemo.llm.lightning import PerplexityLoggingCallback
from bionemo.llm.model.biobert.lightning import biobert_lightning_module
from bionemo.llm.model.biobert.model import BiobertSpecOption
from bionemo.llm.model.lr_scheduler import WarmupAnnealDecayHoldScheduler
Expand Down Expand Up @@ -84,6 +83,8 @@ def main(
save_best_checkpoint: bool = True,
save_last_checkpoint: bool = True,
metric_to_monitor_for_checkpoints: str = "val_loss",
log_train_ppl: bool = False,
log_val_ppl: bool = True,
save_top_k: int = 2,
nsys_profiling: bool = False,
nsys_start_step: int = 0,
Expand Down Expand Up @@ -145,6 +146,8 @@ def main(
save_best_checkpoint (bool): whether to save the best checkpoint
save_last_checkpoint (bool): whether to save the last checkpoint
metric_to_monitor_for_checkpoints (str): metric to monitor for checkpoints
log_train_ppl (bool): log training perplexity
log_val_ppl (bool): log validation perplexity
save_top_k (int): number of top checkpoints to save
nsys_profiling (bool): whether to enable nsys profiling
nsys_start_step (int): start step for nsys profiling
Expand Down Expand Up @@ -211,7 +214,6 @@ def main(
)

callbacks = [
PerplexityLoggingCallback(log_train=False, log_val=True),
RichModelSummary(max_depth=4),
LearningRateMonitor(),
nl_callbacks.PreemptionCallback(),
Expand Down Expand Up @@ -281,6 +283,9 @@ def main(
if scheduler_num_steps is None:
scheduler_num_steps = num_steps

if (log_train_ppl or log_val_ppl) and pipeline_model_parallel_size > 1:
raise NotImplementedError("Perplexity logging does not support pipeline parallelism yet.")

model = biobert_lightning_module(
esm2_config,
tokenizer=tokenizer,
Expand All @@ -301,6 +306,9 @@ def main(
anneal_percentage=0.10,
),
),
# perplexity logging
log_train_ppl=log_train_ppl,
log_val_ppl=log_val_ppl,
)

# Configure our custom Checkpointer
Expand Down Expand Up @@ -384,6 +392,8 @@ def train_esm2_entrypoint():
save_best_checkpoint=args.save_best_checkpoint,
save_last_checkpoint=args.save_last_checkpoint,
metric_to_monitor_for_checkpoints=args.metric_to_monitor_for_checkpoints,
log_train_ppl=args.log_train_ppl,
log_val_ppl=args.log_val_ppl,
save_top_k=args.save_top_k,
nsys_profiling=args.nsys_profiling,
nsys_start_step=args.nsys_start_step,
Expand Down Expand Up @@ -637,6 +647,25 @@ def get_parser():
default="val_loss",
help="The metric to monitor for checkpointing.",
)
parser.add_argument(
"--log-train-ppl",
action="store_true",
default=False,
help="Log perplexity during training. Requires synchronization every training step and hurts performance. Enable only when necessary.",
)
parser.add_argument(
"--log-val-ppl",
action="store_true",
default=False,
help="Log perplexity during validation.",
)
parser.add_argument(
"--no-log-val-ppl",
action="store_false",
dest="log_val_ppl",
default=True,
help="Disable logging perplexity during validation.",
)
parser.add_argument(
"--save-top-k",
type=int,
Expand Down
155 changes: 42 additions & 113 deletions sub-packages/bionemo-llm/src/bionemo/llm/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Dict, Generic, Iterable, Iterator, List, Optional, Sequence, Tuple, TypeVar, Union
from typing import Any, Callable, Generic, Iterable, Iterator, List, Optional, Sequence, Tuple, TypeVar, Union

import lightning.pytorch as pl
import torch.distributed
import torchmetrics.text
from megatron.core import parallel_state
from megatron.core.optimizer.optimizer_config import OptimizerConfig
from nemo.lightning import io as nlio
from nemo.lightning.megatron_parallel import (
CallbackMethods,
DataT,
MegatronLossReduction,
MegatronStep,
ReductionT,
)
from nemo.lightning.pytorch.optim import MegatronOptimizerModule
from torch import Tensor
from typing_extensions import override

from bionemo.core.model.config import BionemoTrainableModelConfig
from bionemo.llm.api import MegatronLossType, MegatronModelType
from bionemo.llm.model.loss import unreduced_token_loss_fn


__all__: Sequence[str] = (
"get_dtype_device",
"batch_collator",
"PassthroughLossReduction",
"LightningPassthroughPredictionMixin",
"PerplexityLoggingCallback",
"BionemoLightningModule",
"default_megatron_optimizer",
)
Expand Down Expand Up @@ -227,6 +223,8 @@ def __init__(
# TODO: Add transformer_layer_spec when we update mcore
optimizer: MegatronOptimizerModule,
model_transform: Optional[Callable[[MegatronModelType], MegatronModelType]] = None,
log_train_ppl: bool = False,
log_val_ppl: bool = False,
**model_construct_args,
) -> None:
"""Constructor.
Expand All @@ -242,6 +240,8 @@ def __init__(
model_construct_args: Optional. Any arguments necessary to construct the model in the `config`'s
`configure_model` method.
model_transform: Optional. The model transform function.
log_train_ppl (bool): Log training perplexity.
log_val_ppl (bool): Log validation perplexity.
**model_construct_args: Optional. Arguments necessary for the supplied model configuration's
`configure_model` method, which will make an instance of the model.
"""
Expand All @@ -258,6 +258,10 @@ def __init__(
self._forward_step = forward_step
self.model_transform = model_transform

# torchmetrics must init here for fiddle serialization
self.train_ppl = torchmetrics.text.Perplexity(ignore_index=-100) if log_train_ppl else None
self.valid_ppl = torchmetrics.text.Perplexity(ignore_index=-100) if log_val_ppl else None

def configure_model(self) -> None:
"""Updates internal state: instantiates the model from the object's config, assigns to `model` attribute.
Expand All @@ -273,9 +277,14 @@ def configure_model(self) -> None:
else self.config.configure_model()
)
self.module = model

if self.module is None:
raise ValueError("Invalid semantics: configure_model method **MUST** initialize the model.")

def is_on_logging_device(self):
"""Return True if last stage of pipeline parallel and first tensor parallel rank."""
return parallel_state.is_pipeline_last_stage() and parallel_state.get_tensor_model_parallel_rank() == 0

def forward(self, *args, **kwargs) -> DataT:
"""Call the forward method of the underlying model, and return whatever it outputs."""
# safe to do because configure_model is idempotent
Expand Down Expand Up @@ -304,11 +313,26 @@ def forward_step(self, batch) -> Tensor:

def training_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
"""In mcore the loss-function is part of the forward-pass when labels are provided."""
return self.forward_step(batch)
outputs = self.forward_step(batch)
logits = outputs["token_logits"].detach().transpose(0, 1).clone() # [s, b, v] -> [b, s, v]

if self.train_ppl is not None:
if self.is_on_logging_device():
self.train_ppl(logits, batch["labels"])

self.log("train_ppl", self.train_ppl, on_step=True, on_epoch=False, prog_bar=True)

return outputs

def validation_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
"""In mcore the loss-function is part of the forward-pass when labels are provided."""
return self.forward_step(batch)
outputs = self.forward_step(batch)
logits = outputs["token_logits"].detach().transpose(0, 1).clone() # [s, b, v] -> [b, s, v]

if self.valid_ppl is not None and self.is_on_logging_device():
self.valid_ppl.update(logits, batch["labels"])

return outputs

def predict_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
"""Alias for forward_step."""
Expand All @@ -326,114 +350,19 @@ def validation_loss_reduction(self) -> MegatronLossType: # noqa: D102
def test_loss_reduction(self) -> MegatronLossType: # noqa: D102
return self.loss_reduction_class(validation_step=True)

def on_validation_epoch_end(self): # noqa: D102
if self.valid_ppl is None:
return

if self.trainer.sanity_checking:
self.valid_ppl.reset() # clean up sanity runs
return

self.log("valid_ppl", self.valid_ppl, on_step=False, on_epoch=True, prog_bar=True)


def default_megatron_optimizer() -> MegatronOptimizerModule:
"""Default distributed optimizer uses Adam with a 1e-4 learning rate."""
return MegatronOptimizerModule(
config=OptimizerConfig(lr=1e-4, optimizer="adam", use_distributed_optimizer=True),
)


class PerplexityLoggingCallback(pl.Callback, CallbackMethods):
"""Megatron Callback to log perplexity in validation and optionally training.
NeMo2.0 checks whether a callback is an instance of {LightningModule,LightningDataModule,Callback} but only megatron_hooks are useful.
"""

def __init__(self, log_train: bool = False, log_val: bool = True):
"""Initialize PerplexityLoggingCallback.
Args:
log_train: whether to log train perplexity. Defaults to False.
log_val: whether to log validation perplexity. Defaults to True.
"""
super().__init__()
self.log_train = log_train
self.log_val = log_val

def _pad_to_max_length(
self,
microbatch_outputs: List[Dict[str, Dict[str, Tensor]]],
key1: str,
key2: str,
pad_value: int = 0,
seq_dim: int = 1,
batch_dim: int = 0,
) -> Tensor:
"""Pad tensors to max length in microbatch_outputs."""
assert seq_dim != batch_dim, "Forgot to set one of seq_dim, batch_dim, they are equal!"
max_sequence_length: int = max(output[key1][key2].shape[seq_dim] for output in microbatch_outputs)

tensors: List[Tensor] = []
for microbatch_output in microbatch_outputs:
tensor = microbatch_output[key1][key2]
assert (
tensor.dim() >= 2
), f"Tensor in microbatch_outputs must have at least 2 dimensions, but got {tensor.dim()} dimensions"
pad_size = [(0, 0)] * tensor.dim()
pad_size[seq_dim] = (0, max_sequence_length - tensor.shape[seq_dim])
# Flatten pad size list for F.pad
pad_size_flat = [item for sublist in reversed(pad_size) for item in sublist]
tensors.append(
torch.nn.functional.pad( # padding reverse in order
tensor,
pad_size_flat,
mode="constant",
value=pad_value,
)
)

return torch.cat(tensors, dim=batch_dim) # concat on batch dim

@override
def on_megatron_reduce_microbatches_end(
self,
step: MegatronStep,
microbatch_outputs: List[Any],
loss_reduction: MegatronLossReduction,
reduced: Tensor | dict[str, Tensor],
) -> None:
"""Log after MegatronReductionLoss.reduce is called.
Expected microbatch_outputs to be a list of dicts with the following keys:
- batch: dict of tensors with the following keys:
- labels: [b s]
- loss_mask: [b s]; 1 means included 0 means ignored
- forward_out: dict of tensors with the following keys:
- token_logits: [b s vocab]
"""
if step.trainer.sanity_checking: # skip sanity check
return

if step.trainer.training and not self.log_train:
return

if not parallel_state.is_pipeline_last_stage():
return

assert step.num_microbatches is not None, "num_microbatches must be initialized to non-None"
assert step.num_microbatches > 0, "num_microbatches must be greater than 0"
assert (
len(microbatch_outputs) == step.num_microbatches
), "microbatch_outputs length does not match num_microbatches"
labels = self._pad_to_max_length(microbatch_outputs, "batch", "labels", pad_value=-100)
loss_mask = self._pad_to_max_length(microbatch_outputs, "batch", "loss_mask")
token_logits = self._pad_to_max_length(
microbatch_outputs, "forward_out", "token_logits", seq_dim=0, batch_dim=1
)

unreduced_token_loss = unreduced_token_loss_fn(
token_logits.clone(), # [s,b] as expected unreduced_token_loss_fn has inplace operation on token_logits
labels.clone(), # [b,s] as expected
) # [b s] is the return

cp_size = parallel_state.get_context_parallel_world_size()
if cp_size == 1:
ppl = torch.exp((unreduced_token_loss * loss_mask).sum() / loss_mask.sum())
else:
raise NotImplementedError("Context parallel perplexity logging is not supported yet")

if self.log_val and not step.trainer.training:
step.pl_module.log("val_ppl", ppl, prog_bar=True, on_epoch=True)
elif self.log_train and step.trainer.training:
step.pl_module.log("train_ppl", ppl, prog_bar=True, batch_size=1, sync_dist=False)
Loading

0 comments on commit e7191c8

Please sign in to comment.