From 6bd28e5f3dae11c8ce4035c2684781ba41f47f3a Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Wed, 18 Dec 2024 17:10:25 -0500 Subject: [PATCH] Support no CUDA version specification (#43) * Allow monobases to be created with no defined cuda version --- src/monobase/build.py | 24 +++++++++++++++++------- src/monobase/user.py | 2 +- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/monobase/build.py b/src/monobase/build.py index 699f449..8a9f496 100644 --- a/src/monobase/build.py +++ b/src/monobase/build.py @@ -182,10 +182,11 @@ def assert_env(e: str) -> None: assert os.environ.get(e) is not None, f'{e} is required for mini mono' assert_env('R8_COG_VERSION') - assert_env('R8_CUDA_VERSION') - assert_env('R8_CUDNN_VERSION') assert_env('R8_PYTHON_VERSION') assert_env('R8_TORCH_VERSION') + if not args.skip_cuda: + assert_env('R8_CUDA_VERSION') + assert_env('R8_CUDNN_VERSION') assert ( args.cog_versions is None @@ -196,15 +197,24 @@ def assert_env(e: str) -> None: args.cog_versions = [os.environ['R8_COG_VERSION']] args.default_cog_version = os.environ['R8_COG_VERSION'] - def pick(d: dict[str, str], env: str) -> dict[str, str]: - key = os.environ[env] - return {key: d[key]} + def pick( + d: dict[str, str], env: str, fail_on_empty: bool = True + ) -> dict[str, str]: + try: + key = os.environ[env] + return {key: d[key]} + except KeyError: + if fail_on_empty: + raise + return {} monogens = [ MonoGen( id=mg.id, - cuda=pick(mg.cuda, 'R8_CUDA_VERSION'), - cudnn=pick(mg.cudnn, 'R8_CUDNN_VERSION'), + cuda=pick(mg.cuda, 'R8_CUDA_VERSION', fail_on_empty=not args.skip_cuda), + cudnn=pick( + mg.cudnn, 'R8_CUDNN_VERSION', fail_on_empty=not args.skip_cuda + ), python=pick(mg.python, 'R8_PYTHON_VERSION'), torch=[os.environ['R8_TORCH_VERSION']], pip_pkgs=mg.pip_pkgs, diff --git a/src/monobase/user.py b/src/monobase/user.py index 0bd47c6..c8600f0 100644 --- a/src/monobase/user.py +++ b/src/monobase/user.py @@ -35,7 +35,7 @@ def build_user_venv(args: argparse.Namespace) -> None: python_version = os.environ['R8_PYTHON_VERSION'] torch_version = os.environ['R8_TORCH_VERSION'] - cuda_version = os.environ['R8_CUDA_VERSION'] + cuda_version = os.environ.get('R8_CUDA_VERSION', 'cpu') uv = os.path.join(args.prefix, 'bin', 'uv')