From c20c287370b54d4f10a15c8f803ccb4c62b931e6 Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Wed, 7 Aug 2024 01:35:24 +0000 Subject: [PATCH 1/8] refactor the hg interface to support multiple models through presets --- src/fairseq2/recipes/hg/__init__.py | 10 +++++----- src/fairseq2/recipes/hg/asr_eval.py | 13 +++++++++++-- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/fairseq2/recipes/hg/__init__.py b/src/fairseq2/recipes/hg/__init__.py index 36d4ba38a..f44159cac 100644 --- a/src/fairseq2/recipes/hg/__init__.py +++ b/src/fairseq2/recipes/hg/__init__.py @@ -33,17 +33,17 @@ def _setup_hg_cli(cli: Cli) -> None: from fairseq2.recipes.hg.asr_eval import ( asr_eval_presets, - load_wav2vec2_asr_evaluator, + load_asr_evaluator, ) handler = RecipeCommandHandler( - load_wav2vec2_asr_evaluator, + load_asr_evaluator, preset_configs=asr_eval_presets, - default_preset="librispeech_asr", + default_preset="wav2vec2_librispeech_asr", ) group.add_command( - "wav2vec2_asr", + "asr", handler, - help="evaluate a wav2vec 2.0 ASR model on a downstream benchmark (default: librispeech_asr)", + help="evaluate an ASR model (default: wav2vec2) on a downstream benchmark (default: librispeech_asr)", ) diff --git a/src/fairseq2/recipes/hg/asr_eval.py b/src/fairseq2/recipes/hg/asr_eval.py index 17b1f3a02..2b592eedb 100644 --- a/src/fairseq2/recipes/hg/asr_eval.py +++ b/src/fairseq2/recipes/hg/asr_eval.py @@ -16,6 +16,7 @@ load_dataset, ) +from fairseq2.assets.metadata_provider import AssetNotFoundError from fairseq2.config_registry import ConfigRegistry from fairseq2.data.data_pipeline import SequenceData from fairseq2.data.text import load_text_tokenizer @@ -28,6 +29,7 @@ from fairseq2.nn.padding import get_seqs_and_padding_mask from fairseq2.recipes.hg.dataset import Example, create_hf_reader from fairseq2.recipes.hg.evaluator import HFEvaluator +from fairseq2.recipes.utils.asset import retrieve_asset_card from fairseq2.recipes.utils.setup import setup_root_gang from fairseq2.typing import META, DataType from fairseq2.utils.profiler import Stopwatch @@ -87,8 +89,8 @@ class AsrEvalConfig: asr_eval_preset = asr_eval_presets.decorator -@asr_eval_preset("librispeech_asr") -def _librispeech_asr_config() -> AsrEvalConfig: +@asr_eval_preset("wav2vec2_librispeech_asr") +def _wav2vec2_librispeech_asr_config() -> AsrEvalConfig: return AsrEvalConfig( dataset_name="librispeech_asr", model_name="wav2vec2_asr_base_10h", @@ -96,6 +98,13 @@ def _librispeech_asr_config() -> AsrEvalConfig: # converter=librispeech_asr_to_batch, ) +@asr_eval_preset("whisper_librispeech_asr") +def _whisper_librispeech_asr_config() -> AsrEvalConfig: + return AsrEvalConfig( + dataset_name="librispeech_asr", + model_name="whisper", + split="test.other", + ) def _librispeech_asr_to_batch(examples: Example) -> Seq2SeqBatch: """ From 1d6ebb91a4b82b893289e185f5cd2f9e8387a6b8 Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Fri, 9 Aug 2024 13:21:42 +0000 Subject: [PATCH 2/8] define a defaut preset only --- src/fairseq2/recipes/hg/__init__.py | 2 +- src/fairseq2/recipes/hg/asr_eval.py | 12 ++---------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/src/fairseq2/recipes/hg/__init__.py b/src/fairseq2/recipes/hg/__init__.py index f44159cac..6435cbd89 100644 --- a/src/fairseq2/recipes/hg/__init__.py +++ b/src/fairseq2/recipes/hg/__init__.py @@ -39,7 +39,7 @@ def _setup_hg_cli(cli: Cli) -> None: handler = RecipeCommandHandler( load_asr_evaluator, preset_configs=asr_eval_presets, - default_preset="wav2vec2_librispeech_asr", + default_preset="default_asr", ) group.add_command( diff --git a/src/fairseq2/recipes/hg/asr_eval.py b/src/fairseq2/recipes/hg/asr_eval.py index 2b592eedb..f37787df2 100644 --- a/src/fairseq2/recipes/hg/asr_eval.py +++ b/src/fairseq2/recipes/hg/asr_eval.py @@ -89,8 +89,8 @@ class AsrEvalConfig: asr_eval_preset = asr_eval_presets.decorator -@asr_eval_preset("wav2vec2_librispeech_asr") -def _wav2vec2_librispeech_asr_config() -> AsrEvalConfig: +@asr_eval_preset("default_asr") +def _default_asr_config() -> AsrEvalConfig: return AsrEvalConfig( dataset_name="librispeech_asr", model_name="wav2vec2_asr_base_10h", @@ -98,14 +98,6 @@ def _wav2vec2_librispeech_asr_config() -> AsrEvalConfig: # converter=librispeech_asr_to_batch, ) -@asr_eval_preset("whisper_librispeech_asr") -def _whisper_librispeech_asr_config() -> AsrEvalConfig: - return AsrEvalConfig( - dataset_name="librispeech_asr", - model_name="whisper", - split="test.other", - ) - def _librispeech_asr_to_batch(examples: Example) -> Seq2SeqBatch: """ Converts a collated batch of examples into a Seq2SeqBatch. From 8076bf0678facf9223034fd315b44d8bd180fbae Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Fri, 9 Aug 2024 17:18:49 +0000 Subject: [PATCH 3/8] Refactor ASR evaluation code for improved extensibility and whisper integration --- src/fairseq2/recipes/hg/__init__.py | 5 +- src/fairseq2/recipes/hg/asr_eval.py | 181 ++++++++++++++++++++-------- 2 files changed, 135 insertions(+), 51 deletions(-) diff --git a/src/fairseq2/recipes/hg/__init__.py b/src/fairseq2/recipes/hg/__init__.py index 6435cbd89..dfc7847d6 100644 --- a/src/fairseq2/recipes/hg/__init__.py +++ b/src/fairseq2/recipes/hg/__init__.py @@ -31,10 +31,7 @@ def _setup_hg_cli(cli: Cli) -> None: group = cli.add_group("hg", help="Hugging Face recipes") - from fairseq2.recipes.hg.asr_eval import ( - asr_eval_presets, - load_asr_evaluator, - ) + from fairseq2.recipes.hg.asr_eval import asr_eval_presets, load_asr_evaluator handler = RecipeCommandHandler( load_asr_evaluator, diff --git a/src/fairseq2/recipes/hg/asr_eval.py b/src/fairseq2/recipes/hg/asr_eval.py index f37787df2..b36476354 100644 --- a/src/fairseq2/recipes/hg/asr_eval.py +++ b/src/fairseq2/recipes/hg/asr_eval.py @@ -5,16 +5,18 @@ # LICENSE file in the root directory of this source tree. import itertools +from collections.abc import Callable from dataclasses import dataclass -from functools import lru_cache from pathlib import Path -from typing import Any, cast +from typing import Any, Optional, cast import torch from datasets import ( # type: ignore[attr-defined,import-untyped,import-not-found] Dataset, load_dataset, + load_dataset_builder, ) +from transformers import WhisperForConditionalGeneration, WhisperProcessor # type: ignore[attr-defined,import-untyped,import-not-found] from fairseq2.assets.metadata_provider import AssetNotFoundError from fairseq2.config_registry import ConfigRegistry @@ -31,7 +33,7 @@ from fairseq2.recipes.hg.evaluator import HFEvaluator from fairseq2.recipes.utils.asset import retrieve_asset_card from fairseq2.recipes.utils.setup import setup_root_gang -from fairseq2.typing import META, DataType +from fairseq2.typing import META, DataType, Device from fairseq2.utils.profiler import Stopwatch log = get_log_writer(__name__) @@ -49,9 +51,6 @@ class AsrEvalConfig: model_name: str """The name of the model to evaluate.""" - # converter: Callable[[Example], Seq2SeqBatch] - # """The converter function to convert collated data into Seq2SeqBatch""" - tokenizer_name: str = "librispeech_asr" """The tokenizer to use.""" @@ -95,10 +94,10 @@ def _default_asr_config() -> AsrEvalConfig: dataset_name="librispeech_asr", model_name="wav2vec2_asr_base_10h", split="test.other", - # converter=librispeech_asr_to_batch, ) -def _librispeech_asr_to_batch(examples: Example) -> Seq2SeqBatch: + +def to_batch(examples: Example, model_type: str, device: Device) -> Seq2SeqBatch: """ Converts a collated batch of examples into a Seq2SeqBatch. @@ -111,11 +110,19 @@ def _librispeech_asr_to_batch(examples: Example) -> Seq2SeqBatch: source_data = cast(SequenceData, examples["audio"]) target_data = cast(SequenceData, examples["text"]) - source_seqs, source_padding_mask = get_seqs_and_padding_mask(source_data) - target_seqs, target_padding_mask = get_seqs_and_padding_mask(target_data) + if model_type == "wav2vec2": + source_seqs, source_padding_mask = get_seqs_and_padding_mask(source_data) + elif model_type == "whisper": + source_seqs = cast(source_data) + source_padding_mask = None + else: + raise ValueError(f"Unknown model type: {model_type}") + + target_seqs = target_data + target_padding_mask = None return Seq2SeqBatch( - source_seqs, + source_seqs.to(device), source_padding_mask, target_seqs, target_padding_mask, @@ -123,42 +130,28 @@ def _librispeech_asr_to_batch(examples: Example) -> Seq2SeqBatch: ) -@lru_cache(maxsize=None) -def get_cached_tokenizer(tokenizer_name: str) -> TextTokenizer: - return load_text_tokenizer(tokenizer_name) - - -def _preprocess_example( - example: Example, tokenizer_name: str, device: torch.device -) -> Example: +def extract_features(example: Example) -> Example: """ Preprocesses an individual example by converting the audio array to a PyTorch tensor and encoding the text. Args: example (dict): A dictionary containing "audio" and "text" keys. - tokenizer_name (str): The name of the tokenizer to use. device (torch.device): The device to store the tensors. Returns: dict: A dictionary with "audio" and "text" as PyTorch tensors. """ - tokenizer = get_cached_tokenizer(tokenizer_name) - encoder = tokenizer.create_encoder(device=device) - audio_tensor = ( - torch.from_numpy(example["audio"]["array"]).to(torch.float16).to(device) - ) - text_tensor = encoder(example["text"].lower()).to(device) - return {"audio": audio_tensor, "text": text_tensor} + return {"audio": example["audio"]["array"], "text": example["text"].lower()} -def seq2seq_preprocessor(batch: Seq2SeqBatch) -> tuple[SequenceBatch, SequenceBatch]: +def evaluator_preprocessor(batch: Seq2SeqBatch) -> tuple[SequenceBatch, SequenceBatch]: return SequenceBatch(batch.source_seqs, batch.source_padding_mask), SequenceBatch( batch.target_seqs, batch.target_padding_mask ) -def postprocesser( +def evaluator_postprocesser( outputs: Any, targets: SequenceBatch, tokenizer: TextTokenizer ) -> tuple[list[str], list[str]]: decoder = tokenizer.create_decoder() @@ -166,11 +159,43 @@ def postprocesser( hypotheses, _ = outputs.generate_hypotheses(pad_idx=pad_idx) predictions = [decoder(item) for item in hypotheses] - references = [decoder(item) for item in targets.seqs.to(torch.int32)] + references = targets.seqs return predictions, references +def prepare_dataset( + config: AsrEvalConfig, processor: Optional[Callable[[Example], Example]] = None +) -> Dataset: + iterable_ds = load_dataset(config.dataset_name, split=config.split, streaming=True) + ds = Dataset.from_generator( + lambda: itertools.islice(iterable_ds, 0, config.max_samples), + features=iterable_ds.features, + ) + ds = ds.map(lambda x: extract_features(x)) + + if processor is not None: + ds = ds.map(processor) + + format = { + "type": "torch", + "format_kwargs": {"dtype": config.dtype}, + } + ds.set_format(**format, columns=["audio", "text"]) + + return ds + + +def load_asr_evaluator( + config: AsrEvalConfig, output_dir: Path +) -> HFEvaluator[Seq2SeqBatch]: + try: + retrieve_asset_card(config.model_name) + return load_wav2vec2_asr_evaluator(config, output_dir) + except AssetNotFoundError: + return load_hg_asr_evaluator(config, output_dir) + + def load_wav2vec2_asr_evaluator( config: AsrEvalConfig, output_dir: Path ) -> HFEvaluator[Seq2SeqBatch]: @@ -188,12 +213,7 @@ def load_wav2vec2_asr_evaluator( if not isinstance(config, AsrEvalConfig): raise ValueError(f"Expect AsrEvalConfig, get {type(config)}") - iterable_ds = load_dataset(config.dataset_name, split=config.split, streaming=True) - # Load a subset of the dataset if max_samples is set - ds = Dataset.from_generator( - lambda: itertools.islice(iterable_ds, 0, config.max_samples), - features=iterable_ds.features, - ) + ds = prepare_dataset(config) gang = setup_root_gang(log) @@ -202,19 +222,12 @@ def load_wav2vec2_asr_evaluator( else: init_device = META - ds = ds.map(lambda x: _preprocess_example(x, config.tokenizer_name, init_device)) - format = { - "type": "torch", - "format_kwargs": {"dtype": torch.float16, "device": init_device}, - } - ds.set_format(**format, columns=["audio", "text"]) - - tokenizer = get_cached_tokenizer(config.tokenizer_name) + tokenizer = load_text_tokenizer(config.tokenizer_name) pipeline_reader = create_hf_reader( dataset=ds, gang=gang, - converter=_librispeech_asr_to_batch, + converter=lambda x: to_batch(x, "wav2vec2", init_device), batching=StaticBatching(config.max_num_elements), num_prefetch=config.num_prefetch, pad_value=tokenizer.vocab_info.pad_idx, @@ -233,6 +246,80 @@ def load_wav2vec2_asr_evaluator( gang=gang, data_reader=pipeline_reader, wall_watch=wall_watch, - preprocessor=seq2seq_preprocessor, - postprocessor=lambda x, y: postprocesser(x, y, tokenizer), + preprocessor=evaluator_preprocessor, + postprocessor=lambda x, y: evaluator_postprocesser(x, y, tokenizer), + ) + + +class HGModelWrapper: + def __init__(self, model): + self.model = model + + def __call__(self, batch: SequenceBatch): + return self.model.generate(batch.seqs) + + +def load_hg_asr_evaluator( + config: AsrEvalConfig, output_dir: Path +) -> HFEvaluator[Seq2SeqBatch]: + """ + Load the evaluator used for downstream evaluation of the whisper model + in a downstream dataset and report BLEU scores + + Args: + config (HFEvalConfig): The configuration for the evaluation. + output_dir (Path): The output directory to store the evaluation results. + + Returns: + HFEvaluator: Evaluation process results. + """ + if not isinstance(config, AsrEvalConfig): + raise ValueError(f"Expect AsrEvalConfig, get {type(config)}") + + gang = setup_root_gang(log) + + if gang.rank == 0: + init_device = gang.device + else: + init_device = META + + processor = WhisperProcessor.from_pretrained(config.model_name) + model = WhisperForConditionalGeneration.from_pretrained(config.model_name).to( + init_device + ) + + ds_builder = load_dataset_builder(config.dataset_name) + ds = prepare_dataset( + config, + lambda x: { + "audio": processor( + x["audio"], + sampling_rate=ds_builder.info.features["audio"].sampling_rate, + return_tensors="pt", + ).input_features.squeeze(0) + }, + ) + + pipeline_reader = create_hf_reader( + dataset=ds, + gang=gang, + converter=lambda x: to_batch(x, "whisper", init_device), + batching=StaticBatching(config.max_num_elements), + num_prefetch=config.num_prefetch, + max_seq_len=config.max_audio_len, + ) + + wall_watch = Stopwatch(start=True, device=init_device) + + return HFEvaluator[Seq2SeqBatch]( + model=HGModelWrapper(model), + metrics=["bleu"], + gang=gang, + data_reader=pipeline_reader, + wall_watch=wall_watch, + preprocessor=evaluator_preprocessor, + postprocessor=lambda x, y: ( + processor.batch_decode(x, skip_special_tokens=True), + y.seqs, + ), ) From a64576487a355d91b429a4eaef74e11f87e01a8f Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Fri, 9 Aug 2024 17:57:04 +0000 Subject: [PATCH 4/8] fix a typo --- src/fairseq2/recipes/hg/asr_eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fairseq2/recipes/hg/asr_eval.py b/src/fairseq2/recipes/hg/asr_eval.py index b36476354..3271e4ce3 100644 --- a/src/fairseq2/recipes/hg/asr_eval.py +++ b/src/fairseq2/recipes/hg/asr_eval.py @@ -113,7 +113,7 @@ def to_batch(examples: Example, model_type: str, device: Device) -> Seq2SeqBatch if model_type == "wav2vec2": source_seqs, source_padding_mask = get_seqs_and_padding_mask(source_data) elif model_type == "whisper": - source_seqs = cast(source_data) + source_seqs = source_data source_padding_mask = None else: raise ValueError(f"Unknown model type: {model_type}") From ce844e514fa1d3f2134e9c14dc69455815adaaf1 Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Fri, 9 Aug 2024 19:06:46 +0000 Subject: [PATCH 5/8] correctly manage and free cuda memory --- src/fairseq2/recipes/hg/asr_eval.py | 6 ++++-- src/fairseq2/recipes/hg/evaluator.py | 7 +++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/fairseq2/recipes/hg/asr_eval.py b/src/fairseq2/recipes/hg/asr_eval.py index 3271e4ce3..a685a1b5c 100644 --- a/src/fairseq2/recipes/hg/asr_eval.py +++ b/src/fairseq2/recipes/hg/asr_eval.py @@ -112,8 +112,10 @@ def to_batch(examples: Example, model_type: str, device: Device) -> Seq2SeqBatch if model_type == "wav2vec2": source_seqs, source_padding_mask = get_seqs_and_padding_mask(source_data) + source_seqs = source_seqs.to(device) + source_padding_mask = source_padding_mask.to(device) elif model_type == "whisper": - source_seqs = source_data + source_seqs = source_data.to(device) source_padding_mask = None else: raise ValueError(f"Unknown model type: {model_type}") @@ -122,7 +124,7 @@ def to_batch(examples: Example, model_type: str, device: Device) -> Seq2SeqBatch target_padding_mask = None return Seq2SeqBatch( - source_seqs.to(device), + source_seqs, source_padding_mask, target_seqs, target_padding_mask, diff --git a/src/fairseq2/recipes/hg/evaluator.py b/src/fairseq2/recipes/hg/evaluator.py index b30b0886d..d84f207b3 100644 --- a/src/fairseq2/recipes/hg/evaluator.py +++ b/src/fairseq2/recipes/hg/evaluator.py @@ -12,6 +12,8 @@ from pathlib import Path from typing import Any, Generic, TypeVar, final +import torch + from fairseq2.datasets import DataReader from fairseq2.gang import FakeGang, Gang from fairseq2.logging import get_log_writer @@ -172,6 +174,11 @@ def _do_run(self) -> None: predictions=predictions, references=references ) + del inputs + del targets + del outputs + torch.cuda.empty_cache() + self._root_gang.barrier() self._elapsed_time = watch.get_elapsed_time() From 28657f4275d94a59a0995c6e994ee8dc902eaa32 Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Fri, 9 Aug 2024 23:24:39 +0000 Subject: [PATCH 6/8] Reorder functions and lint --- src/fairseq2/recipes/hg/asr_eval.py | 72 ++++++++++++++-------------- src/fairseq2/recipes/hg/evaluator.py | 2 +- 2 files changed, 38 insertions(+), 36 deletions(-) diff --git a/src/fairseq2/recipes/hg/asr_eval.py b/src/fairseq2/recipes/hg/asr_eval.py index a685a1b5c..3bd61fa96 100644 --- a/src/fairseq2/recipes/hg/asr_eval.py +++ b/src/fairseq2/recipes/hg/asr_eval.py @@ -16,7 +16,10 @@ load_dataset, load_dataset_builder, ) -from transformers import WhisperForConditionalGeneration, WhisperProcessor # type: ignore[attr-defined,import-untyped,import-not-found] +from transformers import ( # type: ignore[attr-defined,import-untyped,import-not-found] + WhisperForConditionalGeneration, + WhisperProcessor, +) from fairseq2.assets.metadata_provider import AssetNotFoundError from fairseq2.config_registry import ConfigRegistry @@ -84,7 +87,6 @@ class AsrEvalConfig: asr_eval_presets = ConfigRegistry[AsrEvalConfig]() - asr_eval_preset = asr_eval_presets.decorator @@ -97,6 +99,21 @@ def _default_asr_config() -> AsrEvalConfig: ) +def extract_features(example: Example) -> Example: + """ + Preprocesses an individual example by converting the audio array to a PyTorch tensor + and encoding the text. + + Args: + example (dict): A dictionary containing "audio" and "text" keys. + device (torch.device): The device to store the tensors. + + Returns: + dict: A dictionary with "audio" and "text" as PyTorch tensors. + """ + return {"audio": example["audio"]["array"], "text": example["text"].lower()} + + def to_batch(examples: Example, model_type: str, device: Device) -> Seq2SeqBatch: """ Converts a collated batch of examples into a Seq2SeqBatch. @@ -132,19 +149,26 @@ def to_batch(examples: Example, model_type: str, device: Device) -> Seq2SeqBatch ) -def extract_features(example: Example) -> Example: - """ - Preprocesses an individual example by converting the audio array to a PyTorch tensor - and encoding the text. +def prepare_dataset( + config: AsrEvalConfig, processor: Optional[Callable[[Example], Example]] = None +) -> Dataset: + iterable_ds = load_dataset(config.dataset_name, split=config.split, streaming=True) + ds = Dataset.from_generator( + lambda: itertools.islice(iterable_ds, 0, config.max_samples), + features=iterable_ds.features, + ) + ds = ds.map(lambda x: extract_features(x)) - Args: - example (dict): A dictionary containing "audio" and "text" keys. - device (torch.device): The device to store the tensors. + if processor is not None: + ds = ds.map(processor) - Returns: - dict: A dictionary with "audio" and "text" as PyTorch tensors. - """ - return {"audio": example["audio"]["array"], "text": example["text"].lower()} + format = { + "type": "torch", + "format_kwargs": {"dtype": config.dtype}, + } + ds.set_format(**format, columns=["audio", "text"]) + + return ds def evaluator_preprocessor(batch: Seq2SeqBatch) -> tuple[SequenceBatch, SequenceBatch]: @@ -166,28 +190,6 @@ def evaluator_postprocesser( return predictions, references -def prepare_dataset( - config: AsrEvalConfig, processor: Optional[Callable[[Example], Example]] = None -) -> Dataset: - iterable_ds = load_dataset(config.dataset_name, split=config.split, streaming=True) - ds = Dataset.from_generator( - lambda: itertools.islice(iterable_ds, 0, config.max_samples), - features=iterable_ds.features, - ) - ds = ds.map(lambda x: extract_features(x)) - - if processor is not None: - ds = ds.map(processor) - - format = { - "type": "torch", - "format_kwargs": {"dtype": config.dtype}, - } - ds.set_format(**format, columns=["audio", "text"]) - - return ds - - def load_asr_evaluator( config: AsrEvalConfig, output_dir: Path ) -> HFEvaluator[Seq2SeqBatch]: diff --git a/src/fairseq2/recipes/hg/evaluator.py b/src/fairseq2/recipes/hg/evaluator.py index d84f207b3..c624523f7 100644 --- a/src/fairseq2/recipes/hg/evaluator.py +++ b/src/fairseq2/recipes/hg/evaluator.py @@ -174,7 +174,7 @@ def _do_run(self) -> None: predictions=predictions, references=references ) - del inputs + del inputs del targets del outputs torch.cuda.empty_cache() From 3a8404d3dde9bf3a8e126444163b997de6553747 Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Fri, 9 Aug 2024 23:45:21 +0000 Subject: [PATCH 7/8] fix type hints --- src/fairseq2/recipes/hg/asr_eval.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/fairseq2/recipes/hg/asr_eval.py b/src/fairseq2/recipes/hg/asr_eval.py index 3bd61fa96..139900c65 100644 --- a/src/fairseq2/recipes/hg/asr_eval.py +++ b/src/fairseq2/recipes/hg/asr_eval.py @@ -28,6 +28,7 @@ from fairseq2.data.text.text_tokenizer import TextTokenizer from fairseq2.datasets.batching import StaticBatching from fairseq2.logging import get_log_writer +from fairseq2.models.model import Model from fairseq2.models.seq2seq import Seq2SeqBatch from fairseq2.models.sequence import SequenceBatch from fairseq2.models.wav2vec2.asr import load_wav2vec2_asr_model @@ -125,14 +126,16 @@ def to_batch(examples: Example, model_type: str, device: Device) -> Seq2SeqBatch Seq2SeqBatch: A batch of audio and text sequences. """ source_data = cast(SequenceData, examples["audio"]) - target_data = cast(SequenceData, examples["text"]) + target_data = cast(torch.Tensor, examples["text"]) if model_type == "wav2vec2": source_seqs, source_padding_mask = get_seqs_and_padding_mask(source_data) source_seqs = source_seqs.to(device) - source_padding_mask = source_padding_mask.to(device) + source_padding_mask = ( + source_padding_mask.to(device) if source_padding_mask is not None else None + ) elif model_type == "whisper": - source_seqs = source_data.to(device) + source_seqs = cast(torch.Tensor, source_data).to(device) source_padding_mask = None else: raise ValueError(f"Unknown model type: {model_type}") @@ -185,7 +188,7 @@ def evaluator_postprocesser( hypotheses, _ = outputs.generate_hypotheses(pad_idx=pad_idx) predictions = [decoder(item) for item in hypotheses] - references = targets.seqs + references = cast(list[str], targets.seqs) return predictions, references @@ -256,10 +259,10 @@ def load_wav2vec2_asr_evaluator( class HGModelWrapper: - def __init__(self, model): + def __init__(self, model: WhisperForConditionalGeneration): self.model = model - def __call__(self, batch: SequenceBatch): + def __call__(self, batch: SequenceBatch) -> Any: return self.model.generate(batch.seqs) @@ -316,7 +319,7 @@ def load_hg_asr_evaluator( wall_watch = Stopwatch(start=True, device=init_device) return HFEvaluator[Seq2SeqBatch]( - model=HGModelWrapper(model), + model=cast(Model, HGModelWrapper(model)), metrics=["bleu"], gang=gang, data_reader=pipeline_reader, @@ -324,6 +327,6 @@ def load_hg_asr_evaluator( preprocessor=evaluator_preprocessor, postprocessor=lambda x, y: ( processor.batch_decode(x, skip_special_tokens=True), - y.seqs, + cast(list[str], y.seqs), ), ) From 87255a52691650d2c8f821c63d33441d6fb416be Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Sun, 11 Aug 2024 21:06:49 +0000 Subject: [PATCH 8/8] Update docstrings --- src/fairseq2/recipes/hg/asr_eval.py | 37 +++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/src/fairseq2/recipes/hg/asr_eval.py b/src/fairseq2/recipes/hg/asr_eval.py index 139900c65..c57af7ef4 100644 --- a/src/fairseq2/recipes/hg/asr_eval.py +++ b/src/fairseq2/recipes/hg/asr_eval.py @@ -102,15 +102,13 @@ def _default_asr_config() -> AsrEvalConfig: def extract_features(example: Example) -> Example: """ - Preprocesses an individual example by converting the audio array to a PyTorch tensor - and encoding the text. + Extracts source and target features from dataset examples. Args: - example (dict): A dictionary containing "audio" and "text" keys. - device (torch.device): The device to store the tensors. + example (dict): A dictionary containing source and target keys. Returns: - dict: A dictionary with "audio" and "text" as PyTorch tensors. + dict: A dictionary containing the extracted features. """ return {"audio": example["audio"]["array"], "text": example["text"].lower()} @@ -155,6 +153,17 @@ def to_batch(examples: Example, model_type: str, device: Device) -> Seq2SeqBatch def prepare_dataset( config: AsrEvalConfig, processor: Optional[Callable[[Example], Example]] = None ) -> Dataset: + """ + Prepares a dataset for evaluation. The dataset is loaded from the + HF datasets and preprocessed using the provided processor. + + Args: + config (AsrEvalConfig): The configuration for the evaluation. + processor (Callable): A function to preprocess examples. + + Returns: + Dataset: The prepared dataset. + """ iterable_ds = load_dataset(config.dataset_name, split=config.split, streaming=True) ds = Dataset.from_generator( lambda: itertools.islice(iterable_ds, 0, config.max_samples), @@ -196,6 +205,16 @@ def evaluator_postprocesser( def load_asr_evaluator( config: AsrEvalConfig, output_dir: Path ) -> HFEvaluator[Seq2SeqBatch]: + """ + Load the evaluator based on the model type. + + Args: + config (AsrEvalConfig): The configuration for the evaluation. + output_dir (Path): The output directory to store the evaluation results. + + Returns: + HFEvaluator: Evaluation process. + """ try: retrieve_asset_card(config.model_name) return load_wav2vec2_asr_evaluator(config, output_dir) @@ -207,15 +226,14 @@ def load_wav2vec2_asr_evaluator( config: AsrEvalConfig, output_dir: Path ) -> HFEvaluator[Seq2SeqBatch]: """ - Load the evaluator used for downstream evaluation of the model - in a downstream dataset and report BLEU scores + Load the evaluator used for downstream evaluation of the wav2vec2 model. Args: config (HFEvalConfig): The configuration for the evaluation. output_dir (Path): The output directory to store the evaluation results. Returns: - HFEvaluator: Evaluation process results. + HFEvaluator: Evaluation process. """ if not isinstance(config, AsrEvalConfig): raise ValueError(f"Expect AsrEvalConfig, get {type(config)}") @@ -271,14 +289,13 @@ def load_hg_asr_evaluator( ) -> HFEvaluator[Seq2SeqBatch]: """ Load the evaluator used for downstream evaluation of the whisper model - in a downstream dataset and report BLEU scores Args: config (HFEvalConfig): The configuration for the evaluation. output_dir (Path): The output directory to store the evaluation results. Returns: - HFEvaluator: Evaluation process results. + HFEvaluator: Evaluation process. """ if not isinstance(config, AsrEvalConfig): raise ValueError(f"Expect AsrEvalConfig, get {type(config)}")