From 316694f0590815a3626fe0131c56a001e45876af Mon Sep 17 00:00:00 2001 From: Huy Le Nguyen Date: Sun, 5 May 2024 22:22:31 +0700 Subject: [PATCH] fix: add tfrecords buffer size in bytes for tfrecords dataset --- tensorflow_asr/datasets.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tensorflow_asr/datasets.py b/tensorflow_asr/datasets.py index 236e0189f9..c05f76b2c2 100755 --- a/tensorflow_asr/datasets.py +++ b/tensorflow_asr/datasets.py @@ -25,9 +25,11 @@ # An ASR dataset is some `.tsv` files in format: `PATH\tDURATION\tTRANSCRIPT`. You must create those files by your own with your own data and methods. -# **Note**: Each `.tsv` file must include a header `PATH\tDURATION\tTRANSCRIPT` because it will remove these headers when loading dataset, otherwise you will lose 1 data file :sob: +# **Note**: Each `.tsv` file must include a header `PATH\tDURATION\tTRANSCRIPT` +# because it will remove these headers when loading dataset, otherwise you will lose 1 data file :sob: -# **For transcript**, if you want to include characters such as dots, commas, double quote, etc.. you must create your own `.txt` vocabulary file. Default is [English](../featurizers/english.txt) +# **For transcript**, if you want to include characters such as dots, commas, double quote, etc.. you must create your own `.txt` vocabulary file. +# Default is [English](../featurizers/english.txt) # **Inputs** @@ -141,8 +143,9 @@ def get_global_shape( BUFFER_SIZE = 100 +TFRECORD_BUFFER_SIZE = 32 * 1024 * 1024 TFRECORD_SHARDS = 16 -AUTOTUNE = int(os.environ.get("AUTOTUNE") or tf.data.experimental.AUTOTUNE) +AUTOTUNE = int(os.environ.get("AUTOTUNE") or tf.data.AUTOTUNE) class BaseDataset: @@ -416,6 +419,7 @@ def __init__( indefinite: bool = True, drop_remainder: bool = True, buffer_size: int = BUFFER_SIZE, + tfrecords_buffer_size: int = TFRECORD_BUFFER_SIZE, compression_type: str = "GZIP", sample_rate: int = 16000, name: str = "", @@ -442,6 +446,7 @@ def __init__( if tfrecords_shards <= 0: raise ValueError("tfrecords_shards must be positive") self.tfrecords_shards = tfrecords_shards + self.tfrecords_buffer_size = tfrecords_buffer_size self.compression_type = compression_type def write_tfrecord_file(self, splitted_entries: tuple): @@ -506,7 +511,9 @@ def create(self, batch_size: int, padded_shapes=None): ignore_order = tf.data.Options() ignore_order.deterministic = False files_ds = files_ds.with_options(ignore_order) - dataset = tf.data.TFRecordDataset(files_ds, compression_type=self.compression_type, num_parallel_reads=AUTOTUNE) + dataset = tf.data.TFRecordDataset( + files_ds, compression_type=self.compression_type, buffer_size=self.tfrecords_buffer_size, num_parallel_reads=AUTOTUNE + ) return self.process(dataset, batch_size, padded_shapes=padded_shapes)