diff --git a/src/autoembedder/evaluator.py b/src/autoembedder/evaluator.py index fde110d..2f2b7ca 100644 --- a/src/autoembedder/evaluator.py +++ b/src/autoembedder/evaluator.py @@ -37,9 +37,11 @@ def _predict( device = torch.device( "cuda" if torch.cuda.is_available() - else "mps" - if torch.backends.mps.is_available() and parameters.get("use_mps", False) - else "cpu" + else ( + "mps" + if torch.backends.mps.is_available() and parameters.get("use_mps", False) + else "cpu" + ) ) with torch.no_grad(): diff --git a/src/autoembedder/learner.py b/src/autoembedder/learner.py index 3bc5610..cf2533b 100644 --- a/src/autoembedder/learner.py +++ b/src/autoembedder/learner.py @@ -362,9 +362,12 @@ def fit( torch.device( "cuda" if torch.cuda.is_available() - else "mps" - if torch.backends.mps.is_available() and parameters.get("use_mps", False) - else "cpu" + else ( + "mps" + if torch.backends.mps.is_available() + and parameters.get("use_mps", False) + else "cpu" + ) ) ) if ( @@ -437,10 +440,12 @@ def fit( map_location=torch.device( "cuda" if torch.cuda.is_available() - else "mps" - if torch.backends.mps.is_available() - and parameters.get("use_mps", False) - else "cpu" + else ( + "mps" + if torch.backends.mps.is_available() + and parameters.get("use_mps", False) + else "cpu" + ) ), ) Checkpoint.load_objects( diff --git a/src/autoembedder/model.py b/src/autoembedder/model.py index 5a1296a..700ea2f 100644 --- a/src/autoembedder/model.py +++ b/src/autoembedder/model.py @@ -58,9 +58,11 @@ def model_input( device = torch.device( "cuda" if torch.cuda.is_available() - else "mps" - if torch.backends.mps.is_available() and parameters.get("use_mps", False) - else "cpu" + else ( + "mps" + if torch.backends.mps.is_available() and parameters.get("use_mps", False) + else "cpu" + ) ) cat = [] cont = []