From feff2cd3577c322fe828ea4f7a9b0549cae010fe Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Thu, 15 Aug 2024 02:44:10 +0000 Subject: [PATCH 01/12] refactor the hg interface to support multiple models through presets --- src/fairseq2/recipes/hg/__init__.py | 10 +++++----- src/fairseq2/recipes/hg/asr_eval.py | 7 ++++--- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/fairseq2/recipes/hg/__init__.py b/src/fairseq2/recipes/hg/__init__.py index 36d4ba38a..6435cbd89 100644 --- a/src/fairseq2/recipes/hg/__init__.py +++ b/src/fairseq2/recipes/hg/__init__.py @@ -33,17 +33,17 @@ def _setup_hg_cli(cli: Cli) -> None: from fairseq2.recipes.hg.asr_eval import ( asr_eval_presets, - load_wav2vec2_asr_evaluator, + load_asr_evaluator, ) handler = RecipeCommandHandler( - load_wav2vec2_asr_evaluator, + load_asr_evaluator, preset_configs=asr_eval_presets, - default_preset="librispeech_asr", + default_preset="default_asr", ) group.add_command( - "wav2vec2_asr", + "asr", handler, - help="evaluate a wav2vec 2.0 ASR model on a downstream benchmark (default: librispeech_asr)", + help="evaluate an ASR model (default: wav2vec2) on a downstream benchmark (default: librispeech_asr)", ) diff --git a/src/fairseq2/recipes/hg/asr_eval.py b/src/fairseq2/recipes/hg/asr_eval.py index 17b1f3a02..f37787df2 100644 --- a/src/fairseq2/recipes/hg/asr_eval.py +++ b/src/fairseq2/recipes/hg/asr_eval.py @@ -16,6 +16,7 @@ load_dataset, ) +from fairseq2.assets.metadata_provider import AssetNotFoundError from fairseq2.config_registry import ConfigRegistry from fairseq2.data.data_pipeline import SequenceData from fairseq2.data.text import load_text_tokenizer @@ -28,6 +29,7 @@ from fairseq2.nn.padding import get_seqs_and_padding_mask from fairseq2.recipes.hg.dataset import Example, create_hf_reader from fairseq2.recipes.hg.evaluator import HFEvaluator +from fairseq2.recipes.utils.asset import retrieve_asset_card from fairseq2.recipes.utils.setup import setup_root_gang from fairseq2.typing import META, DataType from fairseq2.utils.profiler import Stopwatch @@ -87,8 +89,8 @@ class AsrEvalConfig: asr_eval_preset = asr_eval_presets.decorator -@asr_eval_preset("librispeech_asr") -def _librispeech_asr_config() -> AsrEvalConfig: +@asr_eval_preset("default_asr") +def _default_asr_config() -> AsrEvalConfig: return AsrEvalConfig( dataset_name="librispeech_asr", model_name="wav2vec2_asr_base_10h", @@ -96,7 +98,6 @@ def _librispeech_asr_config() -> AsrEvalConfig: # converter=librispeech_asr_to_batch, ) - def _librispeech_asr_to_batch(examples: Example) -> Seq2SeqBatch: """ Converts a collated batch of examples into a Seq2SeqBatch. From d9f753eeffd4b168b0a7893bc7efab1495453a2f Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Thu, 15 Aug 2024 02:48:38 +0000 Subject: [PATCH 02/12] Refactor ASR evaluation code for improved extensibility --- src/fairseq2/recipes/hg/asr_eval.py | 99 +++++++++++++++-------------- 1 file changed, 50 insertions(+), 49 deletions(-) diff --git a/src/fairseq2/recipes/hg/asr_eval.py b/src/fairseq2/recipes/hg/asr_eval.py index f37787df2..909db5987 100644 --- a/src/fairseq2/recipes/hg/asr_eval.py +++ b/src/fairseq2/recipes/hg/asr_eval.py @@ -5,10 +5,10 @@ # LICENSE file in the root directory of this source tree. import itertools +from collections.abc import Callable from dataclasses import dataclass -from functools import lru_cache from pathlib import Path -from typing import Any, cast +from typing import Any, Optional, cast import torch from datasets import ( # type: ignore[attr-defined,import-untyped,import-not-found] @@ -31,7 +31,7 @@ from fairseq2.recipes.hg.evaluator import HFEvaluator from fairseq2.recipes.utils.asset import retrieve_asset_card from fairseq2.recipes.utils.setup import setup_root_gang -from fairseq2.typing import META, DataType +from fairseq2.typing import META, DataType, Device from fairseq2.utils.profiler import Stopwatch log = get_log_writer(__name__) @@ -49,9 +49,6 @@ class AsrEvalConfig: model_name: str """The name of the model to evaluate.""" - # converter: Callable[[Example], Seq2SeqBatch] - # """The converter function to convert collated data into Seq2SeqBatch""" - tokenizer_name: str = "librispeech_asr" """The tokenizer to use.""" @@ -95,10 +92,10 @@ def _default_asr_config() -> AsrEvalConfig: dataset_name="librispeech_asr", model_name="wav2vec2_asr_base_10h", split="test.other", - # converter=librispeech_asr_to_batch, ) -def _librispeech_asr_to_batch(examples: Example) -> Seq2SeqBatch: + +def to_batch(examples: Example, model_type: str, device: Device) -> Seq2SeqBatch: """ Converts a collated batch of examples into a Seq2SeqBatch. @@ -111,11 +108,19 @@ def _librispeech_asr_to_batch(examples: Example) -> Seq2SeqBatch: source_data = cast(SequenceData, examples["audio"]) target_data = cast(SequenceData, examples["text"]) - source_seqs, source_padding_mask = get_seqs_and_padding_mask(source_data) - target_seqs, target_padding_mask = get_seqs_and_padding_mask(target_data) + if model_type == "wav2vec2": + source_seqs, source_padding_mask = get_seqs_and_padding_mask(source_data) + elif model_type == "whisper": + source_seqs = cast(source_data) + source_padding_mask = None + else: + raise ValueError(f"Unknown model type: {model_type}") + + target_seqs = target_data + target_padding_mask = None return Seq2SeqBatch( - source_seqs, + source_seqs.to(device), source_padding_mask, target_seqs, target_padding_mask, @@ -123,42 +128,28 @@ def _librispeech_asr_to_batch(examples: Example) -> Seq2SeqBatch: ) -@lru_cache(maxsize=None) -def get_cached_tokenizer(tokenizer_name: str) -> TextTokenizer: - return load_text_tokenizer(tokenizer_name) - - -def _preprocess_example( - example: Example, tokenizer_name: str, device: torch.device -) -> Example: +def extract_features(example: Example) -> Example: """ Preprocesses an individual example by converting the audio array to a PyTorch tensor and encoding the text. Args: example (dict): A dictionary containing "audio" and "text" keys. - tokenizer_name (str): The name of the tokenizer to use. device (torch.device): The device to store the tensors. Returns: dict: A dictionary with "audio" and "text" as PyTorch tensors. """ - tokenizer = get_cached_tokenizer(tokenizer_name) - encoder = tokenizer.create_encoder(device=device) - audio_tensor = ( - torch.from_numpy(example["audio"]["array"]).to(torch.float16).to(device) - ) - text_tensor = encoder(example["text"].lower()).to(device) - return {"audio": audio_tensor, "text": text_tensor} + return {"audio": example["audio"]["array"], "text": example["text"].lower()} -def seq2seq_preprocessor(batch: Seq2SeqBatch) -> tuple[SequenceBatch, SequenceBatch]: +def evaluator_preprocessor(batch: Seq2SeqBatch) -> tuple[SequenceBatch, SequenceBatch]: return SequenceBatch(batch.source_seqs, batch.source_padding_mask), SequenceBatch( batch.target_seqs, batch.target_padding_mask ) -def postprocesser( +def evaluator_postprocesser( outputs: Any, targets: SequenceBatch, tokenizer: TextTokenizer ) -> tuple[list[str], list[str]]: decoder = tokenizer.create_decoder() @@ -166,12 +157,34 @@ def postprocesser( hypotheses, _ = outputs.generate_hypotheses(pad_idx=pad_idx) predictions = [decoder(item) for item in hypotheses] - references = [decoder(item) for item in targets.seqs.to(torch.int32)] + references = targets.seqs return predictions, references -def load_wav2vec2_asr_evaluator( +def prepare_dataset( + config: AsrEvalConfig, processor: Optional[Callable[[Example], Example]] = None +) -> Dataset: + iterable_ds = load_dataset(config.dataset_name, split=config.split, streaming=True) + ds = Dataset.from_generator( + lambda: itertools.islice(iterable_ds, 0, config.max_samples), + features=iterable_ds.features, + ) + ds = ds.map(lambda x: extract_features(x)) + + if processor is not None: + ds = ds.map(processor) + + format = { + "type": "torch", + "format_kwargs": {"dtype": config.dtype}, + } + ds.set_format(**format, columns=["audio", "text"]) + + return ds + + +def load_asr_evaluator( config: AsrEvalConfig, output_dir: Path ) -> HFEvaluator[Seq2SeqBatch]: """ @@ -188,12 +201,7 @@ def load_wav2vec2_asr_evaluator( if not isinstance(config, AsrEvalConfig): raise ValueError(f"Expect AsrEvalConfig, get {type(config)}") - iterable_ds = load_dataset(config.dataset_name, split=config.split, streaming=True) - # Load a subset of the dataset if max_samples is set - ds = Dataset.from_generator( - lambda: itertools.islice(iterable_ds, 0, config.max_samples), - features=iterable_ds.features, - ) + ds = prepare_dataset(config) gang = setup_root_gang(log) @@ -202,19 +210,12 @@ def load_wav2vec2_asr_evaluator( else: init_device = META - ds = ds.map(lambda x: _preprocess_example(x, config.tokenizer_name, init_device)) - format = { - "type": "torch", - "format_kwargs": {"dtype": torch.float16, "device": init_device}, - } - ds.set_format(**format, columns=["audio", "text"]) - - tokenizer = get_cached_tokenizer(config.tokenizer_name) + tokenizer = load_text_tokenizer(config.tokenizer_name) pipeline_reader = create_hf_reader( dataset=ds, gang=gang, - converter=_librispeech_asr_to_batch, + converter=lambda x: to_batch(x, "wav2vec2", init_device), batching=StaticBatching(config.max_num_elements), num_prefetch=config.num_prefetch, pad_value=tokenizer.vocab_info.pad_idx, @@ -233,6 +234,6 @@ def load_wav2vec2_asr_evaluator( gang=gang, data_reader=pipeline_reader, wall_watch=wall_watch, - preprocessor=seq2seq_preprocessor, - postprocessor=lambda x, y: postprocesser(x, y, tokenizer), - ) + preprocessor=evaluator_preprocessor, + postprocessor=lambda x, y: evaluator_postprocesser(x, y, tokenizer), + ) \ No newline at end of file From a3329b792b5a19bdbb80666daf9f7afe6cf91bd9 Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Thu, 15 Aug 2024 02:51:01 +0000 Subject: [PATCH 03/12] correctly manage and free cuda memory --- src/fairseq2/recipes/hg/asr_eval.py | 6 ++++-- src/fairseq2/recipes/hg/evaluator.py | 7 +++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/fairseq2/recipes/hg/asr_eval.py b/src/fairseq2/recipes/hg/asr_eval.py index 909db5987..6566af86e 100644 --- a/src/fairseq2/recipes/hg/asr_eval.py +++ b/src/fairseq2/recipes/hg/asr_eval.py @@ -110,8 +110,10 @@ def to_batch(examples: Example, model_type: str, device: Device) -> Seq2SeqBatch if model_type == "wav2vec2": source_seqs, source_padding_mask = get_seqs_and_padding_mask(source_data) + source_seqs = source_seqs.to(device) + source_padding_mask = source_padding_mask.to(device) elif model_type == "whisper": - source_seqs = cast(source_data) + source_seqs = source_data.to(device) source_padding_mask = None else: raise ValueError(f"Unknown model type: {model_type}") @@ -120,7 +122,7 @@ def to_batch(examples: Example, model_type: str, device: Device) -> Seq2SeqBatch target_padding_mask = None return Seq2SeqBatch( - source_seqs.to(device), + source_seqs, source_padding_mask, target_seqs, target_padding_mask, diff --git a/src/fairseq2/recipes/hg/evaluator.py b/src/fairseq2/recipes/hg/evaluator.py index b30b0886d..d84f207b3 100644 --- a/src/fairseq2/recipes/hg/evaluator.py +++ b/src/fairseq2/recipes/hg/evaluator.py @@ -12,6 +12,8 @@ from pathlib import Path from typing import Any, Generic, TypeVar, final +import torch + from fairseq2.datasets import DataReader from fairseq2.gang import FakeGang, Gang from fairseq2.logging import get_log_writer @@ -172,6 +174,11 @@ def _do_run(self) -> None: predictions=predictions, references=references ) + del inputs + del targets + del outputs + torch.cuda.empty_cache() + self._root_gang.barrier() self._elapsed_time = watch.get_elapsed_time() From da5d8d3e8d04361838a282824b0e6ca3b8afc869 Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Thu, 15 Aug 2024 02:54:28 +0000 Subject: [PATCH 04/12] fix type hints and lint --- src/fairseq2/recipes/hg/__init__.py | 5 +- src/fairseq2/recipes/hg/asr_eval.py | 81 ++++++++++++++-------------- src/fairseq2/recipes/hg/evaluator.py | 2 +- 3 files changed, 42 insertions(+), 46 deletions(-) diff --git a/src/fairseq2/recipes/hg/__init__.py b/src/fairseq2/recipes/hg/__init__.py index 6435cbd89..dfc7847d6 100644 --- a/src/fairseq2/recipes/hg/__init__.py +++ b/src/fairseq2/recipes/hg/__init__.py @@ -31,10 +31,7 @@ def _setup_hg_cli(cli: Cli) -> None: group = cli.add_group("hg", help="Hugging Face recipes") - from fairseq2.recipes.hg.asr_eval import ( - asr_eval_presets, - load_asr_evaluator, - ) + from fairseq2.recipes.hg.asr_eval import asr_eval_presets, load_asr_evaluator handler = RecipeCommandHandler( load_asr_evaluator, diff --git a/src/fairseq2/recipes/hg/asr_eval.py b/src/fairseq2/recipes/hg/asr_eval.py index 6566af86e..39599e2d7 100644 --- a/src/fairseq2/recipes/hg/asr_eval.py +++ b/src/fairseq2/recipes/hg/asr_eval.py @@ -16,7 +16,6 @@ load_dataset, ) -from fairseq2.assets.metadata_provider import AssetNotFoundError from fairseq2.config_registry import ConfigRegistry from fairseq2.data.data_pipeline import SequenceData from fairseq2.data.text import load_text_tokenizer @@ -29,7 +28,6 @@ from fairseq2.nn.padding import get_seqs_and_padding_mask from fairseq2.recipes.hg.dataset import Example, create_hf_reader from fairseq2.recipes.hg.evaluator import HFEvaluator -from fairseq2.recipes.utils.asset import retrieve_asset_card from fairseq2.recipes.utils.setup import setup_root_gang from fairseq2.typing import META, DataType, Device from fairseq2.utils.profiler import Stopwatch @@ -82,7 +80,6 @@ class AsrEvalConfig: asr_eval_presets = ConfigRegistry[AsrEvalConfig]() - asr_eval_preset = asr_eval_presets.decorator @@ -95,6 +92,21 @@ def _default_asr_config() -> AsrEvalConfig: ) +def extract_features(example: Example) -> Example: + """ + Preprocesses an individual example by converting the audio array to a PyTorch tensor + and encoding the text. + + Args: + example (dict): A dictionary containing "audio" and "text" keys. + device (torch.device): The device to store the tensors. + + Returns: + dict: A dictionary with "audio" and "text" as PyTorch tensors. + """ + return {"audio": example["audio"]["array"], "text": example["text"].lower()} + + def to_batch(examples: Example, model_type: str, device: Device) -> Seq2SeqBatch: """ Converts a collated batch of examples into a Seq2SeqBatch. @@ -106,14 +118,16 @@ def to_batch(examples: Example, model_type: str, device: Device) -> Seq2SeqBatch Seq2SeqBatch: A batch of audio and text sequences. """ source_data = cast(SequenceData, examples["audio"]) - target_data = cast(SequenceData, examples["text"]) + target_data = cast(torch.Tensor, examples["text"]) if model_type == "wav2vec2": source_seqs, source_padding_mask = get_seqs_and_padding_mask(source_data) source_seqs = source_seqs.to(device) - source_padding_mask = source_padding_mask.to(device) + source_padding_mask = ( + source_padding_mask.to(device) if source_padding_mask is not None else None + ) elif model_type == "whisper": - source_seqs = source_data.to(device) + source_seqs = cast(torch.Tensor, source_data).to(device) source_padding_mask = None else: raise ValueError(f"Unknown model type: {model_type}") @@ -130,40 +144,6 @@ def to_batch(examples: Example, model_type: str, device: Device) -> Seq2SeqBatch ) -def extract_features(example: Example) -> Example: - """ - Preprocesses an individual example by converting the audio array to a PyTorch tensor - and encoding the text. - - Args: - example (dict): A dictionary containing "audio" and "text" keys. - device (torch.device): The device to store the tensors. - - Returns: - dict: A dictionary with "audio" and "text" as PyTorch tensors. - """ - return {"audio": example["audio"]["array"], "text": example["text"].lower()} - - -def evaluator_preprocessor(batch: Seq2SeqBatch) -> tuple[SequenceBatch, SequenceBatch]: - return SequenceBatch(batch.source_seqs, batch.source_padding_mask), SequenceBatch( - batch.target_seqs, batch.target_padding_mask - ) - - -def evaluator_postprocesser( - outputs: Any, targets: SequenceBatch, tokenizer: TextTokenizer -) -> tuple[list[str], list[str]]: - decoder = tokenizer.create_decoder() - pad_idx = tokenizer.vocab_info.pad_idx - - hypotheses, _ = outputs.generate_hypotheses(pad_idx=pad_idx) - predictions = [decoder(item) for item in hypotheses] - references = targets.seqs - - return predictions, references - - def prepare_dataset( config: AsrEvalConfig, processor: Optional[Callable[[Example], Example]] = None ) -> Dataset: @@ -186,6 +166,25 @@ def prepare_dataset( return ds +def evaluator_preprocessor(batch: Seq2SeqBatch) -> tuple[SequenceBatch, SequenceBatch]: + return SequenceBatch(batch.source_seqs, batch.source_padding_mask), SequenceBatch( + batch.target_seqs, batch.target_padding_mask + ) + + +def evaluator_postprocesser( + outputs: Any, targets: SequenceBatch, tokenizer: TextTokenizer +) -> tuple[list[str], list[str]]: + decoder = tokenizer.create_decoder() + pad_idx = tokenizer.vocab_info.pad_idx + + hypotheses, _ = outputs.generate_hypotheses(pad_idx=pad_idx) + predictions = [decoder(item) for item in hypotheses] + references = cast(list[str], targets.seqs) + + return predictions, references + + def load_asr_evaluator( config: AsrEvalConfig, output_dir: Path ) -> HFEvaluator[Seq2SeqBatch]: @@ -238,4 +237,4 @@ def load_asr_evaluator( wall_watch=wall_watch, preprocessor=evaluator_preprocessor, postprocessor=lambda x, y: evaluator_postprocesser(x, y, tokenizer), - ) \ No newline at end of file + ) diff --git a/src/fairseq2/recipes/hg/evaluator.py b/src/fairseq2/recipes/hg/evaluator.py index d84f207b3..c624523f7 100644 --- a/src/fairseq2/recipes/hg/evaluator.py +++ b/src/fairseq2/recipes/hg/evaluator.py @@ -174,7 +174,7 @@ def _do_run(self) -> None: predictions=predictions, references=references ) - del inputs + del inputs del targets del outputs torch.cuda.empty_cache() From 107aa5e9c6ccc8c6f17c6344940635d9830aa78f Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Thu, 15 Aug 2024 03:19:13 +0000 Subject: [PATCH 05/12] Use custom `EvalSeqbatch` instead of Seq2Seq2Batch --- src/fairseq2/recipes/hg/asr_eval.py | 51 ++++++++++++++-------------- src/fairseq2/recipes/hg/evaluator.py | 8 ++--- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/src/fairseq2/recipes/hg/asr_eval.py b/src/fairseq2/recipes/hg/asr_eval.py index 39599e2d7..15d55a0ec 100644 --- a/src/fairseq2/recipes/hg/asr_eval.py +++ b/src/fairseq2/recipes/hg/asr_eval.py @@ -8,7 +8,7 @@ from collections.abc import Callable from dataclasses import dataclass from pathlib import Path -from typing import Any, Optional, cast +from typing import Any, List, Optional, cast import torch from datasets import ( # type: ignore[attr-defined,import-untyped,import-not-found] @@ -22,10 +22,9 @@ from fairseq2.data.text.text_tokenizer import TextTokenizer from fairseq2.datasets.batching import StaticBatching from fairseq2.logging import get_log_writer -from fairseq2.models.seq2seq import Seq2SeqBatch from fairseq2.models.sequence import SequenceBatch from fairseq2.models.wav2vec2.asr import load_wav2vec2_asr_model -from fairseq2.nn.padding import get_seqs_and_padding_mask +from fairseq2.nn.padding import PaddingMask, get_seqs_and_padding_mask from fairseq2.recipes.hg.dataset import Example, create_hf_reader from fairseq2.recipes.hg.evaluator import HFEvaluator from fairseq2.recipes.utils.setup import setup_root_gang @@ -79,6 +78,18 @@ class AsrEvalConfig: """The data type of the model.""" +@dataclass +class EvalSeqBatch: + source_seqs: torch.Tensor + """The source sequences.""" + + source_padding_mask: Optional[PaddingMask] + """The source padding mask.""" + + target_seqs: List[str] + """The target sequences.""" + + asr_eval_presets = ConfigRegistry[AsrEvalConfig]() asr_eval_preset = asr_eval_presets.decorator @@ -107,40 +118,32 @@ def extract_features(example: Example) -> Example: return {"audio": example["audio"]["array"], "text": example["text"].lower()} -def to_batch(examples: Example, model_type: str, device: Device) -> Seq2SeqBatch: +def to_batch(examples: Example, model_type: str, device: Device) -> EvalSeqBatch: """ - Converts a collated batch of examples into a Seq2SeqBatch. + Converts a collated batch of examples into a EvalSeqBatch. Args: examples (dict): A dictionary containing "audio" and "text" keys. Returns: - Seq2SeqBatch: A batch of audio and text sequences. + EvalSeqBatch: A batch of audio and text sequences. """ - source_data = cast(SequenceData, examples["audio"]) - target_data = cast(torch.Tensor, examples["text"]) if model_type == "wav2vec2": - source_seqs, source_padding_mask = get_seqs_and_padding_mask(source_data) + source_seqs, source_padding_mask = get_seqs_and_padding_mask(examples["audio"]) source_seqs = source_seqs.to(device) source_padding_mask = ( source_padding_mask.to(device) if source_padding_mask is not None else None ) - elif model_type == "whisper": - source_seqs = cast(torch.Tensor, source_data).to(device) - source_padding_mask = None else: raise ValueError(f"Unknown model type: {model_type}") - target_seqs = target_data - target_padding_mask = None + target_seqs = examples['text'] - return Seq2SeqBatch( + return EvalSeqBatch( source_seqs, source_padding_mask, target_seqs, - target_padding_mask, - examples, ) @@ -166,28 +169,26 @@ def prepare_dataset( return ds -def evaluator_preprocessor(batch: Seq2SeqBatch) -> tuple[SequenceBatch, SequenceBatch]: - return SequenceBatch(batch.source_seqs, batch.source_padding_mask), SequenceBatch( - batch.target_seqs, batch.target_padding_mask - ) +def evaluator_preprocessor(batch: EvalSeqBatch) -> tuple[SequenceBatch, List[str]]: + return SequenceBatch(batch.source_seqs, batch.source_padding_mask), batch.target_seqs def evaluator_postprocesser( - outputs: Any, targets: SequenceBatch, tokenizer: TextTokenizer + outputs: Any, targets: list[str], tokenizer: TextTokenizer ) -> tuple[list[str], list[str]]: decoder = tokenizer.create_decoder() pad_idx = tokenizer.vocab_info.pad_idx hypotheses, _ = outputs.generate_hypotheses(pad_idx=pad_idx) predictions = [decoder(item) for item in hypotheses] - references = cast(list[str], targets.seqs) + references = targets return predictions, references def load_asr_evaluator( config: AsrEvalConfig, output_dir: Path -) -> HFEvaluator[Seq2SeqBatch]: +) -> HFEvaluator[EvalSeqBatch]: """ Load the evaluator used for downstream evaluation of the model in a downstream dataset and report BLEU scores @@ -229,7 +230,7 @@ def load_asr_evaluator( wall_watch = Stopwatch(start=True, device=init_device) - return HFEvaluator[Seq2SeqBatch]( + return HFEvaluator[EvalSeqBatch]( model=model, metrics=["bleu"], gang=gang, diff --git a/src/fairseq2/recipes/hg/evaluator.py b/src/fairseq2/recipes/hg/evaluator.py index c624523f7..e11acae59 100644 --- a/src/fairseq2/recipes/hg/evaluator.py +++ b/src/fairseq2/recipes/hg/evaluator.py @@ -39,8 +39,8 @@ class HFEvaluator(Generic[BatchT]): """Evaluate a machine learning model with HuggingFace's evaluate.Metric library""" _model: Model - _preprocessor: Callable[[BatchT], tuple[SequenceBatch, SequenceBatch]] - _postprocessor: Callable[[Any, SequenceBatch], tuple[list[str], list[str]]] + _preprocessor: Callable[[BatchT], tuple[Any, Any]] + _postprocessor: Callable[[Any, Any], tuple[Any, Any]] _root_gang: Gang _dp_gang: Gang _tp_gang: Gang @@ -57,8 +57,8 @@ def __init__( gang: Gang, data_reader: DataReader[BatchT], wall_watch: Stopwatch, - preprocessor: Callable[[BatchT], tuple[SequenceBatch, SequenceBatch]], - postprocessor: Callable[[Any, SequenceBatch], tuple[list[str], list[str]]], + preprocessor: Callable[[BatchT], tuple[Any, Any]], + postprocessor: Callable[[Any, Any], tuple[Any, Any]], dp_gang: Gang | None = None, tp_gang: Gang | None = None, tb_dir: Path | None = None, From 721baa4cc32dfd5ad51f26256387f8e926e55a24 Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Thu, 15 Aug 2024 03:23:15 +0000 Subject: [PATCH 06/12] Introduce AsrDatasetConfig to handle different ASR datasets --- src/fairseq2/recipes/hg/asr_eval.py | 75 +++++++++++++++++++++++------ 1 file changed, 61 insertions(+), 14 deletions(-) diff --git a/src/fairseq2/recipes/hg/asr_eval.py b/src/fairseq2/recipes/hg/asr_eval.py index 15d55a0ec..c9e4c8951 100644 --- a/src/fairseq2/recipes/hg/asr_eval.py +++ b/src/fairseq2/recipes/hg/asr_eval.py @@ -6,9 +6,9 @@ import itertools from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path -from typing import Any, List, Optional, cast +from typing import Any, List, Optional, Union, cast import torch from datasets import ( # type: ignore[attr-defined,import-untyped,import-not-found] @@ -33,22 +33,64 @@ log = get_log_writer(__name__) +@dataclass +class AsrDatasetConfig: + """Configuration for an automatic speech recognition dataset.""" + + dataset_path: str + """The name of the dataset.""" + + dataset_name: Optional[str] = None + + source_column: List[str] = field(default_factory=list) + """The path of the column containing the source audio.""" + + target_column: List[str] = field(default_factory=list) + """The path of the column containing the target text.""" + + tokenizer_name: Optional[str] = None + """The name of the tokenizer to use.""" + + @classmethod + def from_dict(cls, config_dict: dict) -> 'AsrDatasetConfig': + """Create an AsrDatasetConfig instance from a configuration dictionary.""" + return cls( + dataset_path=config_dict.get('dataset_path', ''), + dataset_name=config_dict.get('dataset_name'), + source_column=config_dict.get('source_column', []), + target_column=config_dict.get('target_column', []), + tokenizer_name=config_dict.get('tokenizer_name') + ) + + def get_source_data(self, ds: dict) -> Union[list, dict]: + """Retrieve the source (audio) data from the dataset.""" + return self._get_data(ds, self.source_column) + + def get_target_data(self, ds: dict) -> Union[list, dict]: + """Retrieve the target (text) data from the dataset.""" + return self._get_data(ds, self.target_column) + + @staticmethod + def _get_data(ds: dict, path: List[str]) -> Union[list, dict]: + """Retrieve data from the dataset using the specified path.""" + current = ds + for key in path: + if key in current: + current = current[key] + else: + raise ValueError(f"Invalid path: {path}") + return current @dataclass(kw_only=True) class AsrEvalConfig: """Holds the configuration of a ASR evaluation recipe.""" - # Data - dataset_name: str + dataset_config: AsrDatasetConfig """The HF dataset to evaluate with.""" - # Model model_name: str """The name of the model to evaluate.""" - tokenizer_name: str = "librispeech_asr" - """The tokenizer to use.""" - split: str = "test" """The name of the dataset split to evaluate with.""" @@ -97,13 +139,18 @@ class EvalSeqBatch: @asr_eval_preset("default_asr") def _default_asr_config() -> AsrEvalConfig: return AsrEvalConfig( - dataset_name="librispeech_asr", + dataset_config=AsrDatasetConfig.from_dict({ + 'dataset_path': 'librispeech_asr', + 'source_column': ['audio', 'array'], + 'target_column': ['text'], + 'tokenizer_name': 'librispeech_asr', + }), model_name="wav2vec2_asr_base_10h", split="test.other", ) -def extract_features(example: Example) -> Example: +def extract_features(example: Example, dataset_config: AsrDatasetConfig) -> Example: """ Preprocesses an individual example by converting the audio array to a PyTorch tensor and encoding the text. @@ -115,7 +162,7 @@ def extract_features(example: Example) -> Example: Returns: dict: A dictionary with "audio" and "text" as PyTorch tensors. """ - return {"audio": example["audio"]["array"], "text": example["text"].lower()} + return {"audio": dataset_config.get_source_data(example), "text": dataset_config.get_target_data(example).lower()} def to_batch(examples: Example, model_type: str, device: Device) -> EvalSeqBatch: @@ -150,12 +197,12 @@ def to_batch(examples: Example, model_type: str, device: Device) -> EvalSeqBatch def prepare_dataset( config: AsrEvalConfig, processor: Optional[Callable[[Example], Example]] = None ) -> Dataset: - iterable_ds = load_dataset(config.dataset_name, split=config.split, streaming=True) + iterable_ds = load_dataset(path=config.dataset_config.dataset_path, name=config.dataset_config.dataset_name, split=config.split, streaming=True) ds = Dataset.from_generator( lambda: itertools.islice(iterable_ds, 0, config.max_samples), features=iterable_ds.features, ) - ds = ds.map(lambda x: extract_features(x)) + ds = ds.map(lambda x: extract_features(x, config.dataset_config)) if processor is not None: ds = ds.map(processor) @@ -212,7 +259,7 @@ def load_asr_evaluator( else: init_device = META - tokenizer = load_text_tokenizer(config.tokenizer_name) + tokenizer = load_text_tokenizer(config.dataset_config.tokenizer_name) pipeline_reader = create_hf_reader( dataset=ds, From 1f6772c77e896361b43e0362a2b14e08d629ab10 Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Thu, 15 Aug 2024 03:53:37 +0000 Subject: [PATCH 07/12] lint and fix type hints --- src/fairseq2/recipes/hg/asr_eval.py | 73 ++++++++++++++++++---------- src/fairseq2/recipes/hg/evaluator.py | 1 - 2 files changed, 48 insertions(+), 26 deletions(-) diff --git a/src/fairseq2/recipes/hg/asr_eval.py b/src/fairseq2/recipes/hg/asr_eval.py index c9e4c8951..4268e9475 100644 --- a/src/fairseq2/recipes/hg/asr_eval.py +++ b/src/fairseq2/recipes/hg/asr_eval.py @@ -7,8 +7,9 @@ import itertools from collections.abc import Callable from dataclasses import dataclass, field +from functools import partial from pathlib import Path -from typing import Any, List, Optional, Union, cast +from typing import Any, List, Optional, Union import torch from datasets import ( # type: ignore[attr-defined,import-untyped,import-not-found] @@ -17,7 +18,6 @@ ) from fairseq2.config_registry import ConfigRegistry -from fairseq2.data.data_pipeline import SequenceData from fairseq2.data.text import load_text_tokenizer from fairseq2.data.text.text_tokenizer import TextTokenizer from fairseq2.datasets.batching import StaticBatching @@ -33,6 +33,7 @@ log = get_log_writer(__name__) + @dataclass class AsrDatasetConfig: """Configuration for an automatic speech recognition dataset.""" @@ -52,26 +53,32 @@ class AsrDatasetConfig: """The name of the tokenizer to use.""" @classmethod - def from_dict(cls, config_dict: dict) -> 'AsrDatasetConfig': + def from_dict(cls, config_dict: dict[str, Any]) -> "AsrDatasetConfig": """Create an AsrDatasetConfig instance from a configuration dictionary.""" return cls( - dataset_path=config_dict.get('dataset_path', ''), - dataset_name=config_dict.get('dataset_name'), - source_column=config_dict.get('source_column', []), - target_column=config_dict.get('target_column', []), - tokenizer_name=config_dict.get('tokenizer_name') + dataset_path=config_dict.get("dataset_path", ""), + dataset_name=config_dict.get("dataset_name"), + source_column=config_dict.get("source_column", []), + target_column=config_dict.get("target_column", []), + tokenizer_name=config_dict.get("tokenizer_name"), ) - def get_source_data(self, ds: dict) -> Union[list, dict]: + def get_source_data(self, ds: Example) -> List[int]: """Retrieve the source (audio) data from the dataset.""" - return self._get_data(ds, self.source_column) + results = self._get_data(ds, self.source_column) + if not isinstance(results, list): + raise ValueError(f"Invalid source data: {results}") + return results - def get_target_data(self, ds: dict) -> Union[list, dict]: + def get_target_data(self, ds: Example) -> str: """Retrieve the target (text) data from the dataset.""" - return self._get_data(ds, self.target_column) + results = self._get_data(ds, self.target_column) + if not isinstance(results, str): + raise ValueError(f"Invalid target data: {results}") + return results @staticmethod - def _get_data(ds: dict, path: List[str]) -> Union[list, dict]: + def _get_data(ds: Example, path: List[str]) -> Union[Example, List[int], str]: """Retrieve data from the dataset using the specified path.""" current = ds for key in path: @@ -81,6 +88,7 @@ def _get_data(ds: dict, path: List[str]) -> Union[list, dict]: raise ValueError(f"Invalid path: {path}") return current + @dataclass(kw_only=True) class AsrEvalConfig: """Holds the configuration of a ASR evaluation recipe.""" @@ -139,12 +147,14 @@ class EvalSeqBatch: @asr_eval_preset("default_asr") def _default_asr_config() -> AsrEvalConfig: return AsrEvalConfig( - dataset_config=AsrDatasetConfig.from_dict({ - 'dataset_path': 'librispeech_asr', - 'source_column': ['audio', 'array'], - 'target_column': ['text'], - 'tokenizer_name': 'librispeech_asr', - }), + dataset_config=AsrDatasetConfig.from_dict( + { + "dataset_path": "librispeech_asr", + "source_column": ["audio", "array"], + "target_column": ["text"], + "tokenizer_name": "librispeech_asr", + } + ), model_name="wav2vec2_asr_base_10h", split="test.other", ) @@ -162,7 +172,10 @@ def extract_features(example: Example, dataset_config: AsrDatasetConfig) -> Exam Returns: dict: A dictionary with "audio" and "text" as PyTorch tensors. """ - return {"audio": dataset_config.get_source_data(example), "text": dataset_config.get_target_data(example).lower()} + return { + "audio": dataset_config.get_source_data(example), + "text": dataset_config.get_target_data(example).lower(), + } def to_batch(examples: Example, model_type: str, device: Device) -> EvalSeqBatch: @@ -185,7 +198,7 @@ def to_batch(examples: Example, model_type: str, device: Device) -> EvalSeqBatch else: raise ValueError(f"Unknown model type: {model_type}") - target_seqs = examples['text'] + target_seqs = examples["text"] return EvalSeqBatch( source_seqs, @@ -197,7 +210,12 @@ def to_batch(examples: Example, model_type: str, device: Device) -> EvalSeqBatch def prepare_dataset( config: AsrEvalConfig, processor: Optional[Callable[[Example], Example]] = None ) -> Dataset: - iterable_ds = load_dataset(path=config.dataset_config.dataset_path, name=config.dataset_config.dataset_name, split=config.split, streaming=True) + iterable_ds = load_dataset( + path=config.dataset_config.dataset_path, + name=config.dataset_config.dataset_name, + split=config.split, + streaming=True, + ) ds = Dataset.from_generator( lambda: itertools.islice(iterable_ds, 0, config.max_samples), features=iterable_ds.features, @@ -217,7 +235,10 @@ def prepare_dataset( def evaluator_preprocessor(batch: EvalSeqBatch) -> tuple[SequenceBatch, List[str]]: - return SequenceBatch(batch.source_seqs, batch.source_padding_mask), batch.target_seqs + return ( + SequenceBatch(batch.source_seqs, batch.source_padding_mask), + batch.target_seqs, + ) def evaluator_postprocesser( @@ -259,12 +280,14 @@ def load_asr_evaluator( else: init_device = META + if config.dataset_config.tokenizer_name is None: + raise ValueError("Tokenizer name is not provided but required.") tokenizer = load_text_tokenizer(config.dataset_config.tokenizer_name) pipeline_reader = create_hf_reader( dataset=ds, gang=gang, - converter=lambda x: to_batch(x, "wav2vec2", init_device), + converter=partial(to_batch, model_type="wav2vec2", device=init_device), batching=StaticBatching(config.max_num_elements), num_prefetch=config.num_prefetch, pad_value=tokenizer.vocab_info.pad_idx, @@ -284,5 +307,5 @@ def load_asr_evaluator( data_reader=pipeline_reader, wall_watch=wall_watch, preprocessor=evaluator_preprocessor, - postprocessor=lambda x, y: evaluator_postprocesser(x, y, tokenizer), + postprocessor=partial(evaluator_postprocesser, tokenizer=tokenizer), ) diff --git a/src/fairseq2/recipes/hg/evaluator.py b/src/fairseq2/recipes/hg/evaluator.py index e11acae59..7b234f36f 100644 --- a/src/fairseq2/recipes/hg/evaluator.py +++ b/src/fairseq2/recipes/hg/evaluator.py @@ -24,7 +24,6 @@ record_metrics, ) from fairseq2.models.model import Model -from fairseq2.models.sequence import SequenceBatch from fairseq2.recipes.utils.cli import create_rich_progress from fairseq2.utils.profiler import Stopwatch From 45e65710996ecd020cf3d897e7fbae0a0c13a3f9 Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Thu, 15 Aug 2024 04:24:32 +0000 Subject: [PATCH 08/12] move split to AsrDatasetConfig and move tokenizer back to AsrEvalConfig --- src/fairseq2/recipes/hg/asr_eval.py | 41 +++++++++++++---------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/src/fairseq2/recipes/hg/asr_eval.py b/src/fairseq2/recipes/hg/asr_eval.py index 4268e9475..f9b8ee9d8 100644 --- a/src/fairseq2/recipes/hg/asr_eval.py +++ b/src/fairseq2/recipes/hg/asr_eval.py @@ -42,6 +42,10 @@ class AsrDatasetConfig: """The name of the dataset.""" dataset_name: Optional[str] = None + """The name of the dataset split.""" + + split: str = "test" + """The name of the dataset split to evaluate with.""" source_column: List[str] = field(default_factory=list) """The path of the column containing the source audio.""" @@ -49,9 +53,6 @@ class AsrDatasetConfig: target_column: List[str] = field(default_factory=list) """The path of the column containing the target text.""" - tokenizer_name: Optional[str] = None - """The name of the tokenizer to use.""" - @classmethod def from_dict(cls, config_dict: dict[str, Any]) -> "AsrDatasetConfig": """Create an AsrDatasetConfig instance from a configuration dictionary.""" @@ -60,21 +61,17 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "AsrDatasetConfig": dataset_name=config_dict.get("dataset_name"), source_column=config_dict.get("source_column", []), target_column=config_dict.get("target_column", []), - tokenizer_name=config_dict.get("tokenizer_name"), + split=config_dict.get("split", "test"), ) - def get_source_data(self, ds: Example) -> List[int]: + def get_source_data(self, ds: Example) -> Any: """Retrieve the source (audio) data from the dataset.""" results = self._get_data(ds, self.source_column) - if not isinstance(results, list): - raise ValueError(f"Invalid source data: {results}") return results - def get_target_data(self, ds: Example) -> str: + def get_target_data(self, ds: Example) -> Any: """Retrieve the target (text) data from the dataset.""" results = self._get_data(ds, self.target_column) - if not isinstance(results, str): - raise ValueError(f"Invalid target data: {results}") return results @staticmethod @@ -99,8 +96,8 @@ class AsrEvalConfig: model_name: str """The name of the model to evaluate.""" - split: str = "test" - """The name of the dataset split to evaluate with.""" + tokenizer_name: Optional[str] = None + """The name of the tokenizer to use.""" min_audio_len: int = 1 """The minimum audio sequence length.""" @@ -147,16 +144,14 @@ class EvalSeqBatch: @asr_eval_preset("default_asr") def _default_asr_config() -> AsrEvalConfig: return AsrEvalConfig( - dataset_config=AsrDatasetConfig.from_dict( - { - "dataset_path": "librispeech_asr", - "source_column": ["audio", "array"], - "target_column": ["text"], - "tokenizer_name": "librispeech_asr", - } + dataset_config=AsrDatasetConfig( + dataset_path="librispeech_asr", + source_column=["audio", "array"], + target_column=["text"], + split="test", ), model_name="wav2vec2_asr_base_10h", - split="test.other", + tokenizer_name="librispeech_asr", ) @@ -213,7 +208,7 @@ def prepare_dataset( iterable_ds = load_dataset( path=config.dataset_config.dataset_path, name=config.dataset_config.dataset_name, - split=config.split, + split=config.dataset_config.split, streaming=True, ) ds = Dataset.from_generator( @@ -280,9 +275,9 @@ def load_asr_evaluator( else: init_device = META - if config.dataset_config.tokenizer_name is None: + if config.tokenizer_name is None: raise ValueError("Tokenizer name is not provided but required.") - tokenizer = load_text_tokenizer(config.dataset_config.tokenizer_name) + tokenizer = load_text_tokenizer(config.tokenizer_name) pipeline_reader = create_hf_reader( dataset=ds, From 7fb74ed9d8c5cf870250529d3e29cdb79931b2e1 Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Thu, 15 Aug 2024 04:32:31 +0000 Subject: [PATCH 09/12] update default split --- src/fairseq2/recipes/hg/asr_eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fairseq2/recipes/hg/asr_eval.py b/src/fairseq2/recipes/hg/asr_eval.py index f9b8ee9d8..1077da02e 100644 --- a/src/fairseq2/recipes/hg/asr_eval.py +++ b/src/fairseq2/recipes/hg/asr_eval.py @@ -148,7 +148,7 @@ def _default_asr_config() -> AsrEvalConfig: dataset_path="librispeech_asr", source_column=["audio", "array"], target_column=["text"], - split="test", + split="test.other", ), model_name="wav2vec2_asr_base_10h", tokenizer_name="librispeech_asr", From dc684f03a6b91f858d8a5f106e30013cdcee7d79 Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Tue, 20 Aug 2024 19:57:44 +0000 Subject: [PATCH 10/12] Refactor AsrDatasetConfig --- src/fairseq2/recipes/hg/asr_eval.py | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/src/fairseq2/recipes/hg/asr_eval.py b/src/fairseq2/recipes/hg/asr_eval.py index 1077da02e..f6b7d2846 100644 --- a/src/fairseq2/recipes/hg/asr_eval.py +++ b/src/fairseq2/recipes/hg/asr_eval.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, field from functools import partial from pathlib import Path -from typing import Any, List, Optional, Union +from typing import Any, List, Optional import torch from datasets import ( # type: ignore[attr-defined,import-untyped,import-not-found] @@ -39,30 +39,19 @@ class AsrDatasetConfig: """Configuration for an automatic speech recognition dataset.""" dataset_path: str - """The name of the dataset.""" + """The path to the dataset.""" dataset_name: Optional[str] = None - """The name of the dataset split.""" + """The name of the dataset configuration.""" split: str = "test" - """The name of the dataset split to evaluate with.""" + """Which split of the data to load.""" source_column: List[str] = field(default_factory=list) - """The path of the column containing the source audio.""" + """The path to the column containing the source audio.""" target_column: List[str] = field(default_factory=list) - """The path of the column containing the target text.""" - - @classmethod - def from_dict(cls, config_dict: dict[str, Any]) -> "AsrDatasetConfig": - """Create an AsrDatasetConfig instance from a configuration dictionary.""" - return cls( - dataset_path=config_dict.get("dataset_path", ""), - dataset_name=config_dict.get("dataset_name"), - source_column=config_dict.get("source_column", []), - target_column=config_dict.get("target_column", []), - split=config_dict.get("split", "test"), - ) + """The path to the column containing the target text.""" def get_source_data(self, ds: Example) -> Any: """Retrieve the source (audio) data from the dataset.""" @@ -75,7 +64,7 @@ def get_target_data(self, ds: Example) -> Any: return results @staticmethod - def _get_data(ds: Example, path: List[str]) -> Union[Example, List[int], str]: + def _get_data(ds: Example, path: List[str]) -> Example | List[int] | str: """Retrieve data from the dataset using the specified path.""" current = ds for key in path: From 667194fd968c82f46ff43e85aaf57f13127e3be7 Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Wed, 4 Sep 2024 12:13:10 +0000 Subject: [PATCH 11/12] refactor code and lint --- src/fairseq2/recipes/hg/asr_eval.py | 48 ++++++++++++----------------- 1 file changed, 19 insertions(+), 29 deletions(-) diff --git a/src/fairseq2/recipes/hg/asr_eval.py b/src/fairseq2/recipes/hg/asr_eval.py index f6b7d2846..1fee7edb0 100644 --- a/src/fairseq2/recipes/hg/asr_eval.py +++ b/src/fairseq2/recipes/hg/asr_eval.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import itertools from collections.abc import Callable from dataclasses import dataclass, field from functools import partial @@ -47,32 +46,22 @@ class AsrDatasetConfig: split: str = "test" """Which split of the data to load.""" - source_column: List[str] = field(default_factory=list) - """The path to the column containing the source audio.""" + source_column_path: List[str] = field(default_factory=list) + """The path of the column containing the source audio.""" - target_column: List[str] = field(default_factory=list) - """The path to the column containing the target text.""" + target_column_path: List[str] = field(default_factory=list) + """The path of the column containing the target text.""" - def get_source_data(self, ds: Example) -> Any: - """Retrieve the source (audio) data from the dataset.""" - results = self._get_data(ds, self.source_column) - return results - def get_target_data(self, ds: Example) -> Any: - """Retrieve the target (text) data from the dataset.""" - results = self._get_data(ds, self.target_column) - return results - - @staticmethod - def _get_data(ds: Example, path: List[str]) -> Example | List[int] | str: - """Retrieve data from the dataset using the specified path.""" - current = ds - for key in path: - if key in current: - current = current[key] - else: - raise ValueError(f"Invalid path: {path}") - return current +def _get_column_data(ds: Example, path: List[str]) -> Union[Example, List[int], str]: + """Retrieve data from the dataset using the specified path.""" + current = ds + for key in path: + if key in current: + current = current[key] + else: + raise ValueError(f"Invalid path: {path}") + return current @dataclass(kw_only=True) @@ -135,8 +124,8 @@ def _default_asr_config() -> AsrEvalConfig: return AsrEvalConfig( dataset_config=AsrDatasetConfig( dataset_path="librispeech_asr", - source_column=["audio", "array"], - target_column=["text"], + source_column_path=["audio", "array"], + target_column_path=["text"], split="test.other", ), model_name="wav2vec2_asr_base_10h", @@ -157,8 +146,8 @@ def extract_features(example: Example, dataset_config: AsrDatasetConfig) -> Exam dict: A dictionary with "audio" and "text" as PyTorch tensors. """ return { - "audio": dataset_config.get_source_data(example), - "text": dataset_config.get_target_data(example).lower(), + "audio": _get_column_data(example, dataset_config.source_column_path), + "text": _get_column_data(example, dataset_config.target_column_path).lower(), } @@ -199,9 +188,10 @@ def prepare_dataset( name=config.dataset_config.dataset_name, split=config.dataset_config.split, streaming=True, + trust_remote_code=True, ) ds = Dataset.from_generator( - lambda: itertools.islice(iterable_ds, 0, config.max_samples), + lambda: iterable_ds.take(config.max_samples), features=iterable_ds.features, ) ds = ds.map(lambda x: extract_features(x, config.dataset_config)) From 5fc5fed77430fefea7d0f9ed0b6d5b9cc9c5fe70 Mon Sep 17 00:00:00 2001 From: Ahmed Saed <37080003+Ahmedsaed@users.noreply.github.com> Date: Wed, 4 Sep 2024 12:24:47 +0000 Subject: [PATCH 12/12] lint --- src/fairseq2/recipes/hg/asr_eval.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/fairseq2/recipes/hg/asr_eval.py b/src/fairseq2/recipes/hg/asr_eval.py index 1fee7edb0..3a119028a 100644 --- a/src/fairseq2/recipes/hg/asr_eval.py +++ b/src/fairseq2/recipes/hg/asr_eval.py @@ -53,7 +53,7 @@ class AsrDatasetConfig: """The path of the column containing the target text.""" -def _get_column_data(ds: Example, path: List[str]) -> Union[Example, List[int], str]: +def _get_column_data(ds: Dataset, path: List[str]) -> Any: """Retrieve data from the dataset using the specified path.""" current = ds for key in path: @@ -145,6 +145,7 @@ def extract_features(example: Example, dataset_config: AsrDatasetConfig) -> Exam Returns: dict: A dictionary with "audio" and "text" as PyTorch tensors. """ + return { "audio": _get_column_data(example, dataset_config.source_column_path), "text": _get_column_data(example, dataset_config.target_column_path).lower(),