Skip to content

Commit

Permalink
fix: add tfrecords buffer size in bytes for tfrecords dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
nglehuy committed May 5, 2024
1 parent 0635cfe commit 316694f
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions tensorflow_asr/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@

# An ASR dataset is some `.tsv` files in format: `PATH\tDURATION\tTRANSCRIPT`. You must create those files by your own with your own data and methods.

# **Note**: Each `.tsv` file must include a header `PATH\tDURATION\tTRANSCRIPT` because it will remove these headers when loading dataset, otherwise you will lose 1 data file :sob:
# **Note**: Each `.tsv` file must include a header `PATH\tDURATION\tTRANSCRIPT`
# because it will remove these headers when loading dataset, otherwise you will lose 1 data file :sob:

# **For transcript**, if you want to include characters such as dots, commas, double quote, etc.. you must create your own `.txt` vocabulary file. Default is [English](../featurizers/english.txt)
# **For transcript**, if you want to include characters such as dots, commas, double quote, etc.. you must create your own `.txt` vocabulary file.
# Default is [English](../featurizers/english.txt)

# **Inputs**

Expand Down Expand Up @@ -141,8 +143,9 @@ def get_global_shape(


BUFFER_SIZE = 100
TFRECORD_BUFFER_SIZE = 32 * 1024 * 1024
TFRECORD_SHARDS = 16
AUTOTUNE = int(os.environ.get("AUTOTUNE") or tf.data.experimental.AUTOTUNE)
AUTOTUNE = int(os.environ.get("AUTOTUNE") or tf.data.AUTOTUNE)


class BaseDataset:
Expand Down Expand Up @@ -416,6 +419,7 @@ def __init__(
indefinite: bool = True,
drop_remainder: bool = True,
buffer_size: int = BUFFER_SIZE,
tfrecords_buffer_size: int = TFRECORD_BUFFER_SIZE,
compression_type: str = "GZIP",
sample_rate: int = 16000,
name: str = "",
Expand All @@ -442,6 +446,7 @@ def __init__(
if tfrecords_shards <= 0:
raise ValueError("tfrecords_shards must be positive")
self.tfrecords_shards = tfrecords_shards
self.tfrecords_buffer_size = tfrecords_buffer_size
self.compression_type = compression_type

def write_tfrecord_file(self, splitted_entries: tuple):
Expand Down Expand Up @@ -506,7 +511,9 @@ def create(self, batch_size: int, padded_shapes=None):
ignore_order = tf.data.Options()
ignore_order.deterministic = False
files_ds = files_ds.with_options(ignore_order)
dataset = tf.data.TFRecordDataset(files_ds, compression_type=self.compression_type, num_parallel_reads=AUTOTUNE)
dataset = tf.data.TFRecordDataset(
files_ds, compression_type=self.compression_type, buffer_size=self.tfrecords_buffer_size, num_parallel_reads=AUTOTUNE
)

return self.process(dataset, batch_size, padded_shapes=padded_shapes)

Expand Down

0 comments on commit 316694f

Please sign in to comment.