Skip to content

Commit

Permalink
Standardize all inference-based methods
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Nov 25, 2024
1 parent 5f47892 commit b8f4957
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 338 deletions.
6 changes: 2 additions & 4 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,13 @@ def main():
logging.warning("Test results are without special markdown tags")

tick = time.time()
perform_test(config, model, data_manager.datasets["test"],
data_manager)
perform_test(config, model, data_manager)
timestamps['Test'] = time.time() - tick

# Infer with the model
if config["inference_list"]:
tick = time.time()
perform_inference(config, model, data_manager.datasets["inference"],
data_manager)
perform_inference(config, model, data_manager)
timestamps['Inference'] = time.time() - tick

# Log the timestamps
Expand Down
69 changes: 22 additions & 47 deletions src/modes/inference.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,53 @@
# Imports

# > Standard library
from typing import List
import logging

# > Third-party dependencies
import tensorflow as tf

# > Local dependencies
from data.manager import DataManager
from setup.config import Config
from utils.threading import DecodingWorker, ResultWriter
from modes.utils import setup_workers, process_batches
from utils.threading import ResultWriter
from utils.wbs import setup_word_beam_search


def perform_inference(config: Config,
model: tf.keras.Model,
inference_dataset: tf.data.Dataset,
data_manager: DataManager) -> None:
"""
Performs inference using parallel processing with direct communication between
decoders and the result writer.
Generic evaluation function for inference.
Parameters
----------
config : Config
Configuration object containing parameters for inference.
Configuration object.
model : tf.keras.Model
Trained TensorFlow model for making predictions.
inference_dataset : tf.data.Dataset
Dataset for performing inference.
Keras model to be inferenced on.
data_manager : DataManager
Object managing dataset and filename retrieval.
Notes
-----
This optimized implementation:
- Uses separate threads for CTC decoding and result writing.
- Implements direct communication between decoders and the result writer.
- Prefetches the next batch while processing the current one.
- Minimizes GPU blocking time.
Data manager containing datasets and tokenizer.
"""
# Initialize result writer first
logging.info("Starting %s...", 'inference')

dataset = data_manager.datasets['inference']
tokenizer = data_manager.tokenizer
wbs = setup_word_beam_search(
config, tokenizer.token_list) if config["corpus_file"] else None

# Initialize result writer
result_writer = ResultWriter(config["results_file"],
maxsize=config["batch_size"] * 5)
result_writer.start()

# Initialize decoders with direct access to the result writer
num_decode_workers: int = 2 # Adjust based on available CPU cores
decode_workers: List[DecodingWorker] = [
DecodingWorker(data_manager.tokenizer, config,
result_writer.queue, maxsize=5)
for _ in range(num_decode_workers)
]

# Start all decode workers
for worker in decode_workers:
worker.start()
# Initialize decode workers
decode_workers = setup_workers(config, data_manager,
result_writer.queue, wbs)

try:
for batch_no, batch in enumerate(inference_dataset):
# Get predictions (GPU operation)
predictions: tf.Tensor = model.predict_on_batch(batch[0])

# Prepare filenames for the batch
batch_filenames: List[str] = [
data_manager.get_filename('inference',
(batch_no * config["batch_size"]) + idx)
for idx in range(len(predictions))
]

# Distribute work to decode workers
worker_idx: int = batch_no % num_decode_workers
decode_workers[worker_idx].input_queue.put(
(predictions, batch_no, batch_filenames, None)
)

process_batches(dataset, model, config, data_manager,
decode_workers, 'inference')
finally:
# Clean up workers
for worker in decode_workers:
Expand Down
119 changes: 18 additions & 101 deletions src/modes/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,136 +2,53 @@

# > Standard library
import logging
import os
from typing import List

# > Third-party dependencies
import tensorflow as tf

# > Local dependencies
from data.manager import DataManager
from setup.config import Config
from utils.calculate import calc_95_confidence_interval
from utils.threading import DecodingWorker, MetricsCalculator
from modes.utils import setup_workers, process_batches, output_statistics
from utils.threading import MetricsCalculator
from utils.wbs import setup_word_beam_search


def perform_test(config: Config,
model: tf.keras.Model,
test_dataset: tf.data.Dataset,
data_manager: DataManager) -> None:
data_manager: DataManager):
"""
Performs test run on a dataset using a given model and calculates various
metrics like Character Error Rate (CER).
Generic evaluation function for test and validation.
Parameters
----------
config : Config
A Config object containing arguments for the validation process such as
mask usage and file paths.
Configuration object.
model : tf.keras.Model
The Keras model to be validated.
test_dataset : tf.data.Dataset
The dataset to be used for testing.
Keras model to be evaluated.
data_manager : DataManager
A data data_manager object for additional operations like normalization
and Word Beam Search setup.
Notes
-----
The function processes each batch in the validation dataset, calculates
CER, and optionally processes Word Beam Search (WBS) results if enabled.
It also handles the display and logging of statistical information
throughout the validation process.
Data manager containing datasets and tokenizers.
"""
logging.info("Performing %s...", 'test')

logging.info("Performing test...")
dataset = data_manager.datasets['test']
tokenizer = data_manager.tokenizer
wbs = setup_word_beam_search(
config, tokenizer.token_list) if config["corpus_file"] else None

# Setup WordBeamSearch if needed
wbs = setup_word_beam_search(config, tokenizer.token_list) \
if config["corpus_file"] else None

# Intialize the metric calculator
metrics_calculator = MetricsCalculator(
config, log_results=False, maxsize=10)
metrics_calculator = MetricsCalculator(config, log_results=False,
maxsize=100)
metrics_calculator.start()
decode_workers = setup_workers(config, data_manager,
metrics_calculator.queue, wbs)

# Initialize decoders with direct access to the result writer
num_decode_workers: int = 2 # Adjust based on available CPU cores
decode_workers: List[DecodingWorker] = [
DecodingWorker(data_manager.tokenizer, config,
metrics_calculator.queue, wbs=wbs,
maxsize=5)
for _ in range(num_decode_workers)
]

# Start all decode workers
for worker in decode_workers:
worker.start()

# Process each batch in the validation dataset
try:
for batch_no, batch in enumerate(test_dataset):
logging.info("Batch %s/%s", batch_no + 1, len(test_dataset))

X = batch[0]
y = [data_manager.get_ground_truth('test', i)
for i in range(batch_no * config["batch_size"],
batch_no * config["batch_size"] + len(X))]

# Get predictions (GPU operation)
predictions: tf.Tensor = model.predict_on_batch(X)

# Prepare filenames for the batch
batch_filenames: List[str] = [
data_manager.get_filename('test',
(batch_no * config["batch_size"]) + idx)
for idx in range(len(predictions))
]

# Distribute work to decode workers
worker_idx: int = batch_no % num_decode_workers
decode_workers[worker_idx].input_queue.put(
(predictions, batch_no, batch_filenames, y)
)
process_batches(dataset, model, config, data_manager,
decode_workers, 'test')
finally:
# Clean up workers
for worker in decode_workers:
worker.stop()
metrics_calculator.stop()

# Print the final test statistics
logging.info("")
logging.info("--------------------------------------------------------")
logging.info("")
logging.info("Final test statistics")
logging.info("---------------------------")

# Calculate the CER confidence intervals on all metrics except Items
intervals = [calc_95_confidence_interval(cer_metric, metrics_calculator.n_items)
for cer_metric in metrics_calculator.total_stats[:-1]]

# Print the final statistics
for metric, total_value, interval in zip(metrics_calculator.metrics[:-1],
metrics_calculator.total_stats[:-1],
intervals):
logging.info("%s = %.4f +/- %.4f", metric, total_value, interval)

logging.info("Items = %s", metrics_calculator.total_stats[-1])
logging.info("")

# Output the validation statistics to a csv file
with open(os.path.join(config["output"], 'test.csv'),
'w', encoding="utf-8") as f:
header = "cer,cer_lower,cer_simple"
if config["normalization_file"]:
header += ",normalized_cer,normalized_cer_lower," \
"normalized_cer_simple"
if wbs:
header += ",wbs_cer,wbs_cer_lower,wbs_cer_simple"

f.write(header + "\n")
results = ",".join([str(metrics_calculator.total_stats[i])
for i in range(len(metrics_calculator.metrics)-1)])
f.write(results + "\n")
output_statistics(metrics_calculator, config, 'test', wbs)
91 changes: 91 additions & 0 deletions src/modes/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# > Standard library
import logging
import os
from typing import List, Optional
from queue import Queue

# > Third-party dependencies
import tensorflow as tf

# > Local dependencies
from data.manager import DataManager
from setup.config import Config
from utils.calculate import calc_95_confidence_interval
from utils.threading import DecodingWorker


def setup_workers(config: Config,
data_manager: DataManager,
result_queue: Optional[Queue] = None,
wbs=None) -> List[DecodingWorker]:
"""Sets up decode workers with optional result writer integration."""
num_decode_workers: int = 2 # Adjust based on available CPU cores
decode_workers: List[DecodingWorker] = [
DecodingWorker(data_manager.tokenizer, config,
result_queue if result_queue else None,
wbs=wbs, maxsize=5)
for _ in range(num_decode_workers)
]

for worker in decode_workers:
worker.start()

return decode_workers


def process_batches(dataset, model, config, data_manager, decode_workers, mode: str):
"""Processes batches from the dataset and distributes work to decode workers."""
for batch_no, batch in enumerate(dataset):
X = batch[0]
if mode != 'inference':
y = [data_manager.get_ground_truth(mode, i)
for i in range(batch_no * config["batch_size"],
batch_no * config["batch_size"] + len(X))]
else:
y = None

predictions: tf.Tensor = model.predict_on_batch(X)

batch_filenames: List[str] = [
data_manager.get_filename(mode,
(batch_no * config["batch_size"]) + idx)
for idx in range(len(predictions))
]

worker_idx: int = batch_no % len(decode_workers)
decode_workers[worker_idx].input_queue.put(
(predictions, batch_no, batch_filenames, y)
)


def output_statistics(metrics_calculator, config, mode: str, wbs=None):
"""Logs final statistics and writes them to a CSV file."""
logging.info("--------------------------------------------------------")
logging.info("")
logging.info("Final %s statistics", mode)
logging.info("---------------------------")

intervals = [calc_95_confidence_interval(cer_metric, metrics_calculator.n_items)
for cer_metric in metrics_calculator.total_stats[:-1]]

for metric, total_value, interval in zip(metrics_calculator.metrics[:-1],
metrics_calculator.total_stats[:-1],
intervals):
logging.info("%s = %.4f +/- %.4f", metric, total_value, interval)

logging.info("Items = %s", metrics_calculator.total_stats[-1])
logging.info("")

output_file = os.path.join(config["output"], f"{mode}.csv")
with open(output_file, 'w', encoding="utf-8") as f:
header = "cer,cer_lower,cer_simple"
if config["normalization_file"]:
header += ",normalized_cer,normalized_cer_lower," \
"normalized_cer_simple"
if wbs:
header += ",wbs_cer,wbs_cer_lower,wbs_cer_simple"

f.write(header + "\n")
results = ",".join([str(metrics_calculator.total_stats[i])
for i in range(len(metrics_calculator.metrics) - 1)])
f.write(results + "\n")
Loading

0 comments on commit b8f4957

Please sign in to comment.