Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix perplexity logging #622

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from bionemo.esm2.data.datamodule import ESMDataModule
from bionemo.esm2.data.dataset import RandomMaskStrategy
from bionemo.esm2.data.tokenizer import get_tokenizer
from bionemo.llm.lightning import PerplexityLoggingCallback
from bionemo.llm.model.biobert.lightning import biobert_lightning_module
from bionemo.llm.model.biobert.model import BiobertSpecOption
from bionemo.llm.model.lr_scheduler import WarmupAnnealDecayHoldScheduler
Expand Down Expand Up @@ -84,6 +83,8 @@ def main(
save_best_checkpoint: bool = True,
save_last_checkpoint: bool = True,
metric_to_monitor_for_checkpoints: str = "val_loss",
log_train_ppl: bool = False,
log_val_ppl: bool = True,
save_top_k: int = 2,
nsys_profiling: bool = False,
nsys_start_step: int = 0,
Expand Down Expand Up @@ -145,6 +146,8 @@ def main(
save_best_checkpoint (bool): whether to save the best checkpoint
save_last_checkpoint (bool): whether to save the last checkpoint
metric_to_monitor_for_checkpoints (str): metric to monitor for checkpoints
log_train_ppl (bool): log training perplexity
log_val_ppl (bool): log validation perplexity
save_top_k (int): number of top checkpoints to save
nsys_profiling (bool): whether to enable nsys profiling
nsys_start_step (int): start step for nsys profiling
Expand Down Expand Up @@ -211,7 +214,6 @@ def main(
)

callbacks = [
PerplexityLoggingCallback(log_train=False, log_val=True),
RichModelSummary(max_depth=4),
LearningRateMonitor(),
nl_callbacks.PreemptionCallback(),
Expand Down Expand Up @@ -301,6 +303,9 @@ def main(
anneal_percentage=0.10,
),
),
# perplexity logging
log_train_ppl=log_train_ppl,
log_val_ppl=log_val_ppl,
)

# Configure our custom Checkpointer
Expand Down Expand Up @@ -384,6 +389,8 @@ def train_esm2_entrypoint():
save_best_checkpoint=args.save_best_checkpoint,
save_last_checkpoint=args.save_last_checkpoint,
metric_to_monitor_for_checkpoints=args.metric_to_monitor_for_checkpoints,
log_train_ppl=args.log_train_ppl,
log_val_ppl=args.log_val_ppl,
save_top_k=args.save_top_k,
nsys_profiling=args.nsys_profiling,
nsys_start_step=args.nsys_start_step,
Expand Down Expand Up @@ -637,6 +644,25 @@ def get_parser():
default="val_loss",
help="The metric to monitor for checkpointing.",
)
parser.add_argument(
"--log-train-ppl",
action="store_true",
default=False,
help="Log perplexity during training. Requires synchronization every training step and hurts performance. Enable only when necessary.",
)
parser.add_argument(
"--log-val-ppl",
action="store_true",
default=False,
help="Log perplexity during validation.",
)
parser.add_argument(
"--no-log-val-ppl",
action="store_false",
dest="log_val_ppl",
default=True,
help="Disable logging perplexity during validation.",
)
parser.add_argument(
"--save-top-k",
type=int,
Expand Down
14 changes: 1 addition & 13 deletions sub-packages/bionemo-geneformer/src/bionemo/geneformer/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,6 @@
GeneformerModel = MegatronBioBertModel


class BERTMLMLossWithReductionNoForward(BERTMLMLossWithReduction):
def __init__(
self,
validation_step: bool = False,
val_drop_last: bool = True,
send_train_output: bool = False,
send_val_output: bool = False,
) -> None:
"""Same as BERTMLMLossWithReduction but set send_val_output=False by default since we do not use perplexity."""
super().__init__(validation_step, val_drop_last, send_train_output, send_val_output)


@dataclass
class GeneformerConfig(BioBertConfig[GeneformerModel, MegatronLossType], iom.IOMixinWithGettersSetters):
"""A geneformer config.
Expand Down Expand Up @@ -88,4 +76,4 @@ class GeneformerConfig(BioBertConfig[GeneformerModel, MegatronLossType], iom.IOM

enable_autocast: bool = False
model_cls: Type[GeneformerModel] = GeneformerModel
loss_reduction_class: Type[MegatronLossType] = BERTMLMLossWithReductionNoForward
loss_reduction_class: Type[MegatronLossType] = BERTMLMLossWithReduction
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,7 @@ def forward(
loss_for_microbatch = loss_for_microbatch + rmse_loss # add in the RMSE loss after reducing the logit loss
# average the losses across the data parallel group, but also return the unreduced loss
reduced_loss: Tensor = average_losses_across_data_parallel_group([loss_for_microbatch])
if (self.validation_step and self.send_val_output) or (not self.validation_step and self.send_train_output):
return loss_for_microbatch * cp_size, {"avg": reduced_loss, "batch": batch, "forward_out": forward_out}
else:
return loss_for_microbatch * cp_size, {"avg": reduced_loss}
return loss_for_microbatch * cp_size, {"avg": reduced_loss}


class MegatronRegressionMLPHead(MegatronModule):
Expand Down
155 changes: 42 additions & 113 deletions sub-packages/bionemo-llm/src/bionemo/llm/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Dict, Generic, Iterable, Iterator, List, Optional, Sequence, Tuple, TypeVar, Union
from typing import Any, Callable, Generic, Iterable, Iterator, List, Optional, Sequence, Tuple, TypeVar, Union

import lightning.pytorch as pl
import torch.distributed
import torchmetrics.text
from megatron.core import parallel_state
from megatron.core.optimizer.optimizer_config import OptimizerConfig
from nemo.lightning import io as nlio
from nemo.lightning.megatron_parallel import (
CallbackMethods,
DataT,
MegatronLossReduction,
MegatronStep,
ReductionT,
)
from nemo.lightning.pytorch.optim import MegatronOptimizerModule
from torch import Tensor
from typing_extensions import override

from bionemo.core.model.config import BionemoTrainableModelConfig
from bionemo.llm.api import MegatronLossType, MegatronModelType
from bionemo.llm.model.loss import unreduced_token_loss_fn


__all__: Sequence[str] = (
"get_dtype_device",
"batch_collator",
"PassthroughLossReduction",
"LightningPassthroughPredictionMixin",
"PerplexityLoggingCallback",
"BionemoLightningModule",
"default_megatron_optimizer",
)
Expand Down Expand Up @@ -227,6 +223,8 @@ def __init__(
# TODO: Add transformer_layer_spec when we update mcore
optimizer: MegatronOptimizerModule,
model_transform: Optional[Callable[[MegatronModelType], MegatronModelType]] = None,
log_train_ppl: bool = False,
log_val_ppl: bool = False,
**model_construct_args,
) -> None:
"""Constructor.
Expand All @@ -242,6 +240,8 @@ def __init__(
model_construct_args: Optional. Any arguments necessary to construct the model in the `config`'s
`configure_model` method.
model_transform: Optional. The model transform function.
log_train_ppl (bool): Log training perplexity.
log_val_ppl (bool): Log validation perplexity.
**model_construct_args: Optional. Arguments necessary for the supplied model configuration's
`configure_model` method, which will make an instance of the model.
"""
Expand All @@ -258,6 +258,10 @@ def __init__(
self._forward_step = forward_step
self.model_transform = model_transform

# torchmetrics must init here for fiddle serialization
self.train_ppl = torchmetrics.text.Perplexity(ignore_index=-100) if log_train_ppl else None
self.valid_ppl = torchmetrics.text.Perplexity(ignore_index=-100) if log_val_ppl else None

def configure_model(self) -> None:
"""Updates internal state: instantiates the model from the object's config, assigns to `model` attribute.

Expand All @@ -273,9 +277,14 @@ def configure_model(self) -> None:
else self.config.configure_model()
)
self.module = model

if self.module is None:
raise ValueError("Invalid semantics: configure_model method **MUST** initialize the model.")

def is_on_logging_device(self):
"""Return True if last stage of pipeline parallel and first tensor parallel rank."""
return parallel_state.is_pipeline_last_stage() and parallel_state.get_tensor_model_parallel_rank() == 0

def forward(self, *args, **kwargs) -> DataT:
"""Call the forward method of the underlying model, and return whatever it outputs."""
# safe to do because configure_model is idempotent
Expand Down Expand Up @@ -304,11 +313,26 @@ def forward_step(self, batch) -> Tensor:

def training_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
"""In mcore the loss-function is part of the forward-pass when labels are provided."""
return self.forward_step(batch)
outputs = self.forward_step(batch)
logits = outputs["token_logits"].detach().transpose(0, 1).clone() # [s, b, v] -> [b, s, v]

if self.train_ppl is not None:
if self.is_on_logging_device():
self.train_ppl(logits, batch["labels"])

self.log("train_ppl", self.train_ppl, on_step=True, on_epoch=False, prog_bar=True)

return outputs

def validation_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
"""In mcore the loss-function is part of the forward-pass when labels are provided."""
return self.forward_step(batch)
outputs = self.forward_step(batch)
logits = outputs["token_logits"].detach().transpose(0, 1).clone() # [s, b, v] -> [b, s, v]

if self.valid_ppl is not None and self.is_on_logging_device():
self.valid_ppl.update(logits, batch["labels"])

return outputs

def predict_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
"""Alias for forward_step."""
Expand All @@ -326,114 +350,19 @@ def validation_loss_reduction(self) -> MegatronLossType: # noqa: D102
def test_loss_reduction(self) -> MegatronLossType: # noqa: D102
return self.loss_reduction_class(validation_step=True)

def on_validation_epoch_end(self): # noqa: D102
if self.valid_ppl is None:
return

if self.trainer.sanity_checking:
self.valid_ppl.reset() # clean up sanity runs
return

self.log("valid_ppl", self.valid_ppl, on_step=False, on_epoch=True, prog_bar=True)


def default_megatron_optimizer() -> MegatronOptimizerModule:
"""Default distributed optimizer uses Adam with a 1e-4 learning rate."""
return MegatronOptimizerModule(
config=OptimizerConfig(lr=1e-4, optimizer="adam", use_distributed_optimizer=True),
)


class PerplexityLoggingCallback(pl.Callback, CallbackMethods):
"""Megatron Callback to log perplexity in validation and optionally training.

NeMo2.0 checks whether a callback is an instance of {LightningModule,LightningDataModule,Callback} but only megatron_hooks are useful.
"""

def __init__(self, log_train: bool = False, log_val: bool = True):
"""Initialize PerplexityLoggingCallback.

Args:
log_train: whether to log train perplexity. Defaults to False.
log_val: whether to log validation perplexity. Defaults to True.
"""
super().__init__()
self.log_train = log_train
self.log_val = log_val

def _pad_to_max_length(
self,
microbatch_outputs: List[Dict[str, Dict[str, Tensor]]],
key1: str,
key2: str,
pad_value: int = 0,
seq_dim: int = 1,
batch_dim: int = 0,
) -> Tensor:
"""Pad tensors to max length in microbatch_outputs."""
assert seq_dim != batch_dim, "Forgot to set one of seq_dim, batch_dim, they are equal!"
max_sequence_length: int = max(output[key1][key2].shape[seq_dim] for output in microbatch_outputs)

tensors: List[Tensor] = []
for microbatch_output in microbatch_outputs:
tensor = microbatch_output[key1][key2]
assert (
tensor.dim() >= 2
), f"Tensor in microbatch_outputs must have at least 2 dimensions, but got {tensor.dim()} dimensions"
pad_size = [(0, 0)] * tensor.dim()
pad_size[seq_dim] = (0, max_sequence_length - tensor.shape[seq_dim])
# Flatten pad size list for F.pad
pad_size_flat = [item for sublist in reversed(pad_size) for item in sublist]
tensors.append(
torch.nn.functional.pad( # padding reverse in order
tensor,
pad_size_flat,
mode="constant",
value=pad_value,
)
)

return torch.cat(tensors, dim=batch_dim) # concat on batch dim

@override
def on_megatron_reduce_microbatches_end(
self,
step: MegatronStep,
microbatch_outputs: List[Any],
loss_reduction: MegatronLossReduction,
reduced: Tensor | dict[str, Tensor],
) -> None:
"""Log after MegatronReductionLoss.reduce is called.

Expected microbatch_outputs to be a list of dicts with the following keys:
- batch: dict of tensors with the following keys:
- labels: [b s]
- loss_mask: [b s]; 1 means included 0 means ignored
- forward_out: dict of tensors with the following keys:
- token_logits: [b s vocab]
"""
if step.trainer.sanity_checking: # skip sanity check
return

if step.trainer.training and not self.log_train:
return

if not parallel_state.is_pipeline_last_stage():
return

assert step.num_microbatches is not None, "num_microbatches must be initialized to non-None"
assert step.num_microbatches > 0, "num_microbatches must be greater than 0"
assert (
len(microbatch_outputs) == step.num_microbatches
), "microbatch_outputs length does not match num_microbatches"
labels = self._pad_to_max_length(microbatch_outputs, "batch", "labels", pad_value=-100)
loss_mask = self._pad_to_max_length(microbatch_outputs, "batch", "loss_mask")
token_logits = self._pad_to_max_length(
microbatch_outputs, "forward_out", "token_logits", seq_dim=0, batch_dim=1
)

unreduced_token_loss = unreduced_token_loss_fn(
token_logits.clone(), # [s,b] as expected unreduced_token_loss_fn has inplace operation on token_logits
labels.clone(), # [b,s] as expected
) # [b s] is the return

cp_size = parallel_state.get_context_parallel_world_size()
if cp_size == 1:
ppl = torch.exp((unreduced_token_loss * loss_mask).sum() / loss_mask.sum())
else:
raise NotImplementedError("Context parallel perplexity logging is not supported yet")

if self.log_val and not step.trainer.training:
step.pl_module.log("val_ppl", ppl, prog_bar=True, on_epoch=True)
elif self.log_train and step.trainer.training:
step.pl_module.log("train_ppl", ppl, prog_bar=True, batch_size=1, sync_dist=False)
Loading
Loading