Skip to content

Commit

Permalink
Support no CUDA version specification (#43)
Browse files Browse the repository at this point in the history
* Allow monobases to be created with no defined
cuda version
  • Loading branch information
8W9aG authored Dec 18, 2024
1 parent 4bbd34e commit 6bd28e5
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
24 changes: 17 additions & 7 deletions src/monobase/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,11 @@ def assert_env(e: str) -> None:
assert os.environ.get(e) is not None, f'{e} is required for mini mono'

assert_env('R8_COG_VERSION')
assert_env('R8_CUDA_VERSION')
assert_env('R8_CUDNN_VERSION')
assert_env('R8_PYTHON_VERSION')
assert_env('R8_TORCH_VERSION')
if not args.skip_cuda:
assert_env('R8_CUDA_VERSION')
assert_env('R8_CUDNN_VERSION')

assert (
args.cog_versions is None
Expand All @@ -196,15 +197,24 @@ def assert_env(e: str) -> None:
args.cog_versions = [os.environ['R8_COG_VERSION']]
args.default_cog_version = os.environ['R8_COG_VERSION']

def pick(d: dict[str, str], env: str) -> dict[str, str]:
key = os.environ[env]
return {key: d[key]}
def pick(
d: dict[str, str], env: str, fail_on_empty: bool = True
) -> dict[str, str]:
try:
key = os.environ[env]
return {key: d[key]}
except KeyError:
if fail_on_empty:
raise
return {}

monogens = [
MonoGen(
id=mg.id,
cuda=pick(mg.cuda, 'R8_CUDA_VERSION'),
cudnn=pick(mg.cudnn, 'R8_CUDNN_VERSION'),
cuda=pick(mg.cuda, 'R8_CUDA_VERSION', fail_on_empty=not args.skip_cuda),
cudnn=pick(
mg.cudnn, 'R8_CUDNN_VERSION', fail_on_empty=not args.skip_cuda
),
python=pick(mg.python, 'R8_PYTHON_VERSION'),
torch=[os.environ['R8_TORCH_VERSION']],
pip_pkgs=mg.pip_pkgs,
Expand Down
2 changes: 1 addition & 1 deletion src/monobase/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def build_user_venv(args: argparse.Namespace) -> None:

python_version = os.environ['R8_PYTHON_VERSION']
torch_version = os.environ['R8_TORCH_VERSION']
cuda_version = os.environ['R8_CUDA_VERSION']
cuda_version = os.environ.get('R8_CUDA_VERSION', 'cpu')

uv = os.path.join(args.prefix, 'bin', 'uv')

Expand Down

0 comments on commit 6bd28e5

Please sign in to comment.