Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Metal Performance Shaders #41

Merged
merged 12 commits into from
May 29, 2024
4 changes: 2 additions & 2 deletions TrackToLearn/algorithms/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from TrackToLearn.algorithms.shared.replay import OffPolicyReplayBuffer
from TrackToLearn.algorithms.shared.utils import add_item_to_means
from TrackToLearn.environments.env import BaseEnv

from TrackToLearn.utils.torch_utils import get_device

class DDPG(RLAlgorithm):
"""
Expand Down Expand Up @@ -44,7 +44,7 @@ def __init__(
batch_size: int = 2**12,
replay_size: int = 1e6,
rng: np.random.RandomState = None,
device: torch.device = "cuda:0",
device: torch.device = get_device(),
):
"""
Parameters
Expand Down
4 changes: 2 additions & 2 deletions TrackToLearn/algorithms/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch

from TrackToLearn.environments.env import BaseEnv

from TrackToLearn.utils.torch_utils import get_device

class RLAlgorithm(object):
"""
Expand All @@ -18,7 +18,7 @@ def __init__(
gamma: float = 0.99,
batch_size: int = 10000,
rng: np.random.RandomState = None,
device: torch.device = "cuda:0",
device: torch.device = get_device(),
):
"""
Parameters
Expand Down
4 changes: 2 additions & 2 deletions TrackToLearn/algorithms/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from TrackToLearn.algorithms.ddpg import DDPG
from TrackToLearn.algorithms.shared.offpolicy import SACActorCritic
from TrackToLearn.algorithms.shared.replay import OffPolicyReplayBuffer

from TrackToLearn.utils.torch_utils import get_device

class SAC(DDPG):
"""
Expand Down Expand Up @@ -41,7 +41,7 @@ def __init__(
batch_size: int = 2**12,
replay_size: int = 1e6,
rng: np.random.RandomState = None,
device: torch.device = "cuda:0",
device: torch.device = get_device(),
):
""" Initialize the algorithm. This includes the replay buffer,
the policy and the target policy.
Expand Down
4 changes: 2 additions & 2 deletions TrackToLearn/algorithms/sac_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from TrackToLearn.algorithms.sac import SAC
from TrackToLearn.algorithms.shared.offpolicy import SACActorCritic
from TrackToLearn.algorithms.shared.replay import OffPolicyReplayBuffer

from TrackToLearn.utils.torch_utils import get_device

LOG_STD_MAX = 2
LOG_STD_MIN = -20
Expand Down Expand Up @@ -46,7 +46,7 @@ def __init__(
batch_size: int = 2**12,
replay_size: int = 1e6,
rng: np.random.RandomState = None,
device: torch.device = "cuda:0",
device: torch.device = get_device,
):
"""
Parameters
Expand Down
40 changes: 28 additions & 12 deletions TrackToLearn/algorithms/shared/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import torch

from typing import Tuple
from TrackToLearn.utils.torch_utils import get_device, get_device_str


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = get_device()


class OffPolicyReplayBuffer(object):
Expand Down Expand Up @@ -33,16 +33,25 @@ def __init__(
self.size = 0

# Buffers "filled with zeros"


self.state = torch.zeros(
(self.max_size, state_dim), dtype=torch.float32).pin_memory()
(self.max_size, state_dim), dtype=torch.float32)
self.action = torch.zeros(
(self.max_size, action_dim), dtype=torch.float32).pin_memory()
(self.max_size, action_dim), dtype=torch.float32)
self.next_state = torch.zeros(
(self.max_size, state_dim), dtype=torch.float32).pin_memory()
(self.max_size, state_dim), dtype=torch.float32)
self.reward = torch.zeros(
(self.max_size, 1), dtype=torch.float32).pin_memory()
(self.max_size, 1), dtype=torch.float32)
self.not_done = torch.zeros(
(self.max_size, 1), dtype=torch.float32).pin_memory()
(self.max_size, 1), dtype=torch.float32)

if get_device_str() == "cuda":
self.state = self.state.pin_memory()
self.action = self.action.pin_memory()
self.next_state = self.next_state.pin_memory()
self.reward = self.reward.pin_memory()
self.not_done = self.not_done.pin_memory()

def add(
self,
Expand Down Expand Up @@ -112,12 +121,19 @@ def sample(
ind = torch.randperm(self.size, dtype=torch.long)[
:min(self.size, batch_size)]

s = self.state.index_select(0, ind).pin_memory()
a = self.action.index_select(0, ind).pin_memory()
ns = self.next_state.index_select(0, ind).pin_memory()
r = self.reward.index_select(0, ind).squeeze(-1).pin_memory()
s = self.state.index_select(0, ind)
a = self.action.index_select(0, ind)
ns = self.next_state.index_select(0, ind)
r = self.reward.index_select(0, ind).squeeze(-1)
d = self.not_done.index_select(0, ind).to(
dtype=torch.float32).squeeze(-1).pin_memory()
dtype=torch.float32).squeeze(-1)

if get_device_str() == "cuda":
s = s.pin_memory()
a = a.pin_memory()
ns = ns.pin_memory()
r = r.pin_memory()
d = d.pin_memory()

# Return tensors on the same device as the buffer in pinned memory
return (s.to(device=self.device, non_blocking=True),
Expand Down
4 changes: 2 additions & 2 deletions TrackToLearn/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dipy.data import get_sphere
from dipy.reconst.csdeconv import sph_harm_ind_list
from scilpy.reconst.utils import get_sh_order_and_fullness
from scilpy.reconst.multi_processes import convert_sh_basis
from scilpy.reconst.sh import convert_sh_basis


class MRIDataVolume(object):
Expand Down Expand Up @@ -155,7 +155,7 @@ def set_sh_order_basis(
_, orders = sph_harm_ind_list(sh_order, full_basis)
sh = sh[..., orders % 2 == 0]

# If SH are not of order 6, convert them
# If SH are not of target order, convert them
if sh_order != target_order:
print('SH coefficients are of order {}, '
'converting them to order {}.'.format(sh_order, target_order))
Expand Down
40 changes: 30 additions & 10 deletions TrackToLearn/environments/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
interpolate_volume_in_neighborhood
from dwi_ml.data.processing.space.neighborhood import \
get_neighborhood_vectors_axes
from scilpy.reconst.utils import (find_order_from_nb_coeff, get_b_matrix,
from scilpy.reconst.utils import (find_order_from_nb_coeff,
get_maximas)
from dipy.reconst.shm import sh_to_sf_matrix
from torch.utils.data import DataLoader

from TrackToLearn.datasets.SubjectDataset import SubjectDataset
from TrackToLearn.datasets.utils import (MRIDataVolume,
convert_length_mm2vox,
set_sh_order_basis)
set_sh_order_basis,
get_sh_order_and_fullness)
from TrackToLearn.environments.local_reward import PeaksAlignmentReward
from TrackToLearn.environments.oracle_reward import OracleReward
from TrackToLearn.environments.reward import RewardFunction
Expand All @@ -32,6 +34,8 @@

# from dipy.io.utils import get_reference_info

def collate_fn(data):
return data

class BaseEnv(object):
"""
Expand Down Expand Up @@ -82,9 +86,6 @@ def __init__(
self.dataset_file = subject_data
self.split = split_id

def collate_fn(data):
return data

self.dataset = SubjectDataset(
self.dataset_file, self.split)
self.loader = DataLoader(self.dataset, 1, shuffle=True,
Expand Down Expand Up @@ -134,6 +135,7 @@ def collate_fn(data):
# Other parameters
self.rng = env_dto['rng']
self.device = env_dto['device']
self.target_sh_order = env_dto['target_sh_order']

# Load one subject as an example
self.load_subject()
Expand All @@ -156,7 +158,7 @@ def load_subject(
self.loader_iter = iter(self.loader)
(sub_id, input_volume, tracking_mask, seeding_mask,
peaks, reference) = next(self.loader_iter)[0]

self.subject_id = sub_id
# Affines
self.reference = reference
Expand All @@ -179,6 +181,13 @@ def load_subject(

self.reference = reference

# The SH target order is taken from the hyperparameters in the case of tracking.
# Otherwise, the SH target order is taken from the input volume by default.
if self.target_sh_order is None:
n_coefs = input_volume.shape[-1]
sh_order, _ = get_sh_order_and_fullness(n_coefs)
self.target_sh_order = sh_order

self.tracking_mask = tracking_mask
self.peaks = peaks
mask_data = tracking_mask.data.astype(np.uint8)
Expand Down Expand Up @@ -322,13 +331,15 @@ def from_files(
in_mask = env_dto['in_mask']
sh_basis = env_dto['sh_basis']
reference = env_dto['reference']
target_sh_order = env_dto['target_sh_order']

(input_volume, peaks_volume, tracking_mask, seeding_mask) = \
BaseEnv._load_files(
in_odf,
in_seed,
in_mask,
sh_basis)
sh_basis,
target_sh_order)

subj_files = (input_volume, tracking_mask, seeding_mask,
peaks_volume, reference)
Expand All @@ -342,6 +353,7 @@ def _load_files(
in_seed,
in_mask,
sh_basis,
target_sh_order=6,
):
""" Load data volumes and masks from files. This is useful for
tracking from a trained model.
Expand All @@ -360,6 +372,8 @@ def _load_files(
Path to the tracking mask.
sh_basis: str
Basis of the SH coefficients.
target_sh_order: int
Target SH order. Should come from the hyperparameters file.

Returns
-------
Expand All @@ -385,7 +399,7 @@ def _load_files(

data = set_sh_order_basis(signal.get_fdata(dtype=np.float32),
sh_basis,
target_order=8,
target_order=target_sh_order,
target_basis='descoteaux07')

# Compute peaks from signal
Expand All @@ -398,8 +412,7 @@ def _load_files(
sphere = HemiSphere.from_sphere(get_sphere("repulsion724")
).subdivide(0)

b_matrix = get_b_matrix(
find_order_from_nb_coeff(data), sphere, "descoteaux07")
b_matrix, _ = sh_to_sf_matrix(sphere, find_order_from_nb_coeff(data), "descoteaux07")

for idx in np.argwhere(np.sum(data, axis=-1)):
idx = tuple(idx)
Expand Down Expand Up @@ -454,6 +467,13 @@ def get_action_size(self):
"""

return 3

def get_target_sh_order(self):
""" Returns the target SH order. For tracking, this is based on the hyperparameters.json if it's specified.
Otherwise, it's extracted from the data directly.
"""

return self.target_sh_order

def get_voxel_size(self):
""" Returns the voxel size by taking the mean value of the diagonal
Expand Down
12 changes: 8 additions & 4 deletions TrackToLearn/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def _get_env_dict_and_dto(
'tractometer_validator': self.tractometer_validator,
'binary_stopping_threshold': self.binary_stopping_threshold,
'compute_reward': self.compute_reward,
'device': self.device
'device': self.device,
'target_sh_order': self.target_sh_order if hasattr(self, 'target_sh_order') else None,
}

if noisy:
Expand Down Expand Up @@ -257,7 +258,8 @@ def save_rasmm_tractogram(
tractogram,
subject_id: str,
affine: np.ndarray,
reference: nib.Nifti1Image
reference: nib.Nifti1Image,
path_prefix: str = ''
) -> str:
"""
Saves a non-stateful tractogram from the training/validation
Expand All @@ -277,8 +279,9 @@ def save_rasmm_tractogram(
# Save tractogram so it can be looked at, used by the tractometer
# and more
filename = pjoin(
self.experiment_path, "tractogram_{}_{}_{}.trk".format(
self.experiment, self.name, subject_id))
path_prefix,
self.experiment_path,
"tractogram_{}_{}_{}.trk".format(self.experiment, self.name, subject_id))

# Prune empty streamlines, keep only streamlines that have more
# than the seed.
Expand Down Expand Up @@ -390,6 +393,7 @@ def add_experiment_args(parser: ArgumentParser):
help='Seed to fix general randomness')
parser.add_argument('--use_comet', action='store_true',
help='Use comet to display training or not')
parser.add_argument('--comet_offline_dir', type=str, help='Comet offline directory. If enabled, logs will be saved to this directory and the experiment will be ran offline.')


def add_data_args(parser: ArgumentParser):
Expand Down
9 changes: 6 additions & 3 deletions TrackToLearn/oracles/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from dipy.tracking.streamline import set_number_of_points

from TrackToLearn.oracles.transformer_oracle import TransformerOracle
from TrackToLearn.utils.torch_utils import get_device_str, get_device
import contextlib

autocast_context = torch.cuda.amp.autocast if torch.cuda.is_available() else contextlib.nullcontext

class OracleSingleton:
_self = None
Expand All @@ -15,7 +18,7 @@ def __new__(cls, *args, **kwargs):
return cls._self

def __init__(self, checkpoint: str, device: str, batch_size=4096):
self.checkpoint = torch.load(checkpoint)
self.checkpoint = torch.load(checkpoint, map_location=get_device())

hyper_parameters = self.checkpoint["hyper_parameters"]
# The model's class is saved in hparams
Expand All @@ -38,7 +41,7 @@ def predict(self, streamlines):
N = len(streamlines)
# Placeholders for input and output data
placeholder = torch.zeros(
(self.batch_size, 127, 3), pin_memory=True)
(self.batch_size, 127, 3), pin_memory=get_device_str() == "cuda")
result = torch.zeros((N), dtype=torch.float, device=self.device)

# Get the first batch
Expand Down Expand Up @@ -70,7 +73,7 @@ def predict(self, streamlines):
# Put the directions in pinned memory
placeholder[:end-start] = torch.from_numpy(dirs)

with torch.cuda.amp.autocast():
with autocast_context():
with torch.no_grad():
predictions = self.model(input_data)
result[
Expand Down
Loading
Loading