Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Should be merged after: #40
Allow the support of Metal Performance Shaders (MPS) backend to train on a Mac M1. Whilst this is not meant to perform a complete training (could still be possible), this is meant to at least be able to debug and run the training process without access to CUDA.
I basically changed everything that was hardcoded "cuda" or "cuda:0" by utility functions so that we can add multiple backends easily. However, to actually use MPS to train a model, PyTorch 2.4.0 will be needed once it's out (there was a known bug that just got fixed). For now, we can just install a recent nighty build (torch==2.4.0.dev20240520, torchaudio==2.2.0.dev20240520, torchvision==0.19.0.dev20240520. With index: https://download.pytorch.org/whl/nightly/) manually and everything should work fine.
So:
pip install torch==2.4.0.dev20240520 torchaudio==2.2.0.dev20240520 torchvision==0.19.0.dev20240520 --index=https://download.pytorch.org/whl/nightly/
I propose we merge this to support MPS everywhere in the code and so it's available for all subsequent branches, and if we want to use it locally we can just install the torch nightly build.
I'll take note of upgrading to torch 2.4.0 once the official release is out in late July.