diff --git a/src/monobase/build.py b/src/monobase/build.py index 8a9f496..42d514c 100644 --- a/src/monobase/build.py +++ b/src/monobase/build.py @@ -183,7 +183,6 @@ def assert_env(e: str) -> None: assert_env('R8_COG_VERSION') assert_env('R8_PYTHON_VERSION') - assert_env('R8_TORCH_VERSION') if not args.skip_cuda: assert_env('R8_CUDA_VERSION') assert_env('R8_CUDNN_VERSION') @@ -208,6 +207,12 @@ def pick( raise return {} + torch = [] + if 'R8_TORCH_VERSION' in os.environ: + torch.append(os.environ['R8_TORCH_VERSION']) + else: + torch.append(None) + monogens = [ MonoGen( id=mg.id, @@ -216,7 +221,7 @@ def pick( mg.cudnn, 'R8_CUDNN_VERSION', fail_on_empty=not args.skip_cuda ), python=pick(mg.python, 'R8_PYTHON_VERSION'), - torch=[os.environ['R8_TORCH_VERSION']], + torch=torch, pip_pkgs=mg.pip_pkgs, ) ] diff --git a/src/monobase/user.py b/src/monobase/user.py index c8600f0..3c77696 100644 --- a/src/monobase/user.py +++ b/src/monobase/user.py @@ -34,7 +34,7 @@ def build_user_venv(args: argparse.Namespace) -> None: log.info(f'Building user venv {udir}...') python_version = os.environ['R8_PYTHON_VERSION'] - torch_version = os.environ['R8_TORCH_VERSION'] + torch_version = os.environ.get('R8_TORCH_VERSION') cuda_version = os.environ.get('R8_CUDA_VERSION', 'cpu') uv = os.path.join(args.prefix, 'bin', 'uv') @@ -48,7 +48,11 @@ def build_user_venv(args: argparse.Namespace) -> None: cog_versions = parse_requirements(cog_req) gdir = os.path.realpath(os.path.join(args.prefix, 'monobase', 'latest')) - venv = f'python{python_version}-torch{torch_version}-{cuda_suffix(cuda_version)}' + venv_components = [f'python{python_version}'] + if torch_version is not None: + venv_components.append(f'torch{torch_version}') + venv_components.append(f'{cuda_suffix(cuda_version)}') + venv = '-'.join(venv_components) vdir = os.path.join(gdir, venv) log.info(f'Freezing monobase venv {vdir}...') mono_req = freeze(uv, vdir)