Skip to content

Commit

Permalink
Support empty torch versions
Browse files Browse the repository at this point in the history
* Allow sending torch version as none
* If this is none then torch will not be installed
in the venv
* This is due to not all cog users wanting to
install torch, so if they don’t we may as well not
install it into the venv
  • Loading branch information
8W9aG authored and nevillelyh committed Jan 6, 2025
1 parent e526edf commit a0748a8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
9 changes: 7 additions & 2 deletions src/monobase/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ def assert_env(e: str) -> None:

assert_env('R8_COG_VERSION')
assert_env('R8_PYTHON_VERSION')
assert_env('R8_TORCH_VERSION')
if not args.skip_cuda:
assert_env('R8_CUDA_VERSION')
assert_env('R8_CUDNN_VERSION')
Expand All @@ -208,6 +207,12 @@ def pick(
raise
return {}

torch = []
if 'R8_TORCH_VERSION' in os.environ:
torch.append(os.environ['R8_TORCH_VERSION'])
else:
torch.append(None)

monogens = [
MonoGen(
id=mg.id,
Expand All @@ -216,7 +221,7 @@ def pick(
mg.cudnn, 'R8_CUDNN_VERSION', fail_on_empty=not args.skip_cuda
),
python=pick(mg.python, 'R8_PYTHON_VERSION'),
torch=[os.environ['R8_TORCH_VERSION']],
torch=torch,
pip_pkgs=mg.pip_pkgs,
)
]
Expand Down
8 changes: 6 additions & 2 deletions src/monobase/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def build_user_venv(args: argparse.Namespace) -> None:
log.info(f'Building user venv {udir}...')

python_version = os.environ['R8_PYTHON_VERSION']
torch_version = os.environ['R8_TORCH_VERSION']
torch_version = os.environ.get('R8_TORCH_VERSION')
cuda_version = os.environ.get('R8_CUDA_VERSION', 'cpu')

uv = os.path.join(args.prefix, 'bin', 'uv')
Expand All @@ -48,7 +48,11 @@ def build_user_venv(args: argparse.Namespace) -> None:
cog_versions = parse_requirements(cog_req)

gdir = os.path.realpath(os.path.join(args.prefix, 'monobase', 'latest'))
venv = f'python{python_version}-torch{torch_version}-{cuda_suffix(cuda_version)}'
venv_components = [f'python{python_version}']
if torch_version is not None:
venv_components.append(f'torch{torch_version}')
venv_components.append(f'{cuda_suffix(cuda_version)}')
venv = '-'.join(venv_components)
vdir = os.path.join(gdir, venv)
log.info(f'Freezing monobase venv {vdir}...')
mono_req = freeze(uv, vdir)
Expand Down

0 comments on commit a0748a8

Please sign in to comment.