Skip to content

Commit

Permalink
Support Torch nightly builds
Browse files Browse the repository at this point in the history
  • Loading branch information
nevillelyh committed Oct 2, 2024
1 parent c9841eb commit e4ac360
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/monogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
cuda={'12.4': '12.4.1_550.54.15'},
cudnn={'9': '9.1.0.70'},
python={'3.12': '3.12.6'},
torch=['2.4.1'],
torch=['2.4.1', '2.6.0.dev20240918'],
pip_pkgs=['cog==0.9.23', 'opencv-python==4.10.0.84']
),
]
Expand Down
6 changes: 6 additions & 0 deletions src/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

# https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
torch_specs_dict = {
# Nightly
'2.6': ('3.11', '3.12', ['12.4']),
# Releases
'2.5': ('3.9', '3.12', ['11.8', '12.1', '12.4']),
'2.4': ('3.8', '3.12', ['11.8', '12.1', '12.4']),
'2.3': ('3.8', '3.11', ['11.8', '12.1']),
Expand All @@ -17,6 +20,9 @@
}

torch_deps_dict = {
# Nightly
'2.6.0.dev20240918': TorchDeps('2.5.0.dev20240918', '0.20.0.dev20240918'),
# Releases
'2.4.1': TorchDeps('2.4.1', '0.19.1'),
'2.4.0': TorchDeps('2.4.0', '0.19.0'),
'2.3.1': TorchDeps('2.3.1', '0.18.1'),
Expand Down
7 changes: 4 additions & 3 deletions src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@
logger.addHandler(handler)


class Version(namedtuple('Version', ['major', 'minor', 'patch', 'repr'])):
p = re.compile(r'^(?P<major>\d+)(\.(?P<minor>\d+)(\.(?P<patch>\d+))?)?')
class Version(namedtuple('Version', ['major', 'minor', 'patch', 'extra','repr'])):
p = re.compile(r'^(?P<major>\d+)(\.(?P<minor>\d+)(\.(?P<patch>\d+)(\.(?P<extra>.+))?)?)?')

@classmethod
def parse(cls, s):
m = Version.p.search(s)
major = int(m.group('major'))
minor = int(m.group('minor') or 0)
patch = int(m.group('patch') or 0)
return cls(major, minor, patch, s)
extra = m.group('extra')
return cls(major, minor, patch, extra, s)

def __repr__(self):
return self.repr
Expand Down
14 changes: 9 additions & 5 deletions src/uv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ def cuda_suffix(cuda_version: str) -> str:
return f'cu{cuda_version.replace('.', '')}'


def pip_index_url(cuda_version: str):
return f'https://download.pytorch.org/whl/{cuda_suffix(cuda_version)}'
def pip_index_url(torch_version: Version, cuda_version: str):
prefix = 'https://download.pytorch.org/whl'
if torch_version.extra:
prefix = f'{prefix}/nightly'
return f'{prefix}/{cuda_suffix(cuda_version)}'


def pip_packages(torch_version: Version,
Expand Down Expand Up @@ -59,6 +62,7 @@ def update_venv(
subprocess.run(cmd, check=True)

logger.info(f'Running pip compile in {venv}...')
url = pip_index_url(t, cuda_version)
cmd = [
'docker', 'run', '-i', '--rm',
'--env', f'VIRTUAL_ENV={venv_dir}',
Expand All @@ -67,15 +71,15 @@ def update_venv(
'ghcr.io/astral-sh/uv:debian',
'uv', 'pip', 'compile',
'--python-platform', 'x86_64-unknown-linux-gnu',
'--extra-index-url', pip_index_url(cuda_version),
'--extra-index-url', url,
'--emit-index-url',
'--emit-find-links',
'--emit-build-options',
'--emit-index-annotation',
'-',
]
pkgs = pip_packages(t, cuda_version, pip_pkgs)
proc = subprocess.run(cmd, input='\n'.join(pkgs), capture_output=True, text=True)
proc = subprocess.run(cmd, check=True, input='\n'.join(pkgs), capture_output=True, text=True)

requirements = os.path.join(rdir, f'{venv}.txt')
with open(requirements, 'w') as f:
Expand Down Expand Up @@ -111,7 +115,7 @@ def install_venv(args: argparse.Namespace,

logger.info(f'Installing Torch {t} in {venv}...')

url = pip_index_url(cuda_version)
url = pip_index_url(t, cuda_version)
requirements = os.path.join(rdir, f'{venv}.txt')
cmd = [
uv, 'pip', 'install',
Expand Down

0 comments on commit e4ac360

Please sign in to comment.