diff --git a/TrackToLearn/algorithms/ddpg.py b/TrackToLearn/algorithms/ddpg.py index e79eacd..67cb4fb 100644 --- a/TrackToLearn/algorithms/ddpg.py +++ b/TrackToLearn/algorithms/ddpg.py @@ -11,7 +11,7 @@ from TrackToLearn.algorithms.shared.replay import OffPolicyReplayBuffer from TrackToLearn.algorithms.shared.utils import add_item_to_means from TrackToLearn.environments.env import BaseEnv - +from TrackToLearn.utils.torch_utils import get_device class DDPG(RLAlgorithm): """ @@ -44,7 +44,7 @@ def __init__( batch_size: int = 2**12, replay_size: int = 1e6, rng: np.random.RandomState = None, - device: torch.device = "cuda:0", + device: torch.device = get_device(), ): """ Parameters diff --git a/TrackToLearn/algorithms/rl.py b/TrackToLearn/algorithms/rl.py index 5029d3b..1d79b7f 100644 --- a/TrackToLearn/algorithms/rl.py +++ b/TrackToLearn/algorithms/rl.py @@ -2,7 +2,7 @@ import torch from TrackToLearn.environments.env import BaseEnv - +from TrackToLearn.utils.torch_utils import get_device class RLAlgorithm(object): """ @@ -18,7 +18,7 @@ def __init__( gamma: float = 0.99, batch_size: int = 10000, rng: np.random.RandomState = None, - device: torch.device = "cuda:0", + device: torch.device = get_device(), ): """ Parameters diff --git a/TrackToLearn/algorithms/sac.py b/TrackToLearn/algorithms/sac.py index ba1d00e..cff8b26 100644 --- a/TrackToLearn/algorithms/sac.py +++ b/TrackToLearn/algorithms/sac.py @@ -7,7 +7,7 @@ from TrackToLearn.algorithms.ddpg import DDPG from TrackToLearn.algorithms.shared.offpolicy import SACActorCritic from TrackToLearn.algorithms.shared.replay import OffPolicyReplayBuffer - +from TrackToLearn.utils.torch_utils import get_device class SAC(DDPG): """ @@ -41,7 +41,7 @@ def __init__( batch_size: int = 2**12, replay_size: int = 1e6, rng: np.random.RandomState = None, - device: torch.device = "cuda:0", + device: torch.device = get_device(), ): """ Initialize the algorithm. This includes the replay buffer, the policy and the target policy. diff --git a/TrackToLearn/algorithms/sac_auto.py b/TrackToLearn/algorithms/sac_auto.py index 80f860f..d669d71 100644 --- a/TrackToLearn/algorithms/sac_auto.py +++ b/TrackToLearn/algorithms/sac_auto.py @@ -9,7 +9,7 @@ from TrackToLearn.algorithms.sac import SAC from TrackToLearn.algorithms.shared.offpolicy import SACActorCritic from TrackToLearn.algorithms.shared.replay import OffPolicyReplayBuffer - +from TrackToLearn.utils.torch_utils import get_device LOG_STD_MAX = 2 LOG_STD_MIN = -20 @@ -46,7 +46,7 @@ def __init__( batch_size: int = 2**12, replay_size: int = 1e6, rng: np.random.RandomState = None, - device: torch.device = "cuda:0", + device: torch.device = get_device, ): """ Parameters diff --git a/TrackToLearn/algorithms/shared/replay.py b/TrackToLearn/algorithms/shared/replay.py index c55fde4..dbd558d 100644 --- a/TrackToLearn/algorithms/shared/replay.py +++ b/TrackToLearn/algorithms/shared/replay.py @@ -2,9 +2,9 @@ import torch from typing import Tuple +from TrackToLearn.utils.torch_utils import get_device, get_device_str - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = get_device() class OffPolicyReplayBuffer(object): @@ -33,16 +33,25 @@ def __init__( self.size = 0 # Buffers "filled with zeros" + + self.state = torch.zeros( - (self.max_size, state_dim), dtype=torch.float32).pin_memory() + (self.max_size, state_dim), dtype=torch.float32) self.action = torch.zeros( - (self.max_size, action_dim), dtype=torch.float32).pin_memory() + (self.max_size, action_dim), dtype=torch.float32) self.next_state = torch.zeros( - (self.max_size, state_dim), dtype=torch.float32).pin_memory() + (self.max_size, state_dim), dtype=torch.float32) self.reward = torch.zeros( - (self.max_size, 1), dtype=torch.float32).pin_memory() + (self.max_size, 1), dtype=torch.float32) self.not_done = torch.zeros( - (self.max_size, 1), dtype=torch.float32).pin_memory() + (self.max_size, 1), dtype=torch.float32) + + if get_device_str() == "cuda": + self.state = self.state.pin_memory() + self.action = self.action.pin_memory() + self.next_state = self.next_state.pin_memory() + self.reward = self.reward.pin_memory() + self.not_done = self.not_done.pin_memory() def add( self, @@ -112,12 +121,19 @@ def sample( ind = torch.randperm(self.size, dtype=torch.long)[ :min(self.size, batch_size)] - s = self.state.index_select(0, ind).pin_memory() - a = self.action.index_select(0, ind).pin_memory() - ns = self.next_state.index_select(0, ind).pin_memory() - r = self.reward.index_select(0, ind).squeeze(-1).pin_memory() + s = self.state.index_select(0, ind) + a = self.action.index_select(0, ind) + ns = self.next_state.index_select(0, ind) + r = self.reward.index_select(0, ind).squeeze(-1) d = self.not_done.index_select(0, ind).to( - dtype=torch.float32).squeeze(-1).pin_memory() + dtype=torch.float32).squeeze(-1) + + if get_device_str() == "cuda": + s = s.pin_memory() + a = a.pin_memory() + ns = ns.pin_memory() + r = r.pin_memory() + d = d.pin_memory() # Return tensors on the same device as the buffer in pinned memory return (s.to(device=self.device, non_blocking=True), diff --git a/TrackToLearn/environments/env.py b/TrackToLearn/environments/env.py index 163bc69..a241abd 100644 --- a/TrackToLearn/environments/env.py +++ b/TrackToLearn/environments/env.py @@ -34,6 +34,8 @@ # from dipy.io.utils import get_reference_info +def collate_fn(data): + return data class BaseEnv(object): """ @@ -84,9 +86,6 @@ def __init__( self.dataset_file = subject_data self.split = split_id - def collate_fn(data): - return data - self.dataset = SubjectDataset( self.dataset_file, self.split) self.loader = DataLoader(self.dataset, 1, shuffle=True, diff --git a/TrackToLearn/oracles/oracle.py b/TrackToLearn/oracles/oracle.py index f33d1ef..0614c8f 100644 --- a/TrackToLearn/oracles/oracle.py +++ b/TrackToLearn/oracles/oracle.py @@ -3,7 +3,10 @@ from dipy.tracking.streamline import set_number_of_points from TrackToLearn.oracles.transformer_oracle import TransformerOracle +from TrackToLearn.utils.torch_utils import get_device_str, get_device +import contextlib +autocast_context = torch.cuda.amp.autocast if torch.cuda.is_available() else contextlib.nullcontext class OracleSingleton: _self = None @@ -15,7 +18,7 @@ def __new__(cls, *args, **kwargs): return cls._self def __init__(self, checkpoint: str, device: str, batch_size=4096): - self.checkpoint = torch.load(checkpoint) + self.checkpoint = torch.load(checkpoint, map_location=get_device()) hyper_parameters = self.checkpoint["hyper_parameters"] # The model's class is saved in hparams @@ -38,7 +41,7 @@ def predict(self, streamlines): N = len(streamlines) # Placeholders for input and output data placeholder = torch.zeros( - (self.batch_size, 127, 3), pin_memory=True) + (self.batch_size, 127, 3), pin_memory=get_device_str() == "cuda") result = torch.zeros((N), dtype=torch.float, device=self.device) # Get the first batch @@ -70,7 +73,7 @@ def predict(self, streamlines): # Put the directions in pinned memory placeholder[:end-start] = torch.from_numpy(dirs) - with torch.cuda.amp.autocast(): + with autocast_context(): with torch.no_grad(): predictions = self.model(input_data) result[ diff --git a/TrackToLearn/runners/ttl_track.py b/TrackToLearn/runners/ttl_track.py index 777b870..04fa720 100755 --- a/TrackToLearn/runners/ttl_track.py +++ b/TrackToLearn/runners/ttl_track.py @@ -23,6 +23,7 @@ from TrackToLearn.experiment.experiment import Experiment from TrackToLearn.tracking.tracker import Tracker +from TrackToLearn.utils.torch_utils import get_device # Define the example model paths from the install folder. # Hackish ? I'm not aware of a better solution but I'm @@ -79,9 +80,7 @@ def __init__( self.compute_reward = False self.render = False - self.device = torch.device( - "cuda" if torch.cuda.is_available() - else "cpu") + self.device = get_device() self.fa_map = None if 'fa_map_file' in track_dto: diff --git a/TrackToLearn/runners/ttl_track_from_hdf5.py b/TrackToLearn/runners/ttl_track_from_hdf5.py index 9f70818..38ba2c4 100755 --- a/TrackToLearn/runners/ttl_track_from_hdf5.py +++ b/TrackToLearn/runners/ttl_track_from_hdf5.py @@ -23,7 +23,7 @@ add_tractometer_args) from TrackToLearn.tracking.tracker import Tracker from TrackToLearn.experiment.experiment import Experiment - +from TrackToLearn.utils.torch_utils import get_device class TrackToLearnValidation(Experiment): """ TrackToLearn validing script. Should work on any model trained with a @@ -98,8 +98,7 @@ def __init__( self.comet_experiment = None - self.device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu") + self.device = get_device() self.random_seed = valid_dto['rng_seed'] torch.manual_seed(self.random_seed) diff --git a/TrackToLearn/searchers/sac_auto_searcher.py b/TrackToLearn/searchers/sac_auto_searcher.py index d3ae92d..3622f80 100644 --- a/TrackToLearn/searchers/sac_auto_searcher.py +++ b/TrackToLearn/searchers/sac_auto_searcher.py @@ -5,9 +5,9 @@ from TrackToLearn.trainers.sac_auto_train import ( parse_args, SACAutoTrackToLearnTraining) - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -assert torch.cuda.is_available() +from TrackToLearn.utils.torch_utils import get_device, assert_accelerator +device = get_device() +assert_accelerator() def main(): diff --git a/TrackToLearn/searchers/sac_auto_searcher_oracle.py b/TrackToLearn/searchers/sac_auto_searcher_oracle.py index 5b7fd4f..e0b749a 100644 --- a/TrackToLearn/searchers/sac_auto_searcher_oracle.py +++ b/TrackToLearn/searchers/sac_auto_searcher_oracle.py @@ -5,9 +5,9 @@ from TrackToLearn.trainers.sac_auto_train import ( parse_args, SACAutoTrackToLearnTraining) - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -assert torch.cuda.is_available() +from TrackToLearn.utils.torch_utils import get_device, assert_accelerator +device = get_device() +assert_accelerator() def main(): diff --git a/TrackToLearn/trainers/ddpg_train.py b/TrackToLearn/trainers/ddpg_train.py index 14996e9..aa43c9c 100644 --- a/TrackToLearn/trainers/ddpg_train.py +++ b/TrackToLearn/trainers/ddpg_train.py @@ -9,9 +9,10 @@ from TrackToLearn.algorithms.ddpg import DDPG from TrackToLearn.experiment.train import ( add_training_args, TrackToLearnTraining) +from TrackToLearn.utils.torch_utils import get_device, assert_accelerator -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -assert torch.cuda.is_available() +device = get_device() +assert_accelerator() class DDPGTrackToLearnTraining(TrackToLearnTraining): diff --git a/TrackToLearn/trainers/sac_auto_train.py b/TrackToLearn/trainers/sac_auto_train.py index bbe7702..094b576 100755 --- a/TrackToLearn/trainers/sac_auto_train.py +++ b/TrackToLearn/trainers/sac_auto_train.py @@ -12,8 +12,8 @@ from TrackToLearn.algorithms.sac_auto import SACAuto from TrackToLearn.trainers.train import (TrackToLearnTraining, add_training_args) - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +from TrackToLearn.utils.torch_utils import get_device +device = get_device() class SACAutoTrackToLearnTraining(TrackToLearnTraining): diff --git a/TrackToLearn/trainers/sac_train.py b/TrackToLearn/trainers/sac_train.py index dedf04f..cc7a3d3 100644 --- a/TrackToLearn/trainers/sac_train.py +++ b/TrackToLearn/trainers/sac_train.py @@ -16,9 +16,9 @@ from TrackToLearn.experiment.train import ( add_rl_args, TrackToLearnTraining) - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -assert torch.cuda.is_available() +from TrackToLearn.utils.torch_utils import get_device, assert_accelerator +device = get_device() +assert_accelerator() class SACTrackToLearnTraining(TrackToLearnTraining): diff --git a/TrackToLearn/trainers/td3_train.py b/TrackToLearn/trainers/td3_train.py index 690fbe2..b5832b3 100644 --- a/TrackToLearn/trainers/td3_train.py +++ b/TrackToLearn/trainers/td3_train.py @@ -16,9 +16,10 @@ from TrackToLearn.experiment.train import ( add_rl_args, TrackToLearnTraining) +from TrackToLearn.utils.torch_utils import get_device, assert_accelerator -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -assert torch.cuda.is_available() +device = get_device() +assert_accelerator() class TD3TrackToLearnTraining(TrackToLearnTraining): diff --git a/TrackToLearn/trainers/train.py b/TrackToLearn/trainers/train.py index d33dd46..3f80c17 100644 --- a/TrackToLearn/trainers/train.py +++ b/TrackToLearn/trainers/train.py @@ -21,6 +21,7 @@ from TrackToLearn.experiment.tractometer_validator import TractometerValidator from TrackToLearn.experiment.experiment import Experiment from TrackToLearn.tracking.tracker import Tracker +from TrackToLearn.utils.torch_utils import get_device, assert_accelerator class TrackToLearnTraining(Experiment): @@ -99,8 +100,8 @@ def __init__( self.comet_experiment = comet_experiment self.last_episode = 0 - self.device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu") + self.device = get_device() + self.use_comet = train_dto['use_comet'] @@ -353,8 +354,8 @@ def run(self): training loop """ - assert torch.cuda.is_available(), \ - "Training is only supported on CUDA devices." + assert_accelerator(), \ + "Training is only supported with hardware accelerated devices." # Instantiate environment. Actions will be fed to it and new # states will be returned. The environment updates the streamline diff --git a/TrackToLearn/utils/torch_utils.py b/TrackToLearn/utils/torch_utils.py new file mode 100644 index 0000000..6e78ba4 --- /dev/null +++ b/TrackToLearn/utils/torch_utils.py @@ -0,0 +1,15 @@ +import torch + +def get_device(): + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + else: + return torch.device("cpu") + +def assert_accelerator(): + assert torch.cuda.is_available() or torch.backends.mps.is_available() + +def get_device_str(): + return str(get_device())