Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for huggingface ASR datasets #749

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions src/fairseq2/recipes/hg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,16 @@ def _setup_hg_cli(cli: Cli) -> None:

group = cli.add_group("hg", help="Hugging Face recipes")

from fairseq2.recipes.hg.asr_eval import (
asr_eval_presets,
load_wav2vec2_asr_evaluator,
)
from fairseq2.recipes.hg.asr_eval import asr_eval_presets, load_asr_evaluator

handler = RecipeCommandHandler(
load_wav2vec2_asr_evaluator,
load_asr_evaluator,
preset_configs=asr_eval_presets,
default_preset="librispeech_asr",
default_preset="default_asr",
)

group.add_command(
"wav2vec2_asr",
"asr",
handler,
help="evaluate a wav2vec 2.0 ASR model on a downstream benchmark (default: librispeech_asr)",
help="evaluate an ASR model (default: wav2vec2) on a downstream benchmark (default: librispeech_asr)",
)
219 changes: 134 additions & 85 deletions src/fairseq2/recipes/hg/asr_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import itertools
from dataclasses import dataclass
from functools import lru_cache
from collections.abc import Callable
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Any, cast
from typing import Any, List, Optional

import torch
from datasets import ( # type: ignore[attr-defined,import-untyped,import-not-found]
Expand All @@ -17,44 +17,65 @@
)

from fairseq2.config_registry import ConfigRegistry
from fairseq2.data.data_pipeline import SequenceData
from fairseq2.data.text import load_text_tokenizer
from fairseq2.data.text.text_tokenizer import TextTokenizer
from fairseq2.datasets.batching import StaticBatching
from fairseq2.logging import get_log_writer
from fairseq2.models.seq2seq import Seq2SeqBatch
from fairseq2.models.sequence import SequenceBatch
from fairseq2.models.wav2vec2.asr import load_wav2vec2_asr_model
from fairseq2.nn.padding import get_seqs_and_padding_mask
from fairseq2.nn.padding import PaddingMask, get_seqs_and_padding_mask
from fairseq2.recipes.hg.dataset import Example, create_hf_reader
from fairseq2.recipes.hg.evaluator import HFEvaluator
from fairseq2.recipes.utils.setup import setup_root_gang
from fairseq2.typing import META, DataType
from fairseq2.typing import META, DataType, Device
from fairseq2.utils.profiler import Stopwatch

log = get_log_writer(__name__)


@dataclass
class AsrDatasetConfig:
"""Configuration for an automatic speech recognition dataset."""

dataset_path: str
Ahmedsaed marked this conversation as resolved.
Show resolved Hide resolved
"""The path to the dataset."""

dataset_name: Optional[str] = None
"""The name of the dataset configuration."""

split: str = "test"
"""Which split of the data to load."""

source_column_path: List[str] = field(default_factory=list)
"""The path of the column containing the source audio."""

target_column_path: List[str] = field(default_factory=list)
"""The path of the column containing the target text."""


def _get_column_data(ds: Dataset, path: List[str]) -> Any:
"""Retrieve data from the dataset using the specified path."""
current = ds
for key in path:
if key in current:
current = current[key]
else:
raise ValueError(f"Invalid path: {path}")
return current


@dataclass(kw_only=True)
class AsrEvalConfig:
"""Holds the configuration of a ASR evaluation recipe."""

# Data
dataset_name: str
dataset_config: AsrDatasetConfig
"""The HF dataset to evaluate with."""

# Model
model_name: str
"""The name of the model to evaluate."""

# converter: Callable[[Example], Seq2SeqBatch]
# """The converter function to convert collated data into Seq2SeqBatch"""

tokenizer_name: str = "librispeech_asr"
"""The tokenizer to use."""

split: str = "test"
"""The name of the dataset split to evaluate with."""
tokenizer_name: Optional[str] = None
"""The name of the tokenizer to use."""

min_audio_len: int = 1
"""The minimum audio sequence length."""
Expand Down Expand Up @@ -82,97 +103,135 @@ class AsrEvalConfig:
"""The data type of the model."""


asr_eval_presets = ConfigRegistry[AsrEvalConfig]()
@dataclass
class EvalSeqBatch:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason for not using Seq2SeqBatch here like the native ASR evaluator here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Seq2SeqBatch class expects targets to be tensors, which is useful for training. However, during evaluation, encoding the target only to decode it again without utilizing the tensors seems redundant.

Is there a better approach? another fairseq2 data-structure to use?

source_seqs: torch.Tensor
"""The source sequences."""

source_padding_mask: Optional[PaddingMask]
"""The source padding mask."""

target_seqs: List[str]
"""The target sequences."""


asr_eval_presets = ConfigRegistry[AsrEvalConfig]()
asr_eval_preset = asr_eval_presets.decorator


@asr_eval_preset("librispeech_asr")
def _librispeech_asr_config() -> AsrEvalConfig:
@asr_eval_preset("default_asr")
def _default_asr_config() -> AsrEvalConfig:
return AsrEvalConfig(
dataset_name="librispeech_asr",
dataset_config=AsrDatasetConfig(
dataset_path="librispeech_asr",
source_column_path=["audio", "array"],
target_column_path=["text"],
split="test.other",
),
model_name="wav2vec2_asr_base_10h",
split="test.other",
# converter=librispeech_asr_to_batch,
tokenizer_name="librispeech_asr",
)


def _librispeech_asr_to_batch(examples: Example) -> Seq2SeqBatch:
def extract_features(example: Example, dataset_config: AsrDatasetConfig) -> Example:
"""
Preprocesses an individual example by converting the audio array to a PyTorch tensor
and encoding the text.

Args:
example (dict): A dictionary containing "audio" and "text" keys.
device (torch.device): The device to store the tensors.

Returns:
dict: A dictionary with "audio" and "text" as PyTorch tensors.
"""

return {
"audio": _get_column_data(example, dataset_config.source_column_path),
"text": _get_column_data(example, dataset_config.target_column_path).lower(),
}


def to_batch(examples: Example, model_type: str, device: Device) -> EvalSeqBatch:
"""
Converts a collated batch of examples into a Seq2SeqBatch.
Converts a collated batch of examples into a EvalSeqBatch.

Args:
examples (dict): A dictionary containing "audio" and "text" keys.

Returns:
Seq2SeqBatch: A batch of audio and text sequences.
EvalSeqBatch: A batch of audio and text sequences.
"""
source_data = cast(SequenceData, examples["audio"])
target_data = cast(SequenceData, examples["text"])

source_seqs, source_padding_mask = get_seqs_and_padding_mask(source_data)
target_seqs, target_padding_mask = get_seqs_and_padding_mask(target_data)
if model_type == "wav2vec2":
Ahmedsaed marked this conversation as resolved.
Show resolved Hide resolved
source_seqs, source_padding_mask = get_seqs_and_padding_mask(examples["audio"])
source_seqs = source_seqs.to(device)
source_padding_mask = (
source_padding_mask.to(device) if source_padding_mask is not None else None
)
else:
raise ValueError(f"Unknown model type: {model_type}")

target_seqs = examples["text"]

return Seq2SeqBatch(
return EvalSeqBatch(
source_seqs,
source_padding_mask,
target_seqs,
target_padding_mask,
examples,
)


@lru_cache(maxsize=None)
def get_cached_tokenizer(tokenizer_name: str) -> TextTokenizer:
return load_text_tokenizer(tokenizer_name)

def prepare_dataset(
config: AsrEvalConfig, processor: Optional[Callable[[Example], Example]] = None
) -> Dataset:
iterable_ds = load_dataset(
path=config.dataset_config.dataset_path,
name=config.dataset_config.dataset_name,
split=config.dataset_config.split,
streaming=True,
trust_remote_code=True,
)
ds = Dataset.from_generator(
lambda: iterable_ds.take(config.max_samples),
features=iterable_ds.features,
)
ds = ds.map(lambda x: extract_features(x, config.dataset_config))

def _preprocess_example(
example: Example, tokenizer_name: str, device: torch.device
) -> Example:
"""
Preprocesses an individual example by converting the audio array to a PyTorch tensor
and encoding the text.
if processor is not None:
ds = ds.map(processor)

Args:
example (dict): A dictionary containing "audio" and "text" keys.
tokenizer_name (str): The name of the tokenizer to use.
device (torch.device): The device to store the tensors.
format = {
"type": "torch",
"format_kwargs": {"dtype": config.dtype},
}
ds.set_format(**format, columns=["audio", "text"])

Returns:
dict: A dictionary with "audio" and "text" as PyTorch tensors.
"""
tokenizer = get_cached_tokenizer(tokenizer_name)
encoder = tokenizer.create_encoder(device=device)
audio_tensor = (
torch.from_numpy(example["audio"]["array"]).to(torch.float16).to(device)
)
text_tensor = encoder(example["text"].lower()).to(device)
return {"audio": audio_tensor, "text": text_tensor}
return ds


def seq2seq_preprocessor(batch: Seq2SeqBatch) -> tuple[SequenceBatch, SequenceBatch]:
return SequenceBatch(batch.source_seqs, batch.source_padding_mask), SequenceBatch(
batch.target_seqs, batch.target_padding_mask
def evaluator_preprocessor(batch: EvalSeqBatch) -> tuple[SequenceBatch, List[str]]:
return (
SequenceBatch(batch.source_seqs, batch.source_padding_mask),
batch.target_seqs,
)


def postprocesser(
outputs: Any, targets: SequenceBatch, tokenizer: TextTokenizer
def evaluator_postprocesser(
outputs: Any, targets: list[str], tokenizer: TextTokenizer
) -> tuple[list[str], list[str]]:
decoder = tokenizer.create_decoder()
pad_idx = tokenizer.vocab_info.pad_idx

hypotheses, _ = outputs.generate_hypotheses(pad_idx=pad_idx)
predictions = [decoder(item) for item in hypotheses]
references = [decoder(item) for item in targets.seqs.to(torch.int32)]
references = targets

return predictions, references


def load_wav2vec2_asr_evaluator(
def load_asr_evaluator(
config: AsrEvalConfig, output_dir: Path
) -> HFEvaluator[Seq2SeqBatch]:
) -> HFEvaluator[EvalSeqBatch]:
"""
Load the evaluator used for downstream evaluation of the model
in a downstream dataset and report BLEU scores
Expand All @@ -187,12 +246,7 @@ def load_wav2vec2_asr_evaluator(
if not isinstance(config, AsrEvalConfig):
raise ValueError(f"Expect AsrEvalConfig, get {type(config)}")

iterable_ds = load_dataset(config.dataset_name, split=config.split, streaming=True)
# Load a subset of the dataset if max_samples is set
ds = Dataset.from_generator(
lambda: itertools.islice(iterable_ds, 0, config.max_samples),
features=iterable_ds.features,
)
ds = prepare_dataset(config)

gang = setup_root_gang(log)

Expand All @@ -201,19 +255,14 @@ def load_wav2vec2_asr_evaluator(
else:
init_device = META

ds = ds.map(lambda x: _preprocess_example(x, config.tokenizer_name, init_device))
format = {
"type": "torch",
"format_kwargs": {"dtype": torch.float16, "device": init_device},
}
ds.set_format(**format, columns=["audio", "text"])

tokenizer = get_cached_tokenizer(config.tokenizer_name)
if config.tokenizer_name is None:
raise ValueError("Tokenizer name is not provided but required.")
tokenizer = load_text_tokenizer(config.tokenizer_name)

pipeline_reader = create_hf_reader(
dataset=ds,
gang=gang,
converter=_librispeech_asr_to_batch,
converter=partial(to_batch, model_type="wav2vec2", device=init_device),
batching=StaticBatching(config.max_num_elements),
num_prefetch=config.num_prefetch,
pad_value=tokenizer.vocab_info.pad_idx,
Expand All @@ -226,12 +275,12 @@ def load_wav2vec2_asr_evaluator(

wall_watch = Stopwatch(start=True, device=init_device)

return HFEvaluator[Seq2SeqBatch](
return HFEvaluator[EvalSeqBatch](
model=model,
metrics=["bleu"],
gang=gang,
data_reader=pipeline_reader,
wall_watch=wall_watch,
preprocessor=seq2seq_preprocessor,
postprocessor=lambda x, y: postprocesser(x, y, tokenizer),
preprocessor=evaluator_preprocessor,
postprocessor=partial(evaluator_postprocesser, tokenizer=tokenizer),
)
Loading
Loading