Skip to content

Commit

Permalink
Merge pull request #121 from wolny/imbalanced_labels
Browse files Browse the repository at this point in the history
Refactor and improvements
  • Loading branch information
wolny authored Jan 6, 2025
2 parents 3d67c4e + 9551cef commit fd774f9
Show file tree
Hide file tree
Showing 20 changed files with 534 additions and 255 deletions.
1 change: 1 addition & 0 deletions conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ requirements:
build:
- python >=3.9
- pip
- setuptools

run:
- python >=3.9
Expand Down
41 changes: 34 additions & 7 deletions pytorch3dunet/augment/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import torch
from scipy.ndimage import rotate, map_coordinates, gaussian_filter, convolve
from skimage import measure
from skimage import measure, exposure
from skimage.filters import gaussian
from skimage.segmentation import find_boundaries

Expand Down Expand Up @@ -133,6 +133,27 @@ def __call__(self, m):
return m


class RandomGammaCorrection:
"""
Adjust contrast by scaling each voxel to `v ** gamma`.
"""

def __init__(self, random_state, gamma=(0.5, 1.5), execution_probability=0.1, **kwargs):
self.random_state = random_state
assert len(gamma) == 2
self.gamma = gamma
self.execution_probability = execution_probability

def __call__(self, m):
if self.random_state.uniform() < self.execution_probability:
# rescale intensity values to [0, 1]
m = exposure.rescale_intensity(m, out_range=(0, 1))
gamma = self.random_state.uniform(self.gamma[0], self.gamma[1])
return exposure.adjust_gamma(m, gamma)

return m


# it's relatively slow, i.e. ~1s per patch of size 64x200x200, so use multiple workers in the DataLoader
# remember to use spline_order=0 when transforming the labels
class ElasticDeformation:
Expand Down Expand Up @@ -576,12 +597,12 @@ def __call__(self, m):
# check if non None in self.min_value/self.max_value
# if present and if so copy value to min_value
if self.min_value is not None:
for i,v in enumerate(self.min_value):
for i, v in enumerate(self.min_value):
if v != 'None':
min_value[i] = v

if self.max_value is not None:
for i,v in enumerate(self.max_value):
for i, v in enumerate(self.max_value):
if v != 'None':
max_value[i] = v
else:
Expand All @@ -600,9 +621,9 @@ def __call__(self, m):
norm_0_1 = (m - min_value) / (max_value - min_value + self.eps)

if self.norm01 is True:
return np.clip(norm_0_1, 0, 1)
return np.clip(norm_0_1, 0, 1)
else:
return np.clip(2 * norm_0_1 - 1, -1, 1)
return np.clip(2 * norm_0_1 - 1, -1, 1)


class AdditiveGaussianNoise:
Expand Down Expand Up @@ -640,18 +661,24 @@ class ToTensor:
Args:
expand_dims (bool): if True, adds a channel dimension to the input data
dtype (np.dtype): the desired output data type
normalize (bool): zero-one normalization of the input data
"""

def __init__(self, expand_dims, dtype=np.float32, **kwargs):
def __init__(self, expand_dims, dtype=np.float32, normalize=False, **kwargs):
self.expand_dims = expand_dims
self.dtype = dtype
self.normalize = normalize

def __call__(self, m):
assert m.ndim in [3, 4], 'Supports only 3D (DxHxW) or 4D (CxDxHxW) images'
# add channel dimension
if self.expand_dims and m.ndim == 3:
m = np.expand_dims(m, axis=0)

if self.normalize:
# avoid division by zero
m = (m - np.min(m)) / (np.max(m) - np.min(m) + 1e-10)

return torch.from_numpy(m.astype(dtype=self.dtype))


Expand Down Expand Up @@ -706,7 +733,7 @@ def __call__(self, m):


class GaussianBlur3D:
def __init__(self, sigma=[.1, 2.], execution_probability=0.5, **kwargs):
def __init__(self, sigma=(.1, 2.), execution_probability=0.5, **kwargs):
self.sigma = sigma
self.execution_probability = execution_probability

Expand Down
97 changes: 64 additions & 33 deletions pytorch3dunet/datasets/hdf5.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import glob
import os
from abc import abstractmethod
from concurrent.futures.process import ProcessPoolExecutor
from itertools import chain

import h5py

import pytorch3dunet.augment.transforms as transforms
from pytorch3dunet.datasets.utils import get_slice_builder, ConfigDataset, calculate_stats, mirror_pad
from pytorch3dunet.datasets.utils import get_slice_builder, ConfigDataset, calculate_stats, mirror_pad, RandomScaler
from pytorch3dunet.unet3d.utils import get_logger

logger = get_logger('HDF5Dataset')
Expand Down Expand Up @@ -44,10 +45,14 @@ class AbstractHDF5Dataset(ConfigDataset):
label_internal_path (str or list): H5 internal path to the label dataset
weight_internal_path (str or list): H5 internal path to the per pixel weights (optional)
global_normalization (bool): if True, the mean and std of the raw data will be calculated over the whole dataset
random_scale (int): if not None, the raw data will be randomly shifted by a value in the range
[-random_scale, random_scale] in each dimension and then scaled to the original patch shape
random_scale_probability (float): probability of executing the random scale on a patch
"""

def __init__(self, file_path, phase, slice_builder_config, transformer_config, raw_internal_path='raw',
label_internal_path='label', weight_internal_path=None, global_normalization=True):
label_internal_path='label', weight_internal_path=None, global_normalization=True,
random_scale=None, random_scale_probability=0.5):
assert phase in ['train', 'val', 'test']

self.phase = phase
Expand Down Expand Up @@ -94,6 +99,10 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r

with h5py.File(file_path, 'r') as f:
raw = f[raw_internal_path]
if raw.ndim == 3:
self.volume_shape = raw.shape
else:
self.volume_shape = raw.shape[1:]
label = f[label_internal_path] if phase != 'test' else None
weight_map = f[weight_internal_path] if weight_internal_path is not None else None
# build slice indices for raw and label data sets
Expand All @@ -102,8 +111,18 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r
self.label_slices = slice_builder.label_slices
self.weight_slices = slice_builder.weight_slices

if random_scale is not None:
assert isinstance(random_scale, int), 'random_scale must be an integer'
stride_shape = slice_builder_config.get('stride_shape')
assert all(random_scale < stride for stride in stride_shape), \
f"random_scale {random_scale} must be smaller than each of the strides {stride_shape}"
patch_shape = slice_builder_config.get('patch_shape')
self.random_scaler = RandomScaler(random_scale, patch_shape, self.volume_shape, random_scale_probability)
logger.info(f"Using RandomScaler with offset range {random_scale}")
else:
self.random_scaler = None

self.patch_count = len(self.raw_slices)
logger.info(f'Number of patches: {self.patch_count}')

@abstractmethod
def get_raw_patch(self, idx):
Expand All @@ -121,14 +140,6 @@ def get_weight_patch(self, idx):
def get_raw_padded_patch(self, idx):
raise NotImplementedError

def volume_shape(self):
with h5py.File(self.file_path, 'r') as f:
raw = f[self.raw_internal_path]
if raw.ndim == 3:
return raw.shape
else:
return raw.shape[1:]

def __getitem__(self, idx):
if idx >= len(self):
raise StopIteration
Expand All @@ -146,15 +157,24 @@ def __getitem__(self, idx):
raw_patch_transformed = self.raw_transform(self.get_raw_padded_patch(raw_idx_padded))
return raw_patch_transformed, raw_idx
else:
raw_patch_transformed = self.raw_transform(self.get_raw_patch(raw_idx))

# get the slice for a given index 'idx'
label_idx = self.label_slices[idx]

if self.random_scaler is not None:
# randomize the indices
raw_idx, label_idx = self.random_scaler.randomize_indices(raw_idx, label_idx)

raw_patch_transformed = self.raw_transform(self.get_raw_patch(raw_idx))
label_patch_transformed = self.label_transform(self.get_label_patch(label_idx))
if self.weight_internal_path is not None:
weight_idx = self.weight_slices[idx]
weight_patch_transformed = self.weight_transform(self.get_weight_patch(weight_idx))
return raw_patch_transformed, label_patch_transformed, weight_patch_transformed

if self.random_scaler is not None:
# scale patches back to the original patch size
raw_patch_transformed, label_patch_transformed = self.random_scaler.rescale_patches(
raw_patch_transformed, label_patch_transformed
)
# return the transformed raw and label patches
return raw_patch_transformed, label_patch_transformed

Expand Down Expand Up @@ -192,22 +212,31 @@ def create_datasets(cls, dataset_config, phase):
# are going to be included in the final file_paths
file_paths = traverse_h5_paths(file_paths)

datasets = []
for file_path in file_paths:
try:
# create datasets concurrently
with ProcessPoolExecutor() as executor:
futures = []
for file_path in file_paths:
logger.info(f'Loading {phase} set from: {file_path}...')
dataset = cls(file_path=file_path,
phase=phase,
slice_builder_config=slice_builder_config,
transformer_config=transformer_config,
raw_internal_path=dataset_config.get('raw_internal_path', 'raw'),
label_internal_path=dataset_config.get('label_internal_path', 'label'),
weight_internal_path=dataset_config.get('weight_internal_path', None),
global_normalization=dataset_config.get('global_normalization', None))
datasets.append(dataset)
except Exception:
logger.error(f'Skipping {phase} set: {file_path}', exc_info=True)
return datasets
future = executor.submit(cls, file_path=file_path,
phase=phase,
slice_builder_config=slice_builder_config,
transformer_config=transformer_config,
raw_internal_path=dataset_config.get('raw_internal_path', 'raw'),
label_internal_path=dataset_config.get('label_internal_path', 'label'),
weight_internal_path=dataset_config.get('weight_internal_path', None),
global_normalization=dataset_config.get('global_normalization', None),
random_scale=dataset_config.get('random_scale', None),
random_scale_probability=dataset_config.get('random_scale_probability', 0.5))
futures.append(future)

datasets = []
for future in futures:
try:
dataset = future.result()
datasets.append(dataset)
except Exception as e:
logger.error(f'Failed to load dataset: {e}')
return datasets


class StandardHDF5Dataset(AbstractHDF5Dataset):
Expand All @@ -218,11 +247,12 @@ class StandardHDF5Dataset(AbstractHDF5Dataset):

def __init__(self, file_path, phase, slice_builder_config, transformer_config,
raw_internal_path='raw', label_internal_path='label', weight_internal_path=None,
global_normalization=True):
global_normalization=True, random_scale=None, random_scale_probability=0.5):
super().__init__(file_path=file_path, phase=phase, slice_builder_config=slice_builder_config,
transformer_config=transformer_config, raw_internal_path=raw_internal_path,
label_internal_path=label_internal_path, weight_internal_path=weight_internal_path,
global_normalization=global_normalization)
global_normalization=global_normalization, random_scale=random_scale,
random_scale_probability=random_scale_probability)
self._raw = None
self._raw_padded = None
self._label = None
Expand Down Expand Up @@ -262,11 +292,12 @@ class LazyHDF5Dataset(AbstractHDF5Dataset):

def __init__(self, file_path, phase, slice_builder_config, transformer_config,
raw_internal_path='raw', label_internal_path='label', weight_internal_path=None,
global_normalization=False):
global_normalization=False, random_scale=None, random_scale_probability=0.5):
super().__init__(file_path=file_path, phase=phase, slice_builder_config=slice_builder_config,
transformer_config=transformer_config, raw_internal_path=raw_internal_path,
label_internal_path=label_internal_path, weight_internal_path=weight_internal_path,
global_normalization=global_normalization)
global_normalization=global_normalization, random_scale=random_scale,
random_scale_probability=random_scale_probability)

logger.info("Using LazyHDF5Dataset")

Expand Down
Loading

0 comments on commit fd774f9

Please sign in to comment.