From 023e30cba495f419b422f50545a1010cecf348c3 Mon Sep 17 00:00:00 2001 From: Kaushik Ram Sadagopan Date: Sat, 9 Sep 2023 15:26:56 -0700 Subject: [PATCH 1/4] Load a base mbart model and implement its text tokenizer. --- src/fairseq2/models/mbart/__init__.py | 15 ++ src/fairseq2/models/mbart/builder.py | 290 +++++++++++++++++++++++++ src/fairseq2/models/mbart/loader.py | 150 +++++++++++++ src/fairseq2/models/mbart/tokenizer.py | 113 ++++++++++ 4 files changed, 568 insertions(+) create mode 100644 src/fairseq2/models/mbart/__init__.py create mode 100644 src/fairseq2/models/mbart/builder.py create mode 100644 src/fairseq2/models/mbart/loader.py create mode 100644 src/fairseq2/models/mbart/tokenizer.py diff --git a/src/fairseq2/models/mbart/__init__.py b/src/fairseq2/models/mbart/__init__.py new file mode 100644 index 000000000..a04c1984f --- /dev/null +++ b/src/fairseq2/models/mbart/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq2.models.mbart.builder import create_mbart_model as create_mbart_model +from fairseq2.models.mbart.builder import mbart_arch as mbart_arch +from fairseq2.models.mbart.builder import mbart_archs as mbart_archs +from fairseq2.models.mbart.builder import mBartBuilder as mBartBuilder +from fairseq2.models.mbart.builder import mBartConfig as mBartConfig +from fairseq2.models.mbart.loader import load_mbart_model as load_mbart_model +from fairseq2.models.mbart.loader import load_mbart_tokenizer as load_mbart_tokenizer +from fairseq2.models.mbart.loader import mBartLoader as mBartLoader +from fairseq2.models.mbart.tokenizer import mBartTokenizer as mBartTokenizer diff --git a/src/fairseq2/models/mbart/builder.py b/src/fairseq2/models/mbart/builder.py new file mode 100644 index 000000000..ee83ce8d9 --- /dev/null +++ b/src/fairseq2/models/mbart/builder.py @@ -0,0 +1,290 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Literal, Optional + +from fairseq2.data import VocabularyInfo +from fairseq2.models.transformer import ( + TransformerEmbeddingFrontend, + TransformerFrontend, + TransformerModel, +) +from fairseq2.models.utils.arch_registry import ArchitectureRegistry +from fairseq2.nn.embedding import Embedding +from fairseq2.nn.position_encoder import ( + LearnedPositionEncoder, + SinusoidalPositionEncoder, +) +from fairseq2.nn.projection import TiedProjection +from fairseq2.nn.transformer import ( + FeedForwardNetwork, + MultiheadAttention, + StandardFeedForwardNetwork, + StandardMultiheadAttention, + StandardTransformerDecoder, + StandardTransformerDecoderLayer, + StandardTransformerEncoder, + StandardTransformerEncoderLayer, + TransformerDecoder, + TransformerDecoderLayer, + TransformerEncoder, + TransformerEncoderLayer, + TransformerNormOrder, + create_default_sdpa, +) +from fairseq2.typing import DataType, Device + + +@dataclass +class mBartConfig: + """Holds the configuration of an mBart model.""" + + model_dim: int + """The dimensionality of the model.""" + + max_seq_len: int + """The expected maximum sequence length.""" + + vocabulary_size: int + """The size of the vocabulary.""" + + pad_idx: Optional[int] + """The index of the pad symbol in the vocabulary.""" + + num_encoder_layers: int + """The number of Transformer encoder layers.""" + + num_decoder_layers: int + """The number of Transformer decoder layers.""" + + num_encoder_attn_heads: int + """The number of attention heads in Transformer encoder layers.""" + + num_decoder_attn_heads: int + """The number of attention heads in Transformer decoder layers.""" + + ffn_inner_dim: int + """The inner dimensionality of Transformer feed-forward networks.""" + + # Position Encoder + pos_encoder_type: Literal["sinusoidal", "learned"] + """The type of position encoder.""" + + frontend_layernorm: bool + + dropout_p: float + """The dropout probability in Transformer layers.""" + + def update_vocabulary(self, info: VocabularyInfo) -> None: + """Update vocabulary configuration from ``info``.""" + self.vocabulary_size, self.pad_idx = info.size, info.pad_idx + + +mbart_archs = ArchitectureRegistry[mBartConfig]("mbart") + + +mbart_arch = mbart_archs.marker + + +@mbart_arch("base") +def _base() -> mBartConfig: + return mBartConfig( + model_dim=1024, + max_seq_len=1026, + vocabulary_size=65539, + pad_idx=0, + num_encoder_layers=12, + num_decoder_layers=12, + num_encoder_attn_heads=16, + num_decoder_attn_heads=16, + ffn_inner_dim=4096, + pos_encoder_type="learned", + frontend_layernorm=True, + dropout_p=0.1, + ) + + +class mBartBuilder: + """Builds modules of an mBart model as described in + :cite:t:`https://arxiv.org/abs/2001.08210`. + + To tweak the architecture, you can derive from this class and override the + corresponding methods. + """ + + config: mBartConfig + device: Optional[Device] + dtype: Optional[DataType] + + def __init__( + self, + config: mBartConfig, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, + ) -> None: + """ + :param config: + The configuration to use. + :param device: + The device on which to initialize modules. + :param dtype: + The data type of module parameters and buffers. + """ + self.config = config + self.device = device + self.dtype = dtype + + def build_model(self) -> TransformerModel: + """Build a model.""" + embed = self.build_embedding() + + frontend = self.build_frontend(embed) + + encoder = self.build_encoder() + decoder = self.build_decoder() + + final_proj = TiedProjection(embed.weight) + + return TransformerModel( + frontend, encoder, frontend, decoder, final_proj, self.config.pad_idx + ) + + def build_embedding(self) -> Embedding: + """Build an embedding table.""" + return Embedding( + num_embeddings=self.config.vocabulary_size, + embedding_dim=self.config.model_dim, + pad_idx=self.config.pad_idx, + scaled=True, + device=self.device, + dtype=self.dtype, + ) + + def build_frontend(self, embed: Embedding) -> TransformerFrontend: + """Build a Transformer encoder/decoder front-end.""" + if self.config.pos_encoder_type == "sinusoidal": + pos_encoder = SinusoidalPositionEncoder( + self.config.model_dim, + self.config.max_seq_len, + _legacy_pad_idx=self.config.pad_idx, + device=self.device, + dtype=self.dtype, + ) + else: + pos_encoder = LearnedPositionEncoder( + self.config.model_dim, + self.config.max_seq_len, + device=self.device, + dtype=self.dtype, + ) + + return TransformerEmbeddingFrontend( + embed, + pos_encoder, + layer_norm=self.config.frontend_layernorm, + dropout_p=self.config.dropout_p, + device=self.device, + dtype=self.dtype, + ) + + def build_encoder(self) -> TransformerEncoder: + """Build a Transformer encoder.""" + num_layers = self.config.num_encoder_layers + + layers = [self.build_encoder_layer() for _ in range(num_layers)] + + return StandardTransformerEncoder( + layers, + norm_order=TransformerNormOrder.PRE, + device=self.device, + dtype=self.dtype, + ) + + def build_decoder(self) -> TransformerDecoder: + """Build a Transformer decoder.""" + num_layers = self.config.num_decoder_layers + + layers = [self.build_decoder_layer() for _ in range(num_layers)] + + return StandardTransformerDecoder( + layers, + norm_order=TransformerNormOrder.PRE, + device=self.device, + dtype=self.dtype, + ) + + def build_encoder_layer(self) -> TransformerEncoderLayer: + """Build a Transformer encoder layer.""" + self_attn = self.build_attention(self.config.num_encoder_attn_heads) + + ffn = self.build_ffn() + + return StandardTransformerEncoderLayer( + self_attn, + ffn, + dropout_p=self.config.dropout_p, + norm_order=TransformerNormOrder.PRE, + device=self.device, + dtype=self.dtype, + ) + + def build_decoder_layer(self) -> TransformerDecoderLayer: + """Build a Transformer decoder layer.""" + self_attn = self.build_attention(self.config.num_decoder_attn_heads) + + encoder_decoder_attn = self.build_attention(self.config.num_decoder_attn_heads) + + ffn = self.build_ffn() + + return StandardTransformerDecoderLayer( + self_attn, + encoder_decoder_attn, + ffn, + dropout_p=self.config.dropout_p, + norm_order=TransformerNormOrder.PRE, + device=self.device, + dtype=self.dtype, + ) + + def build_attention(self, num_heads: int) -> MultiheadAttention: + """Build a Transformer multi-head attention layer.""" + sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p) + + return StandardMultiheadAttention( + self.config.model_dim, + num_heads, + sdpa=sdpa, + device=self.device, + dtype=self.dtype, + ) + + def build_ffn(self) -> FeedForwardNetwork: + """Build a Transformer feed-forward network.""" + return StandardFeedForwardNetwork( + self.config.model_dim, + self.config.ffn_inner_dim, + norm_order=TransformerNormOrder.PRE, + device=self.device, + dtype=self.dtype, + ) + + +def create_mbart_model( + config: mBartConfig, + device: Optional[Device] = None, + dtype: Optional[DataType] = None, +) -> TransformerModel: + """Create an mBart model. + + :param config: + The configuration to use. + :param device: + The device on which to initialize modules. + :param dtype: + The data type of module parameters and buffers. + """ + return mBartBuilder(config, device, dtype).build_model() diff --git a/src/fairseq2/models/mbart/loader.py b/src/fairseq2/models/mbart/loader.py new file mode 100644 index 000000000..84cb94d3b --- /dev/null +++ b/src/fairseq2/models/mbart/loader.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, Mapping, Union, final + +import torch + +from fairseq2.assets import ( + AssetCard, + AssetDownloadManager, + AssetStore, + asset_store, + download_manager, +) +from fairseq2.models.mbart.builder import create_mbart_model, mbart_archs, mBartConfig +from fairseq2.models.mbart.tokenizer import mBartTokenizer +from fairseq2.models.transformer import TransformerModel +from fairseq2.models.utils.checkpoint_loader import upgrade_fairseq_checkpoint +from fairseq2.models.utils.model_loader import ModelConfigLoader, ModelLoader +from fairseq2.typing import finaloverride + + +@final +class mBartLoader(ModelLoader[TransformerModel, mBartConfig]): + """Loads mBart models.""" + + @finaloverride + def _upgrade_checkpoint( + self, checkpoint: Mapping[str, Any], config: mBartConfig + ) -> Mapping[str, Any]: + state_dict = checkpoint["model"] + + # Check if we have a fairseq2 checkpoint. + if "decoder_frontend.embed_weight" in state_dict: + return checkpoint + + key_map = self._fairseq_key_map() + + # Convert to fairseq2. + checkpoint = upgrade_fairseq_checkpoint(checkpoint, key_map) + + state_dict = checkpoint["model"] + + embeds = state_dict["final_proj.weight"] + + # fairseq had a bug that accidentally introduced a dummy token in the + # embedding table of NLLB-100. We just discard it. + if embeds.size(0) == 256103: # means NLLB-100 + embeds = embeds[:-1] + + state_dict["final_proj.weight"] = embeds + + # fairseq checkpoints have duplicate embedding weights. Ensure that we + # use a single embedding table in fairseq2. + state_dict["encoder_frontend.embed.weight"] = embeds + state_dict["decoder_frontend.embed.weight"] = embeds + + # The embedding positions of the control symbols in fairseq's dict do + # not match the SentencePiece model of the tokenizer. + with torch.inference_mode(): + # (BOS, PAD, EOS, UNK) -> (PAD, UNK, BOS, EOS) + embeds[[0, 1, 2, 3]] = embeds[[1, 3, 0, 2]] + + return checkpoint + + @staticmethod + def _fairseq_key_map() -> Dict[str, str]: + return { + # fmt: off + r"^encoder\.embed_tokens\.": r"encoder_frontend.embed.", + r"^encoder\.embed_positions\.": r"encoder_frontend.pos_encoder.", + r"^encoder\.layernorm_embedding\.": r"encoder_frontend.layer_norm.", + r"^decoder\.embed_tokens\.": r"decoder_frontend.embed.", + r"^decoder\.embed_positions\.": r"decoder_frontend.pos_encoder.", + r"^decoder\.layernorm_embedding\.": r"decoder_frontend.layer_norm.", + r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"decoder.layers.\1.self_attn.output_proj.", + r"^encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"encoder.layers.\1.self_attn.output_proj.", + r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.": r"decoder.layers.\1.encoder_decoder_attn.output_proj.", + r"^decoder\.layers\.([0-9]+)\.encoder_attn\.": r"decoder.layers.\1.encoder_decoder_attn.", + r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"decoder.layers.\1.encoder_decoder_attn_layer_norm.", + r"^encoder\.layers\.([0-9]+)\.fc1\.": r"encoder.layers.\1.ffn.inner_proj.", + r"^decoder\.layers\.([0-9]+)\.fc1\.": r"decoder.layers.\1.ffn.inner_proj.", + r"^encoder\.layers\.([0-9]+)\.fc2\.": r"encoder.layers.\1.ffn.output_proj.", + r"^decoder\.layers\.([0-9]+)\.fc2\.": r"decoder.layers.\1.ffn.output_proj.", + r"^encoder\.layers\.([0-9]+)\.final_layer_norm\.": r"encoder.layers.\1.ffn_layer_norm.", + r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.": r"decoder.layers.\1.ffn_layer_norm.", + r"^decoder\.output_projection\.": r"final_proj.", + # fmt: on + } + + +load_mbart_model = mBartLoader( + asset_store, download_manager, create_mbart_model, mbart_archs +) + + +load_mbart_config = ModelConfigLoader[mBartConfig](asset_store, mbart_archs) + + +class mBartTokenizerLoader: + """Loads tokenizers of mBart models.""" + + def __init__( + self, asset_store: AssetStore, download_manager: AssetDownloadManager + ) -> None: + """ + :param asset_store: + The asset store to retrieve the model information. + :param download_manager: + The download manager to use. + """ + self.asset_store = asset_store + self.download_manager = download_manager + + def __call__( + self, + model_name_or_card: Union[str, AssetCard], + force: bool = False, + progress: bool = True, + ) -> mBartTokenizer: + """ + :param model_name_or_card: + The name or asset card of the model whose tokenizer to load. + :param force: + If ``True``, downloads the tokenizer even if it is already in cache. + :param progress: + If ``True``, displays a progress bar to stderr. + """ + if isinstance(model_name_or_card, AssetCard): + card: AssetCard = model_name_or_card + else: + card = self.asset_store.retrieve_card(model_name_or_card) + + uri = card.field("tokenizer").as_uri() + + pathname = self.download_manager.download_tokenizer( + uri, card.name, force=force, progress=progress + ) + + langs = card.field("langs").as_list(str) + + default_lang = card.field("default_lang").as_(str) + + return mBartTokenizer(pathname, langs, default_lang) + + +load_mbart_tokenizer = mBartTokenizerLoader(asset_store, download_manager) diff --git a/src/fairseq2/models/mbart/tokenizer.py b/src/fairseq2/models/mbart/tokenizer.py new file mode 100644 index 000000000..a63b027d7 --- /dev/null +++ b/src/fairseq2/models/mbart/tokenizer.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Sequence, Set, final + +from fairseq2.data.text import ( + SentencePieceDecoder, + SentencePieceEncoder, + SentencePieceModel, + TextTokenDecoder, + TextTokenEncoder, + TextTokenizer, + vocabulary_from_sentencepiece, +) +from fairseq2.data.typing import PathLike +from fairseq2.typing import Device, finaloverride + + +@final +class mBartTokenizer(TextTokenizer): + """Represents the tokenizer used by mBart models.""" + + model: SentencePieceModel + langs: Set[str] + default_lang: str + + def __init__( + self, pathname: PathLike, langs: Sequence[str], default_lang: str + ) -> None: + """ + :param pathname: + The pathname of the SentencePiece model file. + :param langs: + The list of supported languages. + :param default_lang: + The fall-back language if no language is specified. + """ + # Each language is represented by a `[lang_XX]` control symbol. + control_symbols = [f"[{lang}_XX]" for lang in langs] + + control_symbols.append("") + + self.model = SentencePieceModel(pathname, control_symbols) + + self.langs = set(langs) + + self.default_lang = default_lang + + vocab_info = vocabulary_from_sentencepiece(self.model) + + super().__init__(vocab_info) + + @finaloverride + def create_encoder( + self, + task: Optional[str] = None, + lang: Optional[str] = None, + mode: Optional[str] = None, + device: Optional[Device] = None, + pin_memory: bool = False, + ) -> TextTokenEncoder: + """Create a token encoder. + + :param task: + Must be 'translation'. If ``None``, defaults to 'translation'. + :param lang: + A language from :attr:`langs`. If ``None``, defaults to + :attr:`default_lang`. + :param mode: + Must be 'source' or 'target'. Set to 'source' if ``lang`` is the + source language; set to 'target' if ``lang`` is the target language. + If ``None``, defaults to 'source'. + :param device: + The device on which to construct tensors. + :param pin_memory: + If ``True``, uses pinned memory while constructing tensors. + """ + if task is not None and task != "translation": + raise ValueError(f"`task` must be 'translation', but is '{task}' instead.") + + if lang is None: + lang = self.default_lang + + if lang not in self.langs: + raise ValueError( + f"`lang` must be a supported language, but is '{lang}' instead." + ) + + if mode is None or mode == "source": + prefix_tokens = [""] + suffix_tokens = ["", f"__{lang}__"] + elif mode == "target": + prefix_tokens = [f"__{lang}__", ""] + suffix_tokens = [""] + else: + raise ValueError( + f"`mode` must be 'source' or 'target', but is '{mode}' instead." + ) + + return SentencePieceEncoder( + self.model, + prefix_tokens=prefix_tokens, + suffix_tokens=suffix_tokens, + device=device, + pin_memory=pin_memory, + ) + + @finaloverride + def create_decoder(self) -> TextTokenDecoder: + return SentencePieceDecoder(self.model) From 0eb23284310b34562e30ef31b7327d925a80b7b7 Mon Sep 17 00:00:00 2001 From: Kaushik Ram Sadagopan Date: Sun, 10 Sep 2023 19:47:22 -0700 Subject: [PATCH 2/4] Fixing bug in lang control symbol in prefix, suffix tokens. --- src/fairseq2/models/mbart/builder.py | 15 ++++++++++----- src/fairseq2/models/mbart/tokenizer.py | 4 ++-- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/fairseq2/models/mbart/builder.py b/src/fairseq2/models/mbart/builder.py index ee83ce8d9..59c3028a9 100644 --- a/src/fairseq2/models/mbart/builder.py +++ b/src/fairseq2/models/mbart/builder.py @@ -75,10 +75,14 @@ class mBartConfig: """The type of position encoder.""" frontend_layernorm: bool + """Whether to add the layernorm in the encoder, decoder frontend.""" dropout_p: float """The dropout probability in Transformer layers.""" + norm_order: TransformerNormOrder + """The Layer Normalization order.""" + def update_vocabulary(self, info: VocabularyInfo) -> None: """Update vocabulary configuration from ``info``.""" self.vocabulary_size, self.pad_idx = info.size, info.pad_idx @@ -105,6 +109,7 @@ def _base() -> mBartConfig: pos_encoder_type="learned", frontend_layernorm=True, dropout_p=0.1, + norm_order=TransformerNormOrder.POST, ) @@ -199,7 +204,7 @@ def build_encoder(self) -> TransformerEncoder: return StandardTransformerEncoder( layers, - norm_order=TransformerNormOrder.PRE, + norm_order=self.config.norm_order, device=self.device, dtype=self.dtype, ) @@ -212,7 +217,7 @@ def build_decoder(self) -> TransformerDecoder: return StandardTransformerDecoder( layers, - norm_order=TransformerNormOrder.PRE, + norm_order=self.config.norm_order, device=self.device, dtype=self.dtype, ) @@ -227,7 +232,7 @@ def build_encoder_layer(self) -> TransformerEncoderLayer: self_attn, ffn, dropout_p=self.config.dropout_p, - norm_order=TransformerNormOrder.PRE, + norm_order=self.config.norm_order, device=self.device, dtype=self.dtype, ) @@ -245,7 +250,7 @@ def build_decoder_layer(self) -> TransformerDecoderLayer: encoder_decoder_attn, ffn, dropout_p=self.config.dropout_p, - norm_order=TransformerNormOrder.PRE, + norm_order=self.config.norm_order, device=self.device, dtype=self.dtype, ) @@ -267,7 +272,7 @@ def build_ffn(self) -> FeedForwardNetwork: return StandardFeedForwardNetwork( self.config.model_dim, self.config.ffn_inner_dim, - norm_order=TransformerNormOrder.PRE, + norm_order=self.config.norm_order, device=self.device, dtype=self.dtype, ) diff --git a/src/fairseq2/models/mbart/tokenizer.py b/src/fairseq2/models/mbart/tokenizer.py index a63b027d7..1a6e184f1 100644 --- a/src/fairseq2/models/mbart/tokenizer.py +++ b/src/fairseq2/models/mbart/tokenizer.py @@ -91,9 +91,9 @@ def create_encoder( if mode is None or mode == "source": prefix_tokens = [""] - suffix_tokens = ["", f"__{lang}__"] + suffix_tokens = ["", f"[{lang}_XX]"] elif mode == "target": - prefix_tokens = [f"__{lang}__", ""] + prefix_tokens = [f"[{lang}_XX]", ""] suffix_tokens = [""] else: raise ValueError( From c52ce3ad2ec4cf6a98ba56f0a390235cef547bce Mon Sep 17 00:00:00 2001 From: Kaushik Ram Sadagopan Date: Mon, 11 Sep 2023 11:12:52 -0700 Subject: [PATCH 3/4] Changing lang tag, removing norm_order, setting to pre-LN. --- src/fairseq2/models/mbart/builder.py | 22 +++++++++------------- src/fairseq2/models/mbart/loader.py | 15 --------------- src/fairseq2/models/mbart/tokenizer.py | 8 ++++---- 3 files changed, 13 insertions(+), 32 deletions(-) diff --git a/src/fairseq2/models/mbart/builder.py b/src/fairseq2/models/mbart/builder.py index 59c3028a9..30b13c89c 100644 --- a/src/fairseq2/models/mbart/builder.py +++ b/src/fairseq2/models/mbart/builder.py @@ -74,15 +74,12 @@ class mBartConfig: pos_encoder_type: Literal["sinusoidal", "learned"] """The type of position encoder.""" - frontend_layernorm: bool - """Whether to add the layernorm in the encoder, decoder frontend.""" + layer_norm_embed: bool + """Adds a layernorm to the embedding in the Transformer encoder.""" dropout_p: float """The dropout probability in Transformer layers.""" - norm_order: TransformerNormOrder - """The Layer Normalization order.""" - def update_vocabulary(self, info: VocabularyInfo) -> None: """Update vocabulary configuration from ``info``.""" self.vocabulary_size, self.pad_idx = info.size, info.pad_idx @@ -107,9 +104,8 @@ def _base() -> mBartConfig: num_decoder_attn_heads=16, ffn_inner_dim=4096, pos_encoder_type="learned", - frontend_layernorm=True, + layer_norm_embed=True, dropout_p=0.1, - norm_order=TransformerNormOrder.POST, ) @@ -190,7 +186,7 @@ def build_frontend(self, embed: Embedding) -> TransformerFrontend: return TransformerEmbeddingFrontend( embed, pos_encoder, - layer_norm=self.config.frontend_layernorm, + layer_norm=self.config.layer_norm_embed, dropout_p=self.config.dropout_p, device=self.device, dtype=self.dtype, @@ -204,7 +200,7 @@ def build_encoder(self) -> TransformerEncoder: return StandardTransformerEncoder( layers, - norm_order=self.config.norm_order, + norm_order=TransformerNormOrder.PRE, device=self.device, dtype=self.dtype, ) @@ -217,7 +213,7 @@ def build_decoder(self) -> TransformerDecoder: return StandardTransformerDecoder( layers, - norm_order=self.config.norm_order, + norm_order=TransformerNormOrder.PRE, device=self.device, dtype=self.dtype, ) @@ -232,7 +228,7 @@ def build_encoder_layer(self) -> TransformerEncoderLayer: self_attn, ffn, dropout_p=self.config.dropout_p, - norm_order=self.config.norm_order, + norm_order=TransformerNormOrder.PRE, device=self.device, dtype=self.dtype, ) @@ -250,7 +246,7 @@ def build_decoder_layer(self) -> TransformerDecoderLayer: encoder_decoder_attn, ffn, dropout_p=self.config.dropout_p, - norm_order=self.config.norm_order, + norm_order=TransformerNormOrder.PRE, device=self.device, dtype=self.dtype, ) @@ -272,7 +268,7 @@ def build_ffn(self) -> FeedForwardNetwork: return StandardFeedForwardNetwork( self.config.model_dim, self.config.ffn_inner_dim, - norm_order=self.config.norm_order, + norm_order=TransformerNormOrder.PRE, device=self.device, dtype=self.dtype, ) diff --git a/src/fairseq2/models/mbart/loader.py b/src/fairseq2/models/mbart/loader.py index 84cb94d3b..47653c065 100644 --- a/src/fairseq2/models/mbart/loader.py +++ b/src/fairseq2/models/mbart/loader.py @@ -6,8 +6,6 @@ from typing import Any, Dict, Mapping, Union, final -import torch - from fairseq2.assets import ( AssetCard, AssetDownloadManager, @@ -46,24 +44,11 @@ def _upgrade_checkpoint( embeds = state_dict["final_proj.weight"] - # fairseq had a bug that accidentally introduced a dummy token in the - # embedding table of NLLB-100. We just discard it. - if embeds.size(0) == 256103: # means NLLB-100 - embeds = embeds[:-1] - - state_dict["final_proj.weight"] = embeds - # fairseq checkpoints have duplicate embedding weights. Ensure that we # use a single embedding table in fairseq2. state_dict["encoder_frontend.embed.weight"] = embeds state_dict["decoder_frontend.embed.weight"] = embeds - # The embedding positions of the control symbols in fairseq's dict do - # not match the SentencePiece model of the tokenizer. - with torch.inference_mode(): - # (BOS, PAD, EOS, UNK) -> (PAD, UNK, BOS, EOS) - embeds[[0, 1, 2, 3]] = embeds[[1, 3, 0, 2]] - return checkpoint @staticmethod diff --git a/src/fairseq2/models/mbart/tokenizer.py b/src/fairseq2/models/mbart/tokenizer.py index 1a6e184f1..ac6087e18 100644 --- a/src/fairseq2/models/mbart/tokenizer.py +++ b/src/fairseq2/models/mbart/tokenizer.py @@ -38,8 +38,8 @@ def __init__( :param default_lang: The fall-back language if no language is specified. """ - # Each language is represented by a `[lang_XX]` control symbol. - control_symbols = [f"[{lang}_XX]" for lang in langs] + # Each language is represented by a `[lang]` control symbol. + control_symbols = [f"[{lang}]" for lang in langs] control_symbols.append("") @@ -91,9 +91,9 @@ def create_encoder( if mode is None or mode == "source": prefix_tokens = [""] - suffix_tokens = ["", f"[{lang}_XX]"] + suffix_tokens = ["", f"[{lang}]"] elif mode == "target": - prefix_tokens = [f"[{lang}_XX]", ""] + prefix_tokens = [f"[{lang}]", ""] suffix_tokens = [""] else: raise ValueError( From c7c7e91372572693fdcc97ab1a149e574a173d67 Mon Sep 17 00:00:00 2001 From: Kaushik Ram Sadagopan Date: Fri, 15 Sep 2023 11:00:40 -0700 Subject: [PATCH 4/4] Embedding special token reordering to align with fairseq. --- src/fairseq2/models/mbart/loader.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/fairseq2/models/mbart/loader.py b/src/fairseq2/models/mbart/loader.py index 47653c065..efa85e4c1 100644 --- a/src/fairseq2/models/mbart/loader.py +++ b/src/fairseq2/models/mbart/loader.py @@ -20,6 +20,8 @@ from fairseq2.models.utils.model_loader import ModelConfigLoader, ModelLoader from fairseq2.typing import finaloverride +import torch + @final class mBartLoader(ModelLoader[TransformerModel, mBartConfig]): @@ -49,6 +51,12 @@ def _upgrade_checkpoint( state_dict["encoder_frontend.embed.weight"] = embeds state_dict["decoder_frontend.embed.weight"] = embeds + # The embedding positions of the control symbols in fairseq's dict do + # not match the SentencePiece model of the tokenizer. + with torch.inference_mode(): + # (BOS, PAD, EOS, UNK) -> (PAD, UNK, BOS, EOS) + embeds[[0, 1, 2, 3]] = embeds[[1, 3, 0, 2]] + return checkpoint @staticmethod