Skip to content

Commit

Permalink
Fix stack provider value bug (#110)
Browse files Browse the repository at this point in the history
* fix stack provider value bug

* add homepage and documentation URLs

* mypy fixes

* don't deploy remote state when using k3d provider

* restructure

* ruff fixes
  • Loading branch information
strickvl authored Nov 6, 2023
1 parent 482dc57 commit 97adc5c
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 38 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ packages = [{ include = "mlstacks", from = "src" }]
description = "MLStacks MLStacks."
authors = ["ZenML GmbH <[email protected]>"]
readme = "README.md"
homepage = ""
documentation = ""
homepage = "https://mlstacks.zenml.io"
documentation = "https://mlstacks.zenml.io/"
repository = "https://github.com/zenml-io/mlstacks"
license = "Apache-2.0"
keywords = ["machine learning", "production", "pipeline", "mlops", "devops"]
Expand Down
42 changes: 24 additions & 18 deletions src/mlstacks/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import shutil
import string
from pathlib import Path
from typing import Optional
from typing import TYPE_CHECKING, Optional

import click
import pkg_resources
Expand Down Expand Up @@ -45,7 +45,10 @@
get_stack_outputs,
infracost_breakdown_stack,
)
from mlstacks.utils.yaml_utils import load_yaml_as_dict
from mlstacks.utils.yaml_utils import load_stack_yaml, load_yaml_as_dict

if TYPE_CHECKING:
from mlstacks.models.stack import Stack


@click.group()
Expand Down Expand Up @@ -91,7 +94,13 @@ def deploy(
debug (bool): Flag to enable debug mode to view raw Terraform logging
"""
with analytics_client.EventHandler(AnalyticsEventsEnum.MLSTACKS_DEPLOY):
if not remote_state_bucket_name:
stack: Stack = load_stack_yaml(file)
if stack.provider.value == "k3d":
deployed_bucket_url = None
elif remote_state_bucket_name:
deployed_bucket_url = remote_state_bucket_name
declare(f"Using '{deployed_bucket_url}' for remote state...")
else:
# generate random bucket name
letters = string.ascii_lowercase + string.digits
random_bucket_suffix = "".join(
Expand All @@ -112,10 +121,6 @@ def deploy(
debug_mode=debug,
)
declare("Remote state successfully deployed!")
else:
deployed_bucket_url = remote_state_bucket_name
declare(f"Using '{deployed_bucket_url}' for remote state...")

# Stack deployment
declare(f"Deploying stack from '{file}'...")
deploy_stack(
Expand Down Expand Up @@ -203,17 +208,18 @@ def destroy(file: str, debug: bool = False, yes: bool = False) -> None:
shutil.rmtree(tf_files_dir)
declare(f"Stack '{stack_name}' has been destroyed.")

remote_state_dir = _get_remote_state_dir_path(provider)
if (
yes
or confirmation(
f"Would you like to destroy the Terraform remote state used "
f"for this stack on {provider}?",
)
) and Path(remote_state_dir).exists():
destroy_remote_state(provider)
shutil.rmtree(remote_state_dir)
declare(f"Remote state for {provider} has been destroyed.")
if provider != "k3d":
remote_state_dir = _get_remote_state_dir_path(provider)
if (
yes
or confirmation(
f"Would you like to destroy the Terraform remote state "
f"used for this stack on {provider}?",
)
) and Path(remote_state_dir).exists():
destroy_remote_state(provider)
shutil.rmtree(remote_state_dir)
declare(f"Remote state for {provider} has been destroyed.")


@click.command()
Expand Down
36 changes: 18 additions & 18 deletions src/mlstacks/utils/terraform_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def parse_and_extract_tf_vars(stack: Stack) -> Dict[str, Any]:


def tf_definitions_present(
provider: ProviderEnum,
provider: str,
base_config_dir: str = CONFIG_DIR,
) -> bool:
"""Check if Terraform definitions are present.
Expand Down Expand Up @@ -257,7 +257,7 @@ def include_files(


def populate_tf_definitions(
provider: ProviderEnum,
provider: str,
region: str,
force: bool = False,
remote_state_bucket: Optional[str] = None,
Expand Down Expand Up @@ -375,7 +375,7 @@ def get_recipe_metadata(
return load_yaml_as_dict(recipe_metadata)


def check_tf_definitions_version(provider: ProviderEnum) -> None:
def check_tf_definitions_version(provider: str) -> None:
"""Checks for a Terraform version mismatch.
Args:
Expand Down Expand Up @@ -668,7 +668,7 @@ def write_remote_state_tf_variables(
Returns:
The remote state variables as a dictionary
"""
provider = stack.provider
provider = stack.provider.value
remote_state_tf_definitions = os.path.join(
CONFIG_DIR,
"terraform",
Expand Down Expand Up @@ -731,7 +731,7 @@ def deploy_remote_state(
remote_state_tf_definitions_path = os.path.join(
CONFIG_DIR,
"terraform",
f"{stack.provider}-remote-state",
f"{stack.provider.value}-remote-state",
)

# check whether remote state files already exist locally
Expand All @@ -742,7 +742,7 @@ def deploy_remote_state(

# copy remote state TF definitions
populate_remote_state_tf_definitions(
provider=stack.provider,
provider=stack.provider.value,
definitions_destination_path=remote_state_tf_definitions_path,
)

Expand All @@ -757,7 +757,7 @@ def deploy_remote_state(
if not tf_previously_initialized(remote_state_tf_definitions_path):
ret_code, _, _stderr = _tf_client_init(
tfr.client,
provider=stack.provider,
provider=stack.provider.value,
region=stack.default_region or "",
debug=debug_mode,
)
Expand Down Expand Up @@ -810,22 +810,22 @@ def deploy_stack(
RuntimeError: when Terraform raises an error.
"""
stack = load_stack_yaml(stack_path)
tf_recipe_path = _get_tf_recipe_path(stack.provider)
if not tf_definitions_present(stack.provider):
tf_recipe_path = _get_tf_recipe_path(stack.provider.value)
if not tf_definitions_present(stack.provider.value):
populate_tf_definitions(
stack.provider,
stack.provider.value,
region=stack.default_region or "",
force=True,
remote_state_bucket=remote_state_bucket,
)
tf_vars = parse_and_extract_tf_vars(stack)
check_tf_definitions_version(stack.provider)
check_tf_definitions_version(stack.provider.value)

tfr = TerraformRunner(tf_recipe_path)
if not tf_previously_initialized(tf_recipe_path):
ret_code, _, _stderr = _tf_client_init(
tfr.client,
provider=stack.provider,
provider=stack.provider.value,
region=stack.default_region or "",
debug=debug_mode,
remote_state_bucket=remote_state_bucket,
Expand Down Expand Up @@ -864,14 +864,14 @@ def destroy_stack(
stack = load_stack_yaml(stack_path)
tf_vars = parse_and_extract_tf_vars(stack)

tf_recipe_path = _get_tf_recipe_path(stack.provider)
tf_recipe_path = _get_tf_recipe_path(stack.provider.value)

tfr = TerraformRunner(tf_recipe_path)

if not tf_previously_initialized(tf_recipe_path):
ret_code, _, _stderr = _tf_client_init(
tfr.client,
provider=stack.provider,
provider=stack.provider.value,
region=stack.default_region or "",
debug=debug_mode,
remote_state_bucket=remote_state_bucket,
Expand Down Expand Up @@ -949,7 +949,7 @@ def get_remote_state_bucket(stack_path: str) -> str:
FileNotFoundError: when file does not exist
"""
stack = load_stack_yaml(stack_path)
tf_recipe_path = _get_tf_recipe_path(stack.provider)
tf_recipe_path = _get_tf_recipe_path(stack.provider.value)
bucket_url_file = os.path.join(
tf_recipe_path,
REMOTE_STATE_BUCKET_URL_FILE_NAME,
Expand Down Expand Up @@ -982,7 +982,7 @@ def get_stack_outputs(
RuntimeError: If Terraform has not been initialized.
"""
stack = load_stack_yaml(stack_path)
tf_recipe_path = _get_tf_recipe_path(stack.provider)
tf_recipe_path = _get_tf_recipe_path(stack.provider.value)
state_tf_path = f"{tf_recipe_path}/terraform.tfstate"

tfr = TerraformRunner(tf_recipe_path)
Expand Down Expand Up @@ -1056,15 +1056,15 @@ def infracost_breakdown_stack(
stack = load_stack_yaml(stack_path)
infracost_vars = _get_infracost_vars(parse_and_extract_tf_vars(stack))

tf_recipe_path = _get_tf_recipe_path(stack.provider)
tf_recipe_path = _get_tf_recipe_path(stack.provider.value)

tfr = TerraformRunner(tf_recipe_path)
if not tf_previously_initialized(tf_recipe_path):
# write a file with name `IGNORE_ME` to the Terraform recipe directory
# to prevent Terraform from initializing the recipe
ret_code, _, _stderr = _tf_client_init(
tfr.client,
provider=stack.provider,
provider=stack.provider.value,
region=stack.default_region or "",
debug=debug_mode,
)
Expand Down

0 comments on commit 97adc5c

Please sign in to comment.