Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Whisper model with hg evaluation CLI interface #740

Closed
wants to merge 10 commits into from
13 changes: 5 additions & 8 deletions src/fairseq2/recipes/hg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,16 @@ def _setup_hg_cli(cli: Cli) -> None:

group = cli.add_group("hg", help="Hugging Face recipes")

from fairseq2.recipes.hg.asr_eval import (
asr_eval_presets,
load_wav2vec2_asr_evaluator,
)
from fairseq2.recipes.hg.asr_eval import asr_eval_presets, load_asr_evaluator

handler = RecipeCommandHandler(
load_wav2vec2_asr_evaluator,
load_asr_evaluator,
preset_configs=asr_eval_presets,
default_preset="librispeech_asr",
default_preset="default_asr",
)

group.add_command(
"wav2vec2_asr",
"asr",
handler,
help="evaluate a wav2vec 2.0 ASR model on a downstream benchmark (default: librispeech_asr)",
help="evaluate an ASR model (default: wav2vec2) on a downstream benchmark (default: librispeech_asr)",
)
226 changes: 169 additions & 57 deletions src/fairseq2/recipes/hg/asr_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,39 @@
# LICENSE file in the root directory of this source tree.

import itertools
from collections.abc import Callable
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any, cast
from typing import Any, Optional, cast

import torch
from datasets import ( # type: ignore[attr-defined,import-untyped,import-not-found]
Dataset,
load_dataset,
load_dataset_builder,
)
from transformers import ( # type: ignore[attr-defined,import-untyped,import-not-found]
WhisperForConditionalGeneration,
WhisperProcessor,
)

from fairseq2.assets.metadata_provider import AssetNotFoundError
from fairseq2.config_registry import ConfigRegistry
from fairseq2.data.data_pipeline import SequenceData
from fairseq2.data.text import load_text_tokenizer
from fairseq2.data.text.text_tokenizer import TextTokenizer
from fairseq2.datasets.batching import StaticBatching
from fairseq2.logging import get_log_writer
from fairseq2.models.model import Model
from fairseq2.models.seq2seq import Seq2SeqBatch
from fairseq2.models.sequence import SequenceBatch
from fairseq2.models.wav2vec2.asr import load_wav2vec2_asr_model
from fairseq2.nn.padding import get_seqs_and_padding_mask
from fairseq2.recipes.hg.dataset import Example, create_hf_reader
from fairseq2.recipes.hg.evaluator import HFEvaluator
from fairseq2.recipes.utils.asset import retrieve_asset_card
from fairseq2.recipes.utils.setup import setup_root_gang
from fairseq2.typing import META, DataType
from fairseq2.typing import META, DataType, Device
from fairseq2.utils.profiler import Stopwatch

log = get_log_writer(__name__)
Expand All @@ -47,9 +55,6 @@ class AsrEvalConfig:
model_name: str
"""The name of the model to evaluate."""

# converter: Callable[[Example], Seq2SeqBatch]
# """The converter function to convert collated data into Seq2SeqBatch"""

tokenizer_name: str = "librispeech_asr"
"""The tokenizer to use."""

Expand Down Expand Up @@ -83,21 +88,32 @@ class AsrEvalConfig:


asr_eval_presets = ConfigRegistry[AsrEvalConfig]()

asr_eval_preset = asr_eval_presets.decorator


@asr_eval_preset("librispeech_asr")
def _librispeech_asr_config() -> AsrEvalConfig:
@asr_eval_preset("default_asr")
def _default_asr_config() -> AsrEvalConfig:
return AsrEvalConfig(
dataset_name="librispeech_asr",
model_name="wav2vec2_asr_base_10h",
split="test.other",
# converter=librispeech_asr_to_batch,
)


def _librispeech_asr_to_batch(examples: Example) -> Seq2SeqBatch:
def extract_features(example: Example) -> Example:
"""
Extracts source and target features from dataset examples.

Args:
example (dict): A dictionary containing source and target keys.

Returns:
dict: A dictionary containing the extracted features.
"""
return {"audio": example["audio"]["array"], "text": example["text"].lower()}


def to_batch(examples: Example, model_type: str, device: Device) -> Seq2SeqBatch:
"""
Converts a collated batch of examples into a Seq2SeqBatch.

Expand All @@ -108,10 +124,22 @@ def _librispeech_asr_to_batch(examples: Example) -> Seq2SeqBatch:
Seq2SeqBatch: A batch of audio and text sequences.
"""
source_data = cast(SequenceData, examples["audio"])
target_data = cast(SequenceData, examples["text"])
target_data = cast(torch.Tensor, examples["text"])

if model_type == "wav2vec2":
source_seqs, source_padding_mask = get_seqs_and_padding_mask(source_data)
source_seqs = source_seqs.to(device)
source_padding_mask = (
source_padding_mask.to(device) if source_padding_mask is not None else None
)
elif model_type == "whisper":
source_seqs = cast(torch.Tensor, source_data).to(device)
source_padding_mask = None
else:
raise ValueError(f"Unknown model type: {model_type}")

source_seqs, source_padding_mask = get_seqs_and_padding_mask(source_data)
target_seqs, target_padding_mask = get_seqs_and_padding_mask(target_data)
target_seqs = target_data
target_padding_mask = None

return Seq2SeqBatch(
source_seqs,
Expand All @@ -122,77 +150,95 @@ def _librispeech_asr_to_batch(examples: Example) -> Seq2SeqBatch:
)


@lru_cache(maxsize=None)
def get_cached_tokenizer(tokenizer_name: str) -> TextTokenizer:
return load_text_tokenizer(tokenizer_name)


def _preprocess_example(
example: Example, tokenizer_name: str, device: torch.device
) -> Example:
def prepare_dataset(
config: AsrEvalConfig, processor: Optional[Callable[[Example], Example]] = None
) -> Dataset:
"""
Preprocesses an individual example by converting the audio array to a PyTorch tensor
and encoding the text.
Prepares a dataset for evaluation. The dataset is loaded from the
HF datasets and preprocessed using the provided processor.

Args:
example (dict): A dictionary containing "audio" and "text" keys.
tokenizer_name (str): The name of the tokenizer to use.
device (torch.device): The device to store the tensors.
config (AsrEvalConfig): The configuration for the evaluation.
processor (Callable): A function to preprocess examples.

Returns:
dict: A dictionary with "audio" and "text" as PyTorch tensors.
Dataset: The prepared dataset.
"""
tokenizer = get_cached_tokenizer(tokenizer_name)
encoder = tokenizer.create_encoder(device=device)
audio_tensor = (
torch.from_numpy(example["audio"]["array"]).to(torch.float16).to(device)
iterable_ds = load_dataset(config.dataset_name, split=config.split, streaming=True)
ds = Dataset.from_generator(
lambda: itertools.islice(iterable_ds, 0, config.max_samples),
features=iterable_ds.features,
)
text_tensor = encoder(example["text"].lower()).to(device)
return {"audio": audio_tensor, "text": text_tensor}
ds = ds.map(lambda x: extract_features(x))

if processor is not None:
ds = ds.map(processor)

format = {
"type": "torch",
"format_kwargs": {"dtype": config.dtype},
}
ds.set_format(**format, columns=["audio", "text"])

return ds


def seq2seq_preprocessor(batch: Seq2SeqBatch) -> tuple[SequenceBatch, SequenceBatch]:
def evaluator_preprocessor(batch: Seq2SeqBatch) -> tuple[SequenceBatch, SequenceBatch]:
return SequenceBatch(batch.source_seqs, batch.source_padding_mask), SequenceBatch(
batch.target_seqs, batch.target_padding_mask
)


def postprocesser(
def evaluator_postprocesser(
outputs: Any, targets: SequenceBatch, tokenizer: TextTokenizer
) -> tuple[list[str], list[str]]:
decoder = tokenizer.create_decoder()
pad_idx = tokenizer.vocab_info.pad_idx

hypotheses, _ = outputs.generate_hypotheses(pad_idx=pad_idx)
predictions = [decoder(item) for item in hypotheses]
references = [decoder(item) for item in targets.seqs.to(torch.int32)]
references = cast(list[str], targets.seqs)
Ahmedsaed marked this conversation as resolved.
Show resolved Hide resolved

return predictions, references


def load_asr_evaluator(
config: AsrEvalConfig, output_dir: Path
) -> HFEvaluator[Seq2SeqBatch]:
"""
Load the evaluator based on the model type.

Args:
config (AsrEvalConfig): The configuration for the evaluation.
output_dir (Path): The output directory to store the evaluation results.

Returns:
HFEvaluator: Evaluation process.
"""
try:
retrieve_asset_card(config.model_name)
return load_wav2vec2_asr_evaluator(config, output_dir)
except AssetNotFoundError:
return load_hg_asr_evaluator(config, output_dir)


def load_wav2vec2_asr_evaluator(
config: AsrEvalConfig, output_dir: Path
Ahmedsaed marked this conversation as resolved.
Show resolved Hide resolved
) -> HFEvaluator[Seq2SeqBatch]:
"""
Load the evaluator used for downstream evaluation of the model
in a downstream dataset and report BLEU scores
Load the evaluator used for downstream evaluation of the wav2vec2 model.

Args:
config (HFEvalConfig): The configuration for the evaluation.
output_dir (Path): The output directory to store the evaluation results.

Returns:
HFEvaluator: Evaluation process results.
HFEvaluator: Evaluation process.
"""
if not isinstance(config, AsrEvalConfig):
raise ValueError(f"Expect AsrEvalConfig, get {type(config)}")

iterable_ds = load_dataset(config.dataset_name, split=config.split, streaming=True)
# Load a subset of the dataset if max_samples is set
ds = Dataset.from_generator(
lambda: itertools.islice(iterable_ds, 0, config.max_samples),
features=iterable_ds.features,
)
ds = prepare_dataset(config)

gang = setup_root_gang(log)

Expand All @@ -201,19 +247,12 @@ def load_wav2vec2_asr_evaluator(
else:
init_device = META

ds = ds.map(lambda x: _preprocess_example(x, config.tokenizer_name, init_device))
format = {
"type": "torch",
"format_kwargs": {"dtype": torch.float16, "device": init_device},
}
ds.set_format(**format, columns=["audio", "text"])

tokenizer = get_cached_tokenizer(config.tokenizer_name)
tokenizer = load_text_tokenizer(config.tokenizer_name)

pipeline_reader = create_hf_reader(
dataset=ds,
gang=gang,
converter=_librispeech_asr_to_batch,
converter=lambda x: to_batch(x, "wav2vec2", init_device),
batching=StaticBatching(config.max_num_elements),
num_prefetch=config.num_prefetch,
pad_value=tokenizer.vocab_info.pad_idx,
Expand All @@ -232,6 +271,79 @@ def load_wav2vec2_asr_evaluator(
gang=gang,
data_reader=pipeline_reader,
wall_watch=wall_watch,
preprocessor=seq2seq_preprocessor,
postprocessor=lambda x, y: postprocesser(x, y, tokenizer),
preprocessor=evaluator_preprocessor,
postprocessor=lambda x, y: evaluator_postprocesser(x, y, tokenizer),
Ahmedsaed marked this conversation as resolved.
Show resolved Hide resolved
)


class HGModelWrapper:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thought: A better way to handle this is to have a HG model loader with HG model config, where we can specify the transformers class and the preprocessor class. But this can be left to another PR

def __init__(self, model: WhisperForConditionalGeneration):
Ahmedsaed marked this conversation as resolved.
Show resolved Hide resolved
self.model = model

def __call__(self, batch: SequenceBatch) -> Any:
return self.model.generate(batch.seqs)


def load_hg_asr_evaluator(
config: AsrEvalConfig, output_dir: Path
) -> HFEvaluator[Seq2SeqBatch]:
"""
Load the evaluator used for downstream evaluation of the whisper model

Args:
config (HFEvalConfig): The configuration for the evaluation.
output_dir (Path): The output directory to store the evaluation results.

Returns:
HFEvaluator: Evaluation process.
"""
if not isinstance(config, AsrEvalConfig):
raise ValueError(f"Expect AsrEvalConfig, get {type(config)}")

gang = setup_root_gang(log)

if gang.rank == 0:
init_device = gang.device
else:
init_device = META

processor = WhisperProcessor.from_pretrained(config.model_name)
Ahmedsaed marked this conversation as resolved.
Show resolved Hide resolved
model = WhisperForConditionalGeneration.from_pretrained(config.model_name).to(
init_device
)

ds_builder = load_dataset_builder(config.dataset_name)
ds = prepare_dataset(
config,
lambda x: {
Ahmedsaed marked this conversation as resolved.
Show resolved Hide resolved
"audio": processor(
x["audio"],
sampling_rate=ds_builder.info.features["audio"].sampling_rate,
return_tensors="pt",
).input_features.squeeze(0)
},
)

pipeline_reader = create_hf_reader(
dataset=ds,
gang=gang,
converter=lambda x: to_batch(x, "whisper", init_device),
batching=StaticBatching(config.max_num_elements),
num_prefetch=config.num_prefetch,
max_seq_len=config.max_audio_len,
)

wall_watch = Stopwatch(start=True, device=init_device)

return HFEvaluator[Seq2SeqBatch](
model=cast(Model, HGModelWrapper(model)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my other comment.

This forced casting is artificial IMHO - The better option is to create a model loader that creates a HG model out of its name. But let's make it in a separate PR.

metrics=["bleu"],
gang=gang,
data_reader=pipeline_reader,
wall_watch=wall_watch,
preprocessor=evaluator_preprocessor,
postprocessor=lambda x, y: (
Ahmedsaed marked this conversation as resolved.
Show resolved Hide resolved
processor.batch_decode(x, skip_special_tokens=True),
cast(list[str], y.seqs),
),
)
7 changes: 7 additions & 0 deletions src/fairseq2/recipes/hg/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from pathlib import Path
from typing import Any, Generic, TypeVar, final

import torch

from fairseq2.datasets import DataReader
from fairseq2.gang import FakeGang, Gang
from fairseq2.logging import get_log_writer
Expand Down Expand Up @@ -172,6 +174,11 @@ def _do_run(self) -> None:
predictions=predictions, references=references
)

del inputs
del targets
del outputs
torch.cuda.empty_cache()

self._root_gang.barrier()

self._elapsed_time = watch.get_elapsed_time()
Expand Down
Loading