diff --git a/.github/workflows/on_pr.yml b/.github/workflows/on_pr.yml index 5b3eaacd..c672df34 100644 --- a/.github/workflows/on_pr.yml +++ b/.github/workflows/on_pr.yml @@ -37,6 +37,22 @@ jobs: key: ${{ runner.os }}-${{ hashFiles('requirements-dev.txt', '.pre-commit-config.yaml') }}-linter-cache - name: Run pre-commit checks run: pre-commit run --all-files --verbose --show-diff-on-failure + test: + name: Test mlinfra on ubuntu-latest + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Set up Python and dependencies + uses: ./.github/actions/setup-python + with: + pythonVersion: "3.10" + dependencyType: "dev" + - name: Run Tests + run: | + pytest + docs: name: Build documentation for mlinfra runs-on: ubuntu-latest diff --git a/pyproject.toml b/pyproject.toml index 81bb2c80..e1c5eb43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,6 @@ dependencies = [ "boto3", "pyyaml", "GitPython", - "pytest", "mypy", "getmac", "requests" @@ -51,7 +50,9 @@ Documentation = "https://mlinfra.io/" [project.optional-dependencies] dev = [ - "pre-commit>=3.3.3" + "pre-commit>=3.3.3", + "pytest", + "pytest-mock", ] docs = [ "mkdocs-material", @@ -99,3 +100,10 @@ skip_glob = ['**/venv/**'] [project.scripts] mlinfra = "mlinfra.cli.cli:cli" + +[tool.pytest.ini_options] +log_cli = true +log_cli_level = "WARNING" + +[tool.bandit] +exclude_dirs = ["tests/", "docs/"] diff --git a/requirements-dev.txt b/requirements-dev.txt index 76ff5914..9eeb58a5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -55,6 +55,10 @@ pluggy==1.3.0 pre-commit==3.6.0 # via mlinfra (pyproject.toml) pytest==7.4.4 + # via + # mlinfra (pyproject.toml) + # pytest-mock +pytest-mock==3.12.0 # via mlinfra (pyproject.toml) python-dateutil==2.8.2 # via botocore diff --git a/requirements-docs.txt b/requirements-docs.txt index 8fdae5da..14be2476 100644 --- a/requirements-docs.txt +++ b/requirements-docs.txt @@ -22,8 +22,6 @@ click==8.1.7 # via mkdocs colorama==0.4.6 # via mkdocs-material -exceptiongroup==1.2.0 - # via pytest getmac==0.9.4 # via mlinfra (pyproject.toml) ghp-import==2.1.0 @@ -40,8 +38,6 @@ importlib-metadata==7.0.1 # via mike importlib-resources==6.1.1 # via mike -iniconfig==2.0.0 - # via pytest invoke==2.2.0 # via mlinfra (pyproject.toml) jinja2==3.1.3 @@ -85,25 +81,19 @@ mypy==1.8.0 mypy-extensions==1.0.0 # via mypy packaging==23.2 - # via - # mkdocs - # pytest + # via mkdocs paginate==0.5.6 # via mkdocs-material pathspec==0.12.1 # via mkdocs platformdirs==4.1.0 # via mkdocs -pluggy==1.3.0 - # via pytest pygments==2.17.2 # via mkdocs-material pymdown-extensions==10.7 # via mkdocs-material pyparsing==3.1.1 # via mike -pytest==7.4.4 - # via mlinfra (pyproject.toml) python-dateutil==2.8.2 # via # botocore @@ -132,9 +122,7 @@ six==1.16.0 smmap==5.0.1 # via gitdb tomli==2.0.1 - # via - # mypy - # pytest + # via mypy typing-extensions==4.9.0 # via mypy urllib3==2.0.7 diff --git a/src/mlinfra/__init__.py b/src/mlinfra/__init__.py index 91ebf8f6..76cf67a3 100644 --- a/src/mlinfra/__init__.py +++ b/src/mlinfra/__init__.py @@ -9,19 +9,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. - -from pathlib import Path - - -def absolute_project_root() -> Path: - """ - Returns the absolute path to the project root. - """ - return Path(__file__).absolute().parent - - -def relative_project_root() -> Path: - """ - Returns the relative path to the project root. - """ - return Path(__file__).relative_to(Path(__file__).parent.parent).parent diff --git a/src/mlinfra/modules/__init__.py b/src/mlinfra/modules/__init__.py new file mode 100644 index 00000000..76cf67a3 --- /dev/null +++ b/src/mlinfra/modules/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) mlinfra 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# https://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. diff --git a/src/mlinfra/stack_processor/deployment_processor/cloud_infra_deployment.py b/src/mlinfra/stack_processor/deployment_processor/cloud_infra_deployment.py index 58a4bc2c..bfa004b4 100644 --- a/src/mlinfra/stack_processor/deployment_processor/cloud_infra_deployment.py +++ b/src/mlinfra/stack_processor/deployment_processor/cloud_infra_deployment.py @@ -11,9 +11,10 @@ # permissions and limitations under the License. import json +from importlib import resources +from typing import Any, Dict -import yaml -from mlinfra import absolute_project_root +from mlinfra import modules from mlinfra.enums.cloud_provider import CloudProvider from mlinfra.stack_processor.deployment_processor.deployment import ( AbstractDeployment, @@ -22,30 +23,40 @@ class CloudInfraDeployment(AbstractDeployment): + """ + A class that configures the deployment of cloud infrastructure resources based on the specified provider. + + Args: + stack_name (str): The name of the stack. + provider (CloudProvider): The cloud provider (AWS, GCP, or Azure). + region (str): The region where the resources will be deployed. + deployment_config (Dict[str, Any]): The deployment configuration in dictionary format. + """ + def __init__( self, stack_name: str, provider: CloudProvider, region: str, - deployment_config: yaml, + deployment_config: Dict[str, Any], ): - super(CloudInfraDeployment, self).__init__( - stack_name=stack_name, - provider=provider, - region=region, - deployment_config=deployment_config, - ) + super().__init__(stack_name, provider, region, deployment_config) def configure_required_provider_config(self): + """ + Configures the required provider configuration for the deployment. + Updates the Terraform JSON file with the necessary provider information. + """ + with open( - absolute_project_root() / f"modules/cloud/{self.provider.value}/terraform.tf.json", + resources.files(modules) / f"cloud/{self.provider.value}/terraform.tf.json", "r", ) as data_json: data = json.load(data_json) # add random provider with open( - absolute_project_root() / "modules/terraform_providers/random/terraform.tf.json", + resources.files(modules) / "terraform_providers/random/terraform.tf.json", "r", ) as random_tf: random_tf_json = json.load(random_tf) @@ -58,17 +69,21 @@ def configure_required_provider_config(self): generate_tf_json(module_name="terraform", json_module=data) def configure_deployment_config(self): - # inject vpc module + """ + Configures the deployment configuration based on the provider. + Generates a Terraform JSON file for the specific provider. + """ if self.provider == CloudProvider.AWS: - json_module = {"module": {"vpc": {}}} - json_module["module"]["vpc"]["name"] = f"{self.stack_name}-vpc" - json_module["module"]["vpc"]["source"] = "./modules/cloud/aws/vpc" - - if "config" in self.deployment_config and "vpc" in self.deployment_config["config"]: - for vpc_config in self.deployment_config["config"]["vpc"]: - json_module["module"]["vpc"][vpc_config] = self.deployment_config["config"][ - "vpc" - ].get(vpc_config, None) + json_module = { + "module": { + "vpc": { + "name": f"{self.stack_name}-vpc", + "source": "./modules/cloud/aws/vpc", + } + } + } + vpc_config = self.deployment_config.get("config", {}).get("vpc", {}) + json_module["module"]["vpc"].update(vpc_config) generate_tf_json(module_name="vpc", json_module=json_module) elif self.provider == CloudProvider.GCP: @@ -79,6 +94,9 @@ def configure_deployment_config(self): raise ValueError(f"Provider {self.provider} is not supported") def configure_deployment(self): + """ + Configures the deployment by calling the `configure_required_provider_config()` and `configure_deployment_config()` methods. + """ self.configure_required_provider_config() self.configure_deployment_config() diff --git a/src/mlinfra/stack_processor/deployment_processor/deployment.py b/src/mlinfra/stack_processor/deployment_processor/deployment.py index 9e4e8dc7..7cdf7a20 100644 --- a/src/mlinfra/stack_processor/deployment_processor/deployment.py +++ b/src/mlinfra/stack_processor/deployment_processor/deployment.py @@ -10,16 +10,26 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. -import json from abc import ABC, abstractmethod -import yaml from mlinfra.enums.cloud_provider import CloudProvider class AbstractDeployment(ABC): """ - Abstract class for deployment + Abstract class for deployment. + + Args: + stack_name (str): The name of the deployment stack. + provider (CloudProvider): The cloud provider for the deployment. + region (str): The region for the deployment. + deployment_config (dict): The deployment configuration. + + Attributes: + stack_name (str): The name of the deployment stack. + provider (CloudProvider): The cloud provider for the deployment. + region (str): The region for the deployment. + deployment_config (dict): The deployment configuration. """ def __init__( @@ -27,7 +37,7 @@ def __init__( stack_name: str, provider: CloudProvider, region: str, - deployment_config: yaml, + deployment_config: dict, ): self.stack_name = stack_name self.provider = provider @@ -36,13 +46,31 @@ def __init__( @abstractmethod def configure_deployment(self): + """ + Abstract method that must be implemented by subclasses to configure the deployment. + """ pass # TODO: refactor statefile name def get_statefile_name(self) -> str: + """ + Get the name of the statefile for the deployment. + + Returns: + str: The name of the statefile. + """ return f"tfstate-{self.stack_name}-{self.region}" - def get_provider_backend(self, provider: CloudProvider) -> json: + def get_provider_backend(self, provider: CloudProvider) -> dict: + """ + Get the backend configuration for the specified provider. + + Args: + provider (CloudProvider): The cloud provider. + + Returns: + json: The backend configuration. + """ if provider == CloudProvider.AWS: return { "backend": { diff --git a/src/mlinfra/stack_processor/deployment_processor/kubernetes_deployment.py b/src/mlinfra/stack_processor/deployment_processor/kubernetes_deployment.py index 079dae7d..3eb8d791 100644 --- a/src/mlinfra/stack_processor/deployment_processor/kubernetes_deployment.py +++ b/src/mlinfra/stack_processor/deployment_processor/kubernetes_deployment.py @@ -11,9 +11,10 @@ # permissions and limitations under the License. import json +from importlib import resources import yaml -from mlinfra import absolute_project_root +from mlinfra import modules from mlinfra.enums.cloud_provider import CloudProvider from mlinfra.stack_processor.deployment_processor.deployment import ( AbstractDeployment, @@ -29,6 +30,15 @@ def __init__( region: str, deployment_config: yaml, ): + """ + Initialize a new instance of the KubernetesDeployment class. + + Parameters: + - stack_name (str): The name of the stack. + - provider (CloudProvider): The cloud provider for the deployment. + - region (str): The region for the deployment. + - deployment_config (yaml): The deployment configuration. + """ super(KubernetesDeployment, self).__init__( stack_name=stack_name, provider=provider, @@ -37,8 +47,20 @@ def __init__( ) def generate_required_provider_config(self): + """ + Generate the required provider configuration for the Kubernetes deployment. + + This method reads the necessary provider configuration files and generates the required provider configuration for the Kubernetes deployment. + It updates the 'terraform.tf.json' file with the required provider information. + + Parameters: + - None + + Returns: + - None + """ with open( - absolute_project_root() / f"modules/cloud/{self.provider.value}/terraform.tf.json", + resources.files(modules) / f"cloud/{self.provider.value}/terraform.tf.json", "r", ) as data_json: data = json.load(data_json) @@ -48,8 +70,8 @@ def generate_required_provider_config(self): for required_provider in required_providers: with open( - absolute_project_root() - / f"modules/terraform_providers/{required_provider}/terraform.tf.json", + resources.files(modules) + / f"terraform_providers/{required_provider}/terraform.tf.json", "r", ) as provider_tf: provider_tf_json = json.load(provider_tf) @@ -69,8 +91,7 @@ def generate_k8s_helm_provider_config(self): for provider in providers: with open( - absolute_project_root() - / f"modules/terraform_providers/{provider}/provider.tf.json", + resources.files(modules) / f"terraform_providers/{provider}/provider.tf.json", "r", ) as provider_tf: provider_tf_json = json.load(provider_tf) @@ -79,6 +100,21 @@ def generate_k8s_helm_provider_config(self): generate_tf_json(module_name="k8s_provider", json_module=data) def generate_deployment_config(self): + """ + Generate the deployment configuration for the Kubernetes deployment. + + This method generates the deployment configuration for the Kubernetes deployment based on the specified provider. + It injects the necessary modules and configurations into the Terraform configuration files. + + Parameters: + - None + + Returns: + - None + + Raises: + - ValueError: If the specified provider is not supported. + """ if self.provider == CloudProvider.AWS: # TODO: Make these blocks generic # inject vpc module @@ -109,7 +145,7 @@ def generate_deployment_config(self): ): # read values from the yaml config file with open( - absolute_project_root() / f"./modules/cloud/{self.provider.value}/eks/eks.yaml", + resources.files(modules) / f"cloud/{self.provider.value}/eks/eks.yaml", "r", encoding="utf-8", ) as tf_config: @@ -180,9 +216,9 @@ def generate_deployment_config(self): generate_tf_json(module_name="nodegroups", json_module=nodegroups_json_module) elif self.provider == CloudProvider.GCP: - pass + raise ValueError(f"Provider {self.provider} is not yet supported") elif self.provider == CloudProvider.AZURE: - pass + raise ValueError(f"Provider {self.provider} is not yet supported") else: raise ValueError(f"Provider {self.provider} is not supported") diff --git a/src/mlinfra/stack_processor/provider_processor/aws_provider.py b/src/mlinfra/stack_processor/provider_processor/aws_provider.py index 088667d3..63fd4c48 100644 --- a/src/mlinfra/stack_processor/provider_processor/aws_provider.py +++ b/src/mlinfra/stack_processor/provider_processor/aws_provider.py @@ -11,9 +11,9 @@ # permissions and limitations under the License. import json +from importlib import resources -import yaml -from mlinfra import absolute_project_root +from mlinfra import modules from mlinfra.stack_processor.provider_processor.provider import ( AbstractProvider, ) @@ -21,7 +21,22 @@ class AWSProvider(AbstractProvider): - def __init__(self, stack_name: str, config: yaml): + """ + Represents a provider for the AWS infrastructure. + Args: + stack_name (str): The name of the stack. + config (dict): The configuration object containing the provider settings. + Attributes: + stack_name (str): The name of the stack. + config (dict): The configuration object containing the provider settings. + account_id (str): The AWS account ID. + region (str): The AWS region. + access_key (str): The AWS access key (optional). + secret_key (str): The AWS secret key (optional). + role_arn (str): The AWS role ARN (optional). + """ + + def __init__(self, stack_name: str, config: dict): super().__init__(stack_name=stack_name, config=config) # required self.account_id = config.get("account_id") @@ -34,10 +49,20 @@ def __init__(self, stack_name: str, config: yaml): # TODO: refactor statefile name def get_statefile_name(self) -> str: + """ + Returns the name of the statefile for the current stack and region. + Returns: + str: The name of the statefile. + """ return f"tfstate-{self.stack_name}-{self.region}" def configure_provider(self): - with open(absolute_project_root() / "modules/cloud/aws/provider.tf.json", "r") as data_json: + """ + Configures the provider by updating the provider configuration file. + It sets the AWS region, allowed account IDs, and default tags. + It also adds a random provider. + """ + with open(resources.files(modules) / "cloud/aws/provider.tf.json", "r") as data_json: with open(f"./{TF_PATH}/provider.tf.json", "w", encoding="utf-8") as tf_json: data = json.load(data_json) data["provider"]["aws"]["region"] = self.region @@ -51,7 +76,7 @@ def configure_provider(self): # add random provider with open( - absolute_project_root() / "modules/terraform_providers/random/provider.tf.json", + resources.files(modules) / "terraform_providers/random/provider.tf.json", "r", ) as random_provider: random_provider_json = json.load(random_provider) diff --git a/src/mlinfra/stack_processor/provider_processor/provider.py b/src/mlinfra/stack_processor/provider_processor/provider.py index a382aef1..b75787f2 100644 --- a/src/mlinfra/stack_processor/provider_processor/provider.py +++ b/src/mlinfra/stack_processor/provider_processor/provider.py @@ -12,19 +12,26 @@ from abc import ABC, abstractmethod -import yaml - class AbstractProvider(ABC): """ Abstract class for providers """ - @abstractmethod - def __init__(self, stack_name: str, config: yaml): + def __init__(self, stack_name: str, config: dict): + """ + Initializes the AbstractProvider object with the provided stack name and configuration object. + + Args: + stack_name (str): The name of the stack. + config (dict): The configuration for the provider. + """ self.stack_name = stack_name self.config = config @abstractmethod def configure_provider(self): + """ + Abstract method that needs to be implemented by subclasses to configure the provider. + """ pass diff --git a/src/mlinfra/stack_processor/stack_generator.py b/src/mlinfra/stack_processor/stack_generator.py index 981e2ca6..dea7022b 100644 --- a/src/mlinfra/stack_processor/stack_generator.py +++ b/src/mlinfra/stack_processor/stack_generator.py @@ -30,7 +30,30 @@ class StackGenerator: + """ + A class that generates and configures infrastructure stacks based on the provided stack configuration. + + Attributes: + stack_config (dict): The stack configuration provided to the StackGenerator object. + stack_name (str): The name of the stack. + account_id (str): The account ID associated with the stack. + provider (str): The cloud provider for the stack. + deployment_type (str): The type of deployment (cloud infrastructure or Kubernetes). + state_file_name (str): The name of the state file. + is_stack_component (bool): A flag indicating if the stack is a component of a larger stack. + output (dict): The output configuration for the stack. + """ + def __init__(self, stack_config): + """ + Initializes the StackGenerator object with the provided stack configuration. + + Args: + stack_config (dict): The stack configuration provided to the StackGenerator object. + + Raises: + Exception: If the stack configuration is missing any required components. + """ self.stack_config = stack_config self.stack_name = "" self.account_id = "" @@ -58,13 +81,28 @@ def __init__(self, stack_config): # TODO: refactor statefile name def get_state_file_name(self): + """ + Generates the state file name based on the stack name and region. + + Returns: + str: The state file name. + """ self.state_file_name = f"tfstate-{self.stack_name}-{self.region}" return self.state_file_name def get_region(self): + """ + Returns the region specified in the stack configuration. + + Returns: + str: The region. + """ return self.region def generate(self): + """ + Generates and configures the infrastructure stacks based on the deployment type. + """ if DeploymentType(self.stack_config["deployment"]["type"]) == DeploymentType.CLOUD_INFRA: CloudInfraDeployment( stack_name=self.stack_name, @@ -99,9 +137,20 @@ def generate(self): ).generate() def configure_provider(self) -> CloudProvider: + """ + Configures the provider details based on the stack configuration. + + Returns: + CloudProvider: The cloud provider. + + Raises: + NotImplementedError: If the cloud provider is not supported. + """ if CloudProvider(self.stack_config["provider"]["name"]) == CloudProvider.AWS: aws_provider = AWSProvider( stack_name=self.stack_name, config=self.stack_config["provider"] ) aws_provider.configure_provider() return CloudProvider.AWS + else: + raise NotImplementedError("Cloud provider not supported") diff --git a/src/mlinfra/stack_processor/stack_processor/stack.py b/src/mlinfra/stack_processor/stack_processor/stack.py index c87cc428..d22c1eae 100644 --- a/src/mlinfra/stack_processor/stack_processor/stack.py +++ b/src/mlinfra/stack_processor/stack_processor/stack.py @@ -12,9 +12,11 @@ import json from abc import ABC, abstractmethod +from importlib import resources +from typing import Any, Dict, Union import yaml -from mlinfra import absolute_project_root +from mlinfra import modules from mlinfra.enums.cloud_provider import CloudProvider from mlinfra.enums.deployment_type import DeploymentType @@ -37,8 +39,12 @@ def __init__( Initializes the stack. Args: - provider (Provider): The cloud provider. - stacks (yaml): The stack config. + state_file_name (str): The name of the state file for the stack. + region (str): The region where the stack will be deployed. + account_id (str): The ID of the account associated with the stack. + provider (CloudProvider): The cloud provider for the stack. + deployment_type (DeploymentType): The type of deployment for the stack. + stacks (yaml): The stack configuration in YAML format. """ self.state_file_name = state_file_name self.region = region @@ -49,59 +55,60 @@ def __init__( def _read_config_file( self, stack_type: str, application_name: str, extension: str = "yaml" - ) -> json: + ) -> Union[Dict, Any]: """ Reads the config file for the application and returns the config - as a json object for application config and yaml for stack config. + as a JSON object for application config and YAML for stack config. Args: stack_type (str): The type of the stack. application_name (str): The name of the application. extension (str, optional): The extension of the config file. + + Returns: + Union[Dict, Any]: The configuration as a JSON object or YAML. """ - with open( - absolute_project_root() - / f"modules/applications/{self.deployment_type.value}/{stack_type}/{application_name}/{application_name}_{self.deployment_type.value}.{extension}", - "r", - encoding="utf-8", - ) as tf_config: - return ( - yaml.safe_load(tf_config.read()) - if extension == "yaml" - else json.loads(tf_config.read()) - ) + file_path = ( + resources.files(modules) + / f"applications/{self.deployment_type.value}/{stack_type}/{application_name}/{application_name}_{self.deployment_type.value}.{extension}" + ) + with open(file_path, "r", encoding="utf-8") as config_file: + if extension == "yaml": + return yaml.safe_load(config_file.read()) + else: + return json.loads(config_file.read()) @abstractmethod - def process_stack_config(): + def process_stack_config(self): """ Process the stack configuration and validates the config. """ pass @abstractmethod - def process_stack_inputs(): + def process_stack_inputs(self): """ Process the stack inputs and validates the inputs. """ pass @abstractmethod - def process_stack_modules(): + def process_stack_modules(self): """ - Generates the stack modules terraform json configuration + Generates the stack modules Terraform JSON configuration and validates the modules against the stack config and parameters. """ pass @abstractmethod - def process_stack_outputs(): + def process_stack_outputs(self): """ Generates the required stack outputs. """ pass @abstractmethod - def generate(): + def generate(self): """ Generates the stack configuration. """ diff --git a/src/mlinfra/terraform/terraform.py b/src/mlinfra/terraform/terraform.py index 3b20ccf5..77111996 100644 --- a/src/mlinfra/terraform/terraform.py +++ b/src/mlinfra/terraform/terraform.py @@ -16,12 +16,13 @@ # TODO: Update this section to run it more secure and # remove the comment import subprocess # nosec +from importlib import resources # import hashlib import boto3 import yaml from botocore.config import Config -from mlinfra import absolute_project_root +from mlinfra import modules from mlinfra.stack_processor.stack_generator import StackGenerator from mlinfra.terraform.state_helper import StateHelper from mlinfra.utils.constants import TF_PATH @@ -105,7 +106,7 @@ def read_stack_config(self) -> yaml: # create the stack folder os.makedirs(TF_PATH, mode=0o777) - create_symlinks(absolute_project_root() / "modules", TF_PATH + "/modules") + create_symlinks(resources.files(modules), TF_PATH + "/modules") # TODO: generate hash of the stack config to not generate config all the time # sha256_hash = hashlib.sha256() @@ -191,7 +192,7 @@ def generate_modules_list(self): modules_list.append(f"module.{key}") return modules_list - def generate_terraform_config(self) -> (str, str): + def generate_terraform_config(self) -> tuple[str, str]: """This function is responsible for generating the terraform config file""" self.check_terraform_installed() # TODO: perform this after the cli package has been released diff --git a/tests/test_stack_processor/test_deployment_processor/test_abs_deployment.py b/tests/test_stack_processor/test_deployment_processor/test_abs_deployment.py new file mode 100644 index 00000000..3c72c0dc --- /dev/null +++ b/tests/test_stack_processor/test_deployment_processor/test_abs_deployment.py @@ -0,0 +1,58 @@ +# Copyright (c) mlinfra 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# https://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +# Generated by CodiumAI +import pytest +import yaml +from mlinfra.enums.cloud_provider import CloudProvider +from mlinfra.stack_processor.deployment_processor.deployment import AbstractDeployment + + +class TestAbstractDeployment: + # stack_name argument is None + def test_stack_name_argument_is_none(self): + stack_name = None + provider = CloudProvider.AWS + region = "us-west-2" + + yaml_config = """ + deployment: + type: cloud_infra + """ + deployment_config = yaml.safe_load(yaml_config) + + with pytest.raises(TypeError): + AbstractDeployment(stack_name, provider, region, deployment_config) + + # region argument is None + def test_region_argument_is_none(self): + stack_name = "test_stack" + provider = CloudProvider.AWS + region = None + + yaml_config = """ + deployment: + type: cloud_infra + """ + deployment_config = yaml.safe_load(yaml_config) + + with pytest.raises(TypeError): + AbstractDeployment(stack_name, provider, region, deployment_config) + + # deployment_config argument is None + def test_deployment_config_argument_is_none(self): + stack_name = "test_stack" + provider = CloudProvider.AWS + region = "us-west-2" + deployment_config = None + + with pytest.raises(TypeError): + AbstractDeployment(stack_name, provider, region, deployment_config) diff --git a/tests/test_stack_processor/test_deployment_processor/test_cloud_infra_deployment.py b/tests/test_stack_processor/test_deployment_processor/test_cloud_infra_deployment.py new file mode 100644 index 00000000..387d23cf --- /dev/null +++ b/tests/test_stack_processor/test_deployment_processor/test_cloud_infra_deployment.py @@ -0,0 +1,47 @@ +# Copyright (c) mlinfra 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# https://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +# Generated by CodiumAI +from mlinfra.enums.cloud_provider import CloudProvider +from mlinfra.stack_processor.deployment_processor.cloud_infra_deployment import CloudInfraDeployment + + +class TestCloudInfraDeployment: + # can be instantiated with required parameters + def test_instantiation(self): + stack_name = "test_stack" + provider = CloudProvider.AWS + region = "us-west-2" + deployment_config = {"config": {"vpc": {"cidr_block": "10.0.0.0/16"}}} + + deployment = CloudInfraDeployment(stack_name, provider, region, deployment_config) + + assert deployment.stack_name == stack_name + assert deployment.provider == provider + assert deployment.region == region + assert deployment.deployment_config == deployment_config + + # can configure deployment for AWS provider + def test_configure_aws_deployment(self, mocker): + stack_name = "test_stack" + provider = CloudProvider.AWS + region = "us-west-2" + deployment_config = {"config": {"vpc": {"cidr_block": "10.0.0.0/16"}}} + + deployment = CloudInfraDeployment(stack_name, provider, region, deployment_config) + + mocker.patch.object(deployment, "configure_required_provider_config") + mocker.patch.object(deployment, "configure_deployment_config") + + deployment.configure_deployment() + + deployment.configure_required_provider_config.assert_called_once() + deployment.configure_deployment_config.assert_called_once() diff --git a/tests/test_stack_processor/test_deployment_processor/test_kubernetes_deployment.py b/tests/test_stack_processor/test_deployment_processor/test_kubernetes_deployment.py new file mode 100644 index 00000000..c2cfba32 --- /dev/null +++ b/tests/test_stack_processor/test_deployment_processor/test_kubernetes_deployment.py @@ -0,0 +1,99 @@ +# Copyright (c) mlinfra 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# https://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +# Generated by CodiumAI +import pytest +from mlinfra.enums.cloud_provider import CloudProvider +from mlinfra.stack_processor.deployment_processor.kubernetes_deployment import KubernetesDeployment + + +class TestKubernetesDeployment: + # can be instantiated with valid arguments + def test_instantiation_with_valid_arguments(self): + stack_name = "test-stack" + provider = CloudProvider.AWS + region = "us-west-2" + deployment_config = {} + + deployment = KubernetesDeployment(stack_name, provider, region, deployment_config) + + assert deployment.stack_name == stack_name + assert deployment.provider == provider + assert deployment.region == region + assert deployment.deployment_config == deployment_config + + def test_specified_provider_not_supported(self): + provider = CloudProvider.GCP + region = "us-west-2" + stack_name = "my-stack" + deployment_config = { + "config": { + "vpc": { + "cidr_block": "10.0.0.0/16", + "subnet_cidr_blocks": ["10.0.1.0/24", "10.0.2.0/24"], + }, + "kubernetes": { + "cluster_version": "1.28", + "node_groups": [ + { + "name": "worker-group", + "instance_type": "t3.medium", + "desired_capacity": 2, + } + ], + }, + } + } + + deployment = KubernetesDeployment( + stack_name=stack_name, + provider=provider, + region=region, + deployment_config=deployment_config, + ) + + # Assert that a FileNotFoundError is raised when the specified provider is not supported + with pytest.raises(FileNotFoundError): + deployment.generate_required_provider_config() + + def test_required_provider_not_supported(self): + provider = CloudProvider.AWS + region = "us-west-2" + stack_name = "my-stack" + deployment_config = { + "config": { + "vpc": { + "cidr_block": "10.0.0.0/16", + "subnet_cidr_blocks": ["10.0.1.0/24", "10.0.2.0/24"], + }, + "kubernetes": { + "cluster_version": "1.28", + "node_groups": [ + { + "name": "worker-group", + "instance_type": "t3.medium", + "desired_capacity": 2, + } + ], + }, + } + } + + deployment = KubernetesDeployment( + stack_name=stack_name, + provider=provider, + region=region, + deployment_config=deployment_config, + ) + + # Assert that a FileNotFoundError is raised when one of the required providers is not supported + with pytest.raises(FileNotFoundError): + deployment.generate_required_provider_config() diff --git a/tests/test_stack_processor/test_provider_processor/test_abs_provider.py b/tests/test_stack_processor/test_provider_processor/test_abs_provider.py new file mode 100644 index 00000000..683aa69c --- /dev/null +++ b/tests/test_stack_processor/test_provider_processor/test_abs_provider.py @@ -0,0 +1,66 @@ +# Copyright (c) mlinfra 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# https://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +# Generated by CodiumAI +import pytest +from mlinfra.stack_processor.provider_processor.provider import AbstractProvider + + +class TestAbstractProvider: + # AbstractProvider object can be initialized with stack name and configuration object. + def test_initialize_with_stack_name_and_config(self): + stack_name = "test_stack" + config = {"key": "value"} + + class ConcreteProvider(AbstractProvider): + def configure_provider(self): + pass + + provider = ConcreteProvider(stack_name, config) + + assert provider.stack_name == stack_name + assert provider.config == config + + # Subclasses can implement configure_provider method to configure the provider. + def test_configure_provider_method(self): + class TestProvider(AbstractProvider): + def configure_provider(self): + return "Provider configured" + + provider = TestProvider("test_stack", {"key": "value"}) + + assert provider.configure_provider() == "Provider configured" + + # stack_name parameter is not provided during initialization. + def test_missing_stack_name_parameter(self): + with pytest.raises(TypeError): + AbstractProvider(config={"key": "value"}) + + # config parameter is not provided during initialization. + def test_missing_config_parameter(self): + with pytest.raises(TypeError): + AbstractProvider(stack_name="test_stack") + + # Subclasses can override __init__ method to add additional parameters. + def test_subclass_override_init_method(self): + class TestProvider(AbstractProvider): + def __init__(self, stack_name, config, additional_param): + super().__init__(stack_name, config) + self.additional_param = additional_param + + def configure_provider(self): + pass + + provider = TestProvider("test_stack", {"key": "value"}, "additional") + + assert provider.stack_name == "test_stack" + assert provider.config == {"key": "value"} + assert provider.additional_param == "additional" diff --git a/tests/test_stack_processor/test_provider_processor/test_aws_provider.py b/tests/test_stack_processor/test_provider_processor/test_aws_provider.py new file mode 100644 index 00000000..2886e43d --- /dev/null +++ b/tests/test_stack_processor/test_provider_processor/test_aws_provider.py @@ -0,0 +1,67 @@ +# Copyright (c) mlinfra 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# https://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +# Generated by CodiumAI +from mlinfra.stack_processor.provider_processor.aws_provider import AWSProvider + + +class TestAWSProvider: + # AWSProvider can be instantiated with a stack name and a configuration object + def test_instantiation_with_stack_name_and_config(self): + stack_name = "test_stack" + config = {"account_id": "123456789", "region": "us-west-2"} + provider = AWSProvider(stack_name, config) + + assert provider.stack_name == stack_name + assert provider.config == config + assert provider.account_id == config["account_id"] + assert provider.region == config["region"] + assert provider.access_key is None + assert provider.secret_key is None + assert provider.role_arn is None + + # AWSProvider can retrieve the name of the statefile for a given stack and region + def test_get_statefile_name(self): + stack_name = "test_stack" + config = {"account_id": "123456789", "region": "us-west-2"} + provider = AWSProvider(stack_name, config) + + expected_statefile_name = f"tfstate-{stack_name}-{config['region']}" + + assert provider.get_statefile_name() == expected_statefile_name + + # AWSProvider can be instantiated with a configuration object that does not contain an access key, a secret key, or a role ARN + def test_instantiation_without_access_key_secret_key_role_arn(self): + stack_name = "test_stack" + config = {"account_id": "123456789", "region": "us-west-2"} + provider = AWSProvider(stack_name, config) + + assert provider.stack_name == stack_name + assert provider.config == config + assert provider.account_id == config["account_id"] + assert provider.region == config["region"] + assert provider.access_key is None + assert provider.secret_key is None + assert provider.role_arn is None + + # AWSProvider can be instantiated with a configuration object that does not contain an account ID or a region + def test_instantiation_without_account_id_region(self): + stack_name = "test_stack" + config = {"access_key": "access_key", "secret_key": "secret_key", "role_arn": "role_arn"} + provider = AWSProvider(stack_name, config) + + assert provider.stack_name == stack_name + assert provider.config == config + assert provider.account_id is None + assert provider.region is None + assert provider.access_key == config["access_key"] + assert provider.secret_key == config["secret_key"] + assert provider.role_arn == config["role_arn"] diff --git a/tests/test_stack_processor/test_stack_generator.py b/tests/test_stack_processor/test_stack_generator.py new file mode 100644 index 00000000..c58e4c1f --- /dev/null +++ b/tests/test_stack_processor/test_stack_generator.py @@ -0,0 +1,58 @@ +# Copyright (c) mlinfra 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# https://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +import pytest +from mlinfra.stack_processor.stack_generator import StackGenerator + + +class TestStackGenerator: + # Raises an exception if 'name', 'provider', 'deployment' or 'stack' keys are missing in 'stack_config'. + def test_raises_exception_if_missing_keys(self): + stack_config = { + "provider": {"name": "aws", "region": "us-west-2"}, + "deployment": {"type": "CLOUD_INFRA"}, + "stack": [], + } + with pytest.raises(Exception): + StackGenerator(stack_config) + + # Raises an exception if the length of 'stack_name' attribute is greater than 37 characters. + def test_raises_exception_if_stack_name_exceeds_37_characters(self): + stack_config = { + "name": "a" * 38, + "provider": {"name": "aws", "region": "us-west-2"}, + "deployment": {"type": "CLOUD_INFRA"}, + "stack": [], + } + with pytest.raises(Exception): + StackGenerator(stack_config) + + # Raises an exception if 'provider' key in 'stack_config' is not a valid CloudProvider enum value. + def test_raises_exception_if_provider_key_is_invalid(self): + stack_config = { + "name": "test_stack", + "provider": {"name": "invalid_provider", "region": "us-west-2"}, + "deployment": {"type": "CLOUD_INFRA"}, + "stack": [], + } + with pytest.raises(Exception): + StackGenerator(stack_config) + + # Raises an exception if 'deployment' key in 'stack_config' is not a valid DeploymentType enum value. + def test_raises_exception_if_deployment_key_is_invalid(self): + stack_config = { + "name": "test_stack", + "provider": {"name": "aws", "region": "us-west-2"}, + "deployment": {"type": "invalid_deployment"}, + "stack": [], + } + with pytest.raises(Exception): + StackGenerator(stack_config)