Skip to content

Commit

Permalink
Merge branch 'main' into hg/AsrDatasetConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
Ahmedsaed authored Sep 4, 2024
2 parents c645950 + 379e4c5 commit 1485f6d
Show file tree
Hide file tree
Showing 51 changed files with 1,706 additions and 315 deletions.
3 changes: 3 additions & 0 deletions native/python/src/fairseq2n/bindings/data/data_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,13 +484,15 @@ def_data_pipeline(py::module_ &data_module)
data_pipeline_builder &self,
float64 threshold,
cost_fn fn,
std::optional<bucket_creation_fn> maybe_bucket_fn,
std::optional<std::size_t> maybe_min_num_examples,
std::optional<std::size_t> maybe_max_num_examples,
bool drop_remainder) -> data_pipeline_builder &
{
self = std::move(self).dynamic_bucket(
threshold,
std::move(fn),
std::move(maybe_bucket_fn),
maybe_min_num_examples,
maybe_max_num_examples,
drop_remainder);
Expand All @@ -499,6 +501,7 @@ def_data_pipeline(py::module_ &data_module)
},
py::arg("threshold"),
py::arg("fn"),
py::arg("bucket_creation_fn") = std::nullopt,
py::arg("min_num_examples") = std::nullopt,
py::arg("max_num_examples") = std::nullopt,
py::arg("drop_remainder") = false)
Expand Down
8 changes: 7 additions & 1 deletion native/src/fairseq2n/data/data_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ data_pipeline_builder
data_pipeline_builder::dynamic_bucket(
float64 threshold,
cost_fn fn,
std::optional<bucket_creation_fn> maybe_bucket_fn,
std::optional<std::size_t> maybe_min_num_examples,
std::optional<std::size_t> maybe_max_num_examples,
bool drop_remainder) &&
Expand All @@ -421,12 +422,17 @@ data_pipeline_builder::dynamic_bucket(
throw_<std::invalid_argument>("`max_num_examples` must be greater than or equal to `min_num_examples`.");
}

factory_ = [=, fn = std::move(fn), inner = std::move(factory_)]() mutable
factory_ = [
=,
fn = std::move(fn),
maybe_bucket_fn = std::move(maybe_bucket_fn),
inner = std::move(factory_)]() mutable
{
return std::make_unique<dynamic_bucket_data_source>(
inner(),
threshold,
std::move(fn),
std::move(maybe_bucket_fn),
maybe_min_num_examples,
maybe_max_num_examples,
drop_remainder);
Expand Down
8 changes: 6 additions & 2 deletions native/src/fairseq2n/data/data_pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <cstddef>
#include <cstdint>
#include <deque>
#include <filesystem>
#include <functional>
#include <memory>
Expand Down Expand Up @@ -111,14 +112,16 @@ class FAIRSEQ2_API data_pipeline {
mutable bool is_broken_ = false;
};

using bucket_creation_fn = std::function<std::pair<std::deque<data>, data_list>(data_list &&)>;

using cost_fn = std::function<float64(const data &)>;

using data_length_fn = std::function<std::size_t(const data &)>;

using map_fn = std::function<data(data &&)>;

using predicate_fn = std::function<bool(const data &)>;

using cost_fn = std::function<float64(const data &)>;

using yield_fn = std::function<data_pipeline(const data &)>;

class FAIRSEQ2_API data_pipeline_builder {
Expand Down Expand Up @@ -152,6 +155,7 @@ class FAIRSEQ2_API data_pipeline_builder {
dynamic_bucket(
float64 threshold,
cost_fn fn,
std::optional<bucket_creation_fn> maybe_bucket_fn = std::nullopt,
std::optional<std::size_t> maybe_min_num_examples = std::nullopt,
std::optional<std::size_t> maybe_max_num_examples = std::nullopt,
bool drop_remainder = false) &&;
Expand Down
45 changes: 38 additions & 7 deletions native/src/fairseq2n/data/dynamic_bucket_data_source.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ dynamic_bucket_data_source::dynamic_bucket_data_source(
std::unique_ptr<data_source> &&inner,
float64 threshold,
cost_fn &&fn,
std::optional<bucket_creation_fn> &&maybe_bucket_fn,
std::optional<std::size_t> maybe_min_num_examples,
std::optional<std::size_t> maybe_max_num_examples,
bool drop_remainder) noexcept
: inner_{std::move(inner)},
threshold_{threshold},
cost_fn_{std::move(fn)},
maybe_bucket_creation_fn_{std::move(maybe_bucket_fn)},
maybe_min_num_examples_{maybe_min_num_examples},
maybe_max_num_examples_{maybe_max_num_examples},
drop_remainder_{drop_remainder}
Expand All @@ -28,10 +30,14 @@ dynamic_bucket_data_source::dynamic_bucket_data_source(
std::optional<data>
dynamic_bucket_data_source::next()
{
data_list output{};
if (!return_buffer_.empty()) {
data output{return_buffer_.front()};
return_buffer_.pop_front();
return output;
}

if (maybe_min_num_examples_)
output.reserve(*maybe_min_num_examples_);
buffer_.reserve(*maybe_min_num_examples_);

float64 cost = 0;

Expand All @@ -40,13 +46,13 @@ dynamic_bucket_data_source::next()

bool minimum_size_met = true;
if (maybe_min_num_examples_)
minimum_size_met = output.size() >= *maybe_min_num_examples_;
minimum_size_met = buffer_.size() >= *maybe_min_num_examples_;

if (cost_threshold_met && minimum_size_met) return true;

bool maximum_size_met = false;
if (maybe_max_num_examples_)
maximum_size_met = output.size() >= *maybe_max_num_examples_;
maximum_size_met = buffer_.size() >= *maybe_max_num_examples_;

return maximum_size_met;
};
Expand All @@ -56,33 +62,58 @@ dynamic_bucket_data_source::next()
if (!maybe_example)
break;
cost += cost_fn_(*maybe_example);
output.push_back(*std::move(maybe_example));
buffer_.push_back(*std::move(maybe_example));
}

if (output.empty())
if (buffer_.empty())
return std::nullopt;

if (drop_remainder_ && !bucket_ready())
if (bucket_ready()) {
if (maybe_bucket_creation_fn_) {
const bucket_creation_fn& fn = *maybe_bucket_creation_fn_;
auto&& [return_buffer, new_buffer] = fn(std::move(buffer_));

buffer_ = std::move(new_buffer);

data output{return_buffer.front()};
return_buffer.pop_front();

return_buffer_ = std::move(return_buffer);

return output;
}
} else if (drop_remainder_) {
buffer_.clear();
return std::nullopt;
}

data_list output = std::move(buffer_);
buffer_.clear();
return output;
}

void
dynamic_bucket_data_source::reset(bool reset_rng)
{
buffer_.clear();
inner_->reset(reset_rng);
}

void
dynamic_bucket_data_source::record_position(tape &t, bool strict) const
{
if (maybe_bucket_creation_fn_) {
t.record(buffer_);
}
inner_->record_position(t, strict);
}

void
dynamic_bucket_data_source::reload_position(tape &t, bool strict)
{
if (maybe_bucket_creation_fn_) {
buffer_ = t.read<data_list>();
}
inner_->reload_position(t, strict);
}

Expand Down
6 changes: 6 additions & 0 deletions native/src/fairseq2n/data/dynamic_bucket_data_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class dynamic_bucket_data_source final : public data_source {
std::unique_ptr<data_source> &&inner,
float64 threshold,
cost_fn &&fn,
std::optional<bucket_creation_fn> &&maybe_bucket_fn,
std::optional<std::size_t> maybe_min_num_examples,
std::optional<std::size_t> maybe_max_num_examples,
bool drop_remainder) noexcept;
Expand All @@ -45,9 +46,14 @@ class dynamic_bucket_data_source final : public data_source {
std::unique_ptr<data_source> inner_;
float64 threshold_;
cost_fn cost_fn_;
std::optional<bucket_creation_fn> maybe_bucket_creation_fn_;
std::optional<std::size_t> maybe_min_num_examples_;
std::optional<std::size_t> maybe_max_num_examples_;
bool drop_remainder_;

data_list buffer_{};
std::deque<data> return_buffer_{};

};

} // namespace fairseq2n::detail
2 changes: 1 addition & 1 deletion src/fairseq2/assets/cards/datasets/librispeech.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
name: librispeech_asr
dataset_family: generic_asr
tokenizer: "https://dl.fbaipublicfiles.com/fairseq/wav2vec/librispeech_asr.model"
tokenizer_family: librispeech_asr
tokenizer_family: char_tokenizer

---

Expand Down
11 changes: 11 additions & 0 deletions src/fairseq2/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ def dynamic_bucket(
self,
threshold: float,
cost_fn: Callable[[Any], float],
bucket_creation_fn: Callable[
[Sequence[Any]], tuple[Sequence[Sequence[Any]], Sequence[Any]]
]
| None = None,
min_num_examples: int | None = None,
max_num_examples: int | None = None,
drop_remainder: bool = False,
Expand All @@ -246,6 +250,13 @@ def dynamic_bucket(
Threshold for cumulative cost to trigger bucketing.
:param cost_fn:
Cost function that outputs cost for a particular example.
:param bucket_creation_fn:
Function for customizing bucket creation. Called with the bucket of
examples that caused the cost threshold to be exceeded.
Expected to return a tuple of ``(new_buckets, remainder)``, where
the internal buffer is set to ``remainder`` and ``new_buckets`` is
a list of buckets to be yielded. If ``None``, defaults to the
identity function.
:param min_num_examples:
Minimum number of examples per bucket.
:param max_num_examples:
Expand Down
1 change: 1 addition & 0 deletions src/fairseq2/data/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from fairseq2.data.text.sentencepiece import (
default_raw_sentencepiece_tokenizer_loader as default_raw_sentencepiece_tokenizer_loader,
)
from fairseq2.data.text.sentencepiece import load_char_tokenizer as load_char_tokenizer
from fairseq2.data.text.sentencepiece import (
vocab_info_from_sentencepiece as vocab_info_from_sentencepiece,
)
Expand Down
6 changes: 6 additions & 0 deletions src/fairseq2/data/text/sentencepiece.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
AbstractTextTokenizerLoader,
TextTokenDecoder,
TextTokenEncoder,
load_text_tokenizer,
)
from fairseq2.data.vocabulary_info import VocabularyInfo
from fairseq2.typing import Device
Expand Down Expand Up @@ -314,3 +315,8 @@ def vocab_info_from_sentencepiece(model: SentencePieceModel) -> VocabularyInfo:
model.eos_idx,
model.pad_idx,
)


load_char_tokenizer = default_raw_sentencepiece_tokenizer_loader

load_text_tokenizer.register("char_tokenizer", load_char_tokenizer)
1 change: 1 addition & 0 deletions src/fairseq2/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations

from fairseq2.datasets.batching import Batching as Batching
from fairseq2.datasets.batching import LengthBatching as LengthBatching
from fairseq2.datasets.batching import StaticBatching as StaticBatching
from fairseq2.datasets.data_reader import DataPipelineReader as DataPipelineReader
Expand Down
22 changes: 7 additions & 15 deletions src/fairseq2/datasets/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,8 @@
read_sequence,
)
from fairseq2.data.audio import AudioDecoder
from fairseq2.data.text import (
StrSplitter,
TextTokenizer,
default_raw_sentencepiece_tokenizer_loader,
load_text_tokenizer,
read_text,
)
from fairseq2.datasets.batching import LengthBatching, StaticBatching
from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
from fairseq2.datasets.batching import Batching, LengthBatching, StaticBatching
from fairseq2.datasets.data_reader import DataPipelineReader, DataReader
from fairseq2.datasets.error import DatasetError
from fairseq2.datasets.loader import AbstractDatasetLoader, DelegatingDatasetLoader
Expand All @@ -54,7 +48,7 @@ def create_reader(
tokenizer: TextTokenizer,
gang: Gang,
max_audio_len: int,
batching: StaticBatching | LengthBatching,
batching: Batching,
*,
dtype: DataType = torch.float32,
min_audio_len: int = 1,
Expand Down Expand Up @@ -173,7 +167,7 @@ def create_reader(
tokenizer: TextTokenizer,
gang: Gang,
max_audio_len: int,
batching: StaticBatching | LengthBatching,
batching: Batching,
*,
dtype: DataType = torch.float32,
min_audio_len: int = 1,
Expand Down Expand Up @@ -231,7 +225,7 @@ def create_reader(
skip_above_max_examples=True,
drop_remainder=drop_remainder,
)
else:
elif isinstance(batching, StaticBatching):
# Filter out out-of-range audios.
def skip(example: dict[str, Any]) -> bool:
audio_len = cast(int, example["audio_size"])
Expand All @@ -242,6 +236,8 @@ def skip(example: dict[str, Any]) -> bool:

# Bucket `batch_size` examples.
builder.bucket(batching.batch_size, drop_remainder=drop_remainder)
else:
raise RuntimeError(f"`{batching}` is not supported.")

# Shuffle buckets.
if batch_shuffle_window != 1:
Expand Down Expand Up @@ -392,7 +388,3 @@ def _load(self, path: Path, card: AssetCard) -> GenericAsrDataset:
load_generic_asr_dataset = GenericAsrDatasetLoader()

load_asr_dataset.register("generic_asr", load_generic_asr_dataset)

load_librispeech_asr_tokenizer = default_raw_sentencepiece_tokenizer_loader

load_text_tokenizer.register("librispeech_asr", load_librispeech_asr_tokenizer)
4 changes: 4 additions & 0 deletions src/fairseq2/datasets/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TypeAlias


@dataclass
Expand All @@ -23,3 +24,6 @@ class LengthBatching:

max_num_elements: int
"""The maximum number of elements (e.g. tokens) in each batch."""


Batching: TypeAlias = StaticBatching | LengthBatching
Loading

0 comments on commit 1485f6d

Please sign in to comment.