diff --git a/src/fairseq2/models/mbart/__init__.py b/src/fairseq2/models/mbart/__init__.py new file mode 100644 index 000000000..a04c1984f --- /dev/null +++ b/src/fairseq2/models/mbart/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq2.models.mbart.builder import create_mbart_model as create_mbart_model +from fairseq2.models.mbart.builder import mbart_arch as mbart_arch +from fairseq2.models.mbart.builder import mbart_archs as mbart_archs +from fairseq2.models.mbart.builder import mBartBuilder as mBartBuilder +from fairseq2.models.mbart.builder import mBartConfig as mBartConfig +from fairseq2.models.mbart.loader import load_mbart_model as load_mbart_model +from fairseq2.models.mbart.loader import load_mbart_tokenizer as load_mbart_tokenizer +from fairseq2.models.mbart.loader import mBartLoader as mBartLoader +from fairseq2.models.mbart.tokenizer import mBartTokenizer as mBartTokenizer diff --git a/src/fairseq2/models/mbart/builder.py b/src/fairseq2/models/mbart/builder.py new file mode 100644 index 000000000..30b13c89c --- /dev/null +++ b/src/fairseq2/models/mbart/builder.py @@ -0,0 +1,291 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Literal, Optional + +from fairseq2.data import VocabularyInfo +from fairseq2.models.transformer import ( + TransformerEmbeddingFrontend, + TransformerFrontend, + TransformerModel, +) +from fairseq2.models.utils.arch_registry import ArchitectureRegistry +from fairseq2.nn.embedding import Embedding +from fairseq2.nn.position_encoder import ( + LearnedPositionEncoder, + SinusoidalPositionEncoder, +) +from fairseq2.nn.projection import TiedProjection +from fairseq2.nn.transformer import ( + FeedForwardNetwork, + MultiheadAttention, + StandardFeedForwardNetwork, + StandardMultiheadAttention, + StandardTransformerDecoder, + StandardTransformerDecoderLayer, + StandardTransformerEncoder, + StandardTransformerEncoderLayer, + TransformerDecoder, + TransformerDecoderLayer, + TransformerEncoder, + TransformerEncoderLayer, + TransformerNormOrder, + create_default_sdpa, +) +from fairseq2.typing import DataType, Device + + +@dataclass +class mBartConfig: + """Holds the configuration of an mBart model.""" + + model_dim: int + """The dimensionality of the model.""" + + max_seq_len: int + """The expected maximum sequence length.""" + + vocabulary_size: int + """The size of the vocabulary.""" + + pad_idx: Optional[int] + """The index of the pad symbol in the vocabulary.""" + + num_encoder_layers: int + """The number of Transformer encoder layers.""" + + num_decoder_layers: int + """The number of Transformer decoder layers.""" + + num_encoder_attn_heads: int + """The number of attention heads in Transformer encoder layers.""" + + num_decoder_attn_heads: int + """The number of attention heads in Transformer decoder layers.""" + + ffn_inner_dim: int + """The inner dimensionality of Transformer feed-forward networks.""" + + # Position Encoder + pos_encoder_type: Literal["sinusoidal", "learned"] + """The type of position encoder.""" + + layer_norm_embed: bool + """Adds a layernorm to the embedding in the Transformer encoder.""" + + dropout_p: float + """The dropout probability in Transformer layers.""" + + def update_vocabulary(self, info: VocabularyInfo) -> None: + """Update vocabulary configuration from ``info``.""" + self.vocabulary_size, self.pad_idx = info.size, info.pad_idx + + +mbart_archs = ArchitectureRegistry[mBartConfig]("mbart") + + +mbart_arch = mbart_archs.marker + + +@mbart_arch("base") +def _base() -> mBartConfig: + return mBartConfig( + model_dim=1024, + max_seq_len=1026, + vocabulary_size=65539, + pad_idx=0, + num_encoder_layers=12, + num_decoder_layers=12, + num_encoder_attn_heads=16, + num_decoder_attn_heads=16, + ffn_inner_dim=4096, + pos_encoder_type="learned", + layer_norm_embed=True, + dropout_p=0.1, + ) + + +class mBartBuilder: + """Builds modules of an mBart model as described in + :cite:t:`https://arxiv.org/abs/2001.08210`. + + To tweak the architecture, you can derive from this class and override the + corresponding methods. + """ + + config: mBartConfig + device: Optional[Device] + dtype: Optional[DataType] + + def __init__( + self, + config: mBartConfig, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + ) -> None: + """ + :param config: + The configuration to use. + :param device: + The device on which to initialize modules. + :param dtype: + The data type of module parameters and buffers. + """ + self.config = config + self.device = device + self.dtype = dtype + + def build_model(self) -> TransformerModel: + """Build a model.""" + embed = self.build_embedding() + + frontend = self.build_frontend(embed) + + encoder = self.build_encoder() + decoder = self.build_decoder() + + final_proj = TiedProjection(embed.weight) + + return TransformerModel( + frontend, encoder, frontend, decoder, final_proj, self.config.pad_idx + ) + + def build_embedding(self) -> Embedding: + """Build an embedding table.""" + return Embedding( + num_embeddings=self.config.vocabulary_size, + embedding_dim=self.config.model_dim, + pad_idx=self.config.pad_idx, + scaled=True, + device=self.device, + dtype=self.dtype, + ) + + def build_frontend(self, embed: Embedding) -> TransformerFrontend: + """Build a Transformer encoder/decoder front-end.""" + if self.config.pos_encoder_type == "sinusoidal": + pos_encoder = SinusoidalPositionEncoder( + self.config.model_dim, + self.config.max_seq_len, + _legacy_pad_idx=self.config.pad_idx, + device=self.device, + dtype=self.dtype, + ) + else: + pos_encoder = LearnedPositionEncoder( + self.config.model_dim, + self.config.max_seq_len, + device=self.device, + dtype=self.dtype, + ) + + return TransformerEmbeddingFrontend( + embed, + pos_encoder, + layer_norm=self.config.layer_norm_embed, + dropout_p=self.config.dropout_p, + device=self.device, + dtype=self.dtype, + ) + + def build_encoder(self) -> TransformerEncoder: + """Build a Transformer encoder.""" + num_layers = self.config.num_encoder_layers + + layers = [self.build_encoder_layer() for _ in range(num_layers)] + + return StandardTransformerEncoder( + layers, + norm_order=TransformerNormOrder.PRE, + device=self.device, + dtype=self.dtype, + ) + + def build_decoder(self) -> TransformerDecoder: + """Build a Transformer decoder.""" + num_layers = self.config.num_decoder_layers + + layers = [self.build_decoder_layer() for _ in range(num_layers)] + + return StandardTransformerDecoder( + layers, + norm_order=TransformerNormOrder.PRE, + device=self.device, + dtype=self.dtype, + ) + + def build_encoder_layer(self) -> TransformerEncoderLayer: + """Build a Transformer encoder layer.""" + self_attn = self.build_attention(self.config.num_encoder_attn_heads) + + ffn = self.build_ffn() + + return StandardTransformerEncoderLayer( + self_attn, + ffn, + dropout_p=self.config.dropout_p, + norm_order=TransformerNormOrder.PRE, + device=self.device, + dtype=self.dtype, + ) + + def build_decoder_layer(self) -> TransformerDecoderLayer: + """Build a Transformer decoder layer.""" + self_attn = self.build_attention(self.config.num_decoder_attn_heads) + + encoder_decoder_attn = self.build_attention(self.config.num_decoder_attn_heads) + + ffn = self.build_ffn() + + return StandardTransformerDecoderLayer( + self_attn, + encoder_decoder_attn, + ffn, + dropout_p=self.config.dropout_p, + norm_order=TransformerNormOrder.PRE, + device=self.device, + dtype=self.dtype, + ) + + def build_attention(self, num_heads: int) -> MultiheadAttention: + """Build a Transformer multi-head attention layer.""" + sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p) + + return StandardMultiheadAttention( + self.config.model_dim, + num_heads, + sdpa=sdpa, + device=self.device, + dtype=self.dtype, + ) + + def build_ffn(self) -> FeedForwardNetwork: + """Build a Transformer feed-forward network.""" + return StandardFeedForwardNetwork( + self.config.model_dim, + self.config.ffn_inner_dim, + norm_order=TransformerNormOrder.PRE, + device=self.device, + dtype=self.dtype, + ) + + +def create_mbart_model( + config: mBartConfig, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, +) -> TransformerModel: + """Create an mBart model. + + :param config: + The configuration to use. + :param device: + The device on which to initialize modules. + :param dtype: + The data type of module parameters and buffers. + """ + return mBartBuilder(config, device, dtype).build_model() diff --git a/src/fairseq2/models/mbart/loader.py b/src/fairseq2/models/mbart/loader.py new file mode 100644 index 000000000..efa85e4c1 --- /dev/null +++ b/src/fairseq2/models/mbart/loader.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, Mapping, Union, final + +from fairseq2.assets import ( + AssetCard, + AssetDownloadManager, + AssetStore, + asset_store, + download_manager, +) +from fairseq2.models.mbart.builder import create_mbart_model, mbart_archs, mBartConfig +from fairseq2.models.mbart.tokenizer import mBartTokenizer +from fairseq2.models.transformer import TransformerModel +from fairseq2.models.utils.checkpoint_loader import upgrade_fairseq_checkpoint +from fairseq2.models.utils.model_loader import ModelConfigLoader, ModelLoader +from fairseq2.typing import finaloverride + +import torch + + +@final +class mBartLoader(ModelLoader[TransformerModel, mBartConfig]): + """Loads mBart models.""" + + @finaloverride + def _upgrade_checkpoint( + self, checkpoint: Mapping[str, Any], config: mBartConfig + ) -> Mapping[str, Any]: + state_dict = checkpoint["model"] + + # Check if we have a fairseq2 checkpoint. + if "decoder_frontend.embed_weight" in state_dict: + return checkpoint + + key_map = self._fairseq_key_map() + + # Convert to fairseq2. + checkpoint = upgrade_fairseq_checkpoint(checkpoint, key_map) + + state_dict = checkpoint["model"] + + embeds = state_dict["final_proj.weight"] + + # fairseq checkpoints have duplicate embedding weights. Ensure that we + # use a single embedding table in fairseq2. + state_dict["encoder_frontend.embed.weight"] = embeds + state_dict["decoder_frontend.embed.weight"] = embeds + + # The embedding positions of the control symbols in fairseq's dict do + # not match the SentencePiece model of the tokenizer. + with torch.inference_mode(): + # (BOS, PAD, EOS, UNK) -> (PAD, UNK, BOS, EOS) + embeds[[0, 1, 2, 3]] = embeds[[1, 3, 0, 2]] + + return checkpoint + + @staticmethod + def _fairseq_key_map() -> Dict[str, str]: + return { + # fmt: off + r"^encoder\.embed_tokens\.": r"encoder_frontend.embed.", + r"^encoder\.embed_positions\.": r"encoder_frontend.pos_encoder.", + r"^encoder\.layernorm_embedding\.": r"encoder_frontend.layer_norm.", + r"^decoder\.embed_tokens\.": r"decoder_frontend.embed.", + r"^decoder\.embed_positions\.": r"decoder_frontend.pos_encoder.", + r"^decoder\.layernorm_embedding\.": r"decoder_frontend.layer_norm.", + r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"decoder.layers.\1.self_attn.output_proj.", + r"^encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"encoder.layers.\1.self_attn.output_proj.", + r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.": r"decoder.layers.\1.encoder_decoder_attn.output_proj.", + r"^decoder\.layers\.([0-9]+)\.encoder_attn\.": r"decoder.layers.\1.encoder_decoder_attn.", + r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"decoder.layers.\1.encoder_decoder_attn_layer_norm.", + r"^encoder\.layers\.([0-9]+)\.fc1\.": r"encoder.layers.\1.ffn.inner_proj.", + r"^decoder\.layers\.([0-9]+)\.fc1\.": r"decoder.layers.\1.ffn.inner_proj.", + r"^encoder\.layers\.([0-9]+)\.fc2\.": r"encoder.layers.\1.ffn.output_proj.", + r"^decoder\.layers\.([0-9]+)\.fc2\.": r"decoder.layers.\1.ffn.output_proj.", + r"^encoder\.layers\.([0-9]+)\.final_layer_norm\.": r"encoder.layers.\1.ffn_layer_norm.", + r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.": r"decoder.layers.\1.ffn_layer_norm.", + r"^decoder\.output_projection\.": r"final_proj.", + # fmt: on + } + + +load_mbart_model = mBartLoader( + asset_store, download_manager, create_mbart_model, mbart_archs +) + + +load_mbart_config = ModelConfigLoader[mBartConfig](asset_store, mbart_archs) + + +class mBartTokenizerLoader: + """Loads tokenizers of mBart models.""" + + def __init__( + self, asset_store: AssetStore, download_manager: AssetDownloadManager + ) -> None: + """ + :param asset_store: + The asset store to retrieve the model information. + :param download_manager: + The download manager to use. + """ + self.asset_store = asset_store + self.download_manager = download_manager + + def __call__( + self, + model_name_or_card: Union[str, AssetCard], + force: bool = False, + progress: bool = True, + ) -> mBartTokenizer: + """ + :param model_name_or_card: + The name or asset card of the model whose tokenizer to load. + :param force: + If ``True``, downloads the tokenizer even if it is already in cache. + :param progress: + If ``True``, displays a progress bar to stderr. + """ + if isinstance(model_name_or_card, AssetCard): + card: AssetCard = model_name_or_card + else: + card = self.asset_store.retrieve_card(model_name_or_card) + + uri = card.field("tokenizer").as_uri() + + pathname = self.download_manager.download_tokenizer( + uri, card.name, force=force, progress=progress + ) + + langs = card.field("langs").as_list(str) + + default_lang = card.field("default_lang").as_(str) + + return mBartTokenizer(pathname, langs, default_lang) + + +load_mbart_tokenizer = mBartTokenizerLoader(asset_store, download_manager) diff --git a/src/fairseq2/models/mbart/tokenizer.py b/src/fairseq2/models/mbart/tokenizer.py new file mode 100644 index 000000000..ac6087e18 --- /dev/null +++ b/src/fairseq2/models/mbart/tokenizer.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Sequence, Set, final + +from fairseq2.data.text import ( + SentencePieceDecoder, + SentencePieceEncoder, + SentencePieceModel, + TextTokenDecoder, + TextTokenEncoder, + TextTokenizer, + vocabulary_from_sentencepiece, +) +from fairseq2.data.typing import PathLike +from fairseq2.typing import Device, finaloverride + + +@final +class mBartTokenizer(TextTokenizer): + """Represents the tokenizer used by mBart models.""" + + model: SentencePieceModel + langs: Set[str] + default_lang: str + + def __init__( + self, pathname: PathLike, langs: Sequence[str], default_lang: str + ) -> None: + """ + :param pathname: + The pathname of the SentencePiece model file. + :param langs: + The list of supported languages. + :param default_lang: + The fall-back language if no language is specified. + """ + # Each language is represented by a `[lang]` control symbol. + control_symbols = [f"[{lang}]" for lang in langs] + + control_symbols.append("") + + self.model = SentencePieceModel(pathname, control_symbols) + + self.langs = set(langs) + + self.default_lang = default_lang + + vocab_info = vocabulary_from_sentencepiece(self.model) + + super().__init__(vocab_info) + + @finaloverride + def create_encoder( + self, + task: Optional[str] = None, + lang: Optional[str] = None, + mode: Optional[str] = None, + device: Optional[Device] = None, + pin_memory: bool = False, + ) -> TextTokenEncoder: + """Create a token encoder. + + :param task: + Must be 'translation'. If ``None``, defaults to 'translation'. + :param lang: + A language from :attr:`langs`. If ``None``, defaults to + :attr:`default_lang`. + :param mode: + Must be 'source' or 'target'. Set to 'source' if ``lang`` is the + source language; set to 'target' if ``lang`` is the target language. + If ``None``, defaults to 'source'. + :param device: + The device on which to construct tensors. + :param pin_memory: + If ``True``, uses pinned memory while constructing tensors. + """ + if task is not None and task != "translation": + raise ValueError(f"`task` must be 'translation', but is '{task}' instead.") + + if lang is None: + lang = self.default_lang + + if lang not in self.langs: + raise ValueError( + f"`lang` must be a supported language, but is '{lang}' instead." + ) + + if mode is None or mode == "source": + prefix_tokens = [""] + suffix_tokens = ["", f"[{lang}]"] + elif mode == "target": + prefix_tokens = [f"[{lang}]", ""] + suffix_tokens = [""] + else: + raise ValueError( + f"`mode` must be 'source' or 'target', but is '{mode}' instead." + ) + + return SentencePieceEncoder( + self.model, + prefix_tokens=prefix_tokens, + suffix_tokens=suffix_tokens, + device=device, + pin_memory=pin_memory, + ) + + @finaloverride + def create_decoder(self) -> TextTokenDecoder: + return SentencePieceDecoder(self.model)