Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 29, 2024
1 parent fcd5467 commit 33d5a52
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 13 deletions.
8 changes: 5 additions & 3 deletions src/autoembedder/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@ def _predict(
device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available() and parameters.get("use_mps", False)
else "cpu"
else (
"mps"
if torch.backends.mps.is_available() and parameters.get("use_mps", False)
else "cpu"
)
)

with torch.no_grad():
Expand Down
19 changes: 12 additions & 7 deletions src/autoembedder/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,12 @@ def fit(
torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available() and parameters.get("use_mps", False)
else "cpu"
else (
"mps"
if torch.backends.mps.is_available()
and parameters.get("use_mps", False)
else "cpu"
)
)
)
if (
Expand Down Expand Up @@ -437,10 +440,12 @@ def fit(
map_location=torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
and parameters.get("use_mps", False)
else "cpu"
else (
"mps"
if torch.backends.mps.is_available()
and parameters.get("use_mps", False)
else "cpu"
)
),
)
Checkpoint.load_objects(
Expand Down
8 changes: 5 additions & 3 deletions src/autoembedder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@ def model_input(
device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available() and parameters.get("use_mps", False)
else "cpu"
else (
"mps"
if torch.backends.mps.is_available() and parameters.get("use_mps", False)
else "cpu"
)
)
cat = []
cont = []
Expand Down

0 comments on commit 33d5a52

Please sign in to comment.