Skip to content

Commit

Permalink
Merge pull request #62 from mlinfra-io/adding-tests
Browse files Browse the repository at this point in the history
Adding tests
  • Loading branch information
aliabbasjaffri authored Feb 11, 2024
2 parents b2787d6 + fc76882 commit 25461ad
Show file tree
Hide file tree
Showing 20 changed files with 678 additions and 101 deletions.
16 changes: 16 additions & 0 deletions .github/workflows/on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,22 @@ jobs:
key: ${{ runner.os }}-${{ hashFiles('requirements-dev.txt', '.pre-commit-config.yaml') }}-linter-cache
- name: Run pre-commit checks
run: pre-commit run --all-files --verbose --show-diff-on-failure
test:
name: Test mlinfra on ubuntu-latest
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python and dependencies
uses: ./.github/actions/setup-python
with:
pythonVersion: "3.10"
dependencyType: "dev"
- name: Run Tests
run: |
pytest
docs:
name: Build documentation for mlinfra
runs-on: ubuntu-latest
Expand Down
12 changes: 10 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ dependencies = [
"boto3",
"pyyaml",
"GitPython",
"pytest",
"mypy",
"getmac",
"requests"
Expand All @@ -51,7 +50,9 @@ Documentation = "https://mlinfra.io/"

[project.optional-dependencies]
dev = [
"pre-commit>=3.3.3"
"pre-commit>=3.3.3",
"pytest",
"pytest-mock",
]
docs = [
"mkdocs-material",
Expand Down Expand Up @@ -99,3 +100,10 @@ skip_glob = ['**/venv/**']

[project.scripts]
mlinfra = "mlinfra.cli.cli:cli"

[tool.pytest.ini_options]
log_cli = true
log_cli_level = "WARNING"

[tool.bandit]
exclude_dirs = ["tests/", "docs/"]
4 changes: 4 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ pluggy==1.3.0
pre-commit==3.6.0
# via mlinfra (pyproject.toml)
pytest==7.4.4
# via
# mlinfra (pyproject.toml)
# pytest-mock
pytest-mock==3.12.0
# via mlinfra (pyproject.toml)
python-dateutil==2.8.2
# via botocore
Expand Down
16 changes: 2 additions & 14 deletions requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ click==8.1.7
# via mkdocs
colorama==0.4.6
# via mkdocs-material
exceptiongroup==1.2.0
# via pytest
getmac==0.9.4
# via mlinfra (pyproject.toml)
ghp-import==2.1.0
Expand All @@ -40,8 +38,6 @@ importlib-metadata==7.0.1
# via mike
importlib-resources==6.1.1
# via mike
iniconfig==2.0.0
# via pytest
invoke==2.2.0
# via mlinfra (pyproject.toml)
jinja2==3.1.3
Expand Down Expand Up @@ -85,25 +81,19 @@ mypy==1.8.0
mypy-extensions==1.0.0
# via mypy
packaging==23.2
# via
# mkdocs
# pytest
# via mkdocs
paginate==0.5.6
# via mkdocs-material
pathspec==0.12.1
# via mkdocs
platformdirs==4.1.0
# via mkdocs
pluggy==1.3.0
# via pytest
pygments==2.17.2
# via mkdocs-material
pymdown-extensions==10.7
# via mkdocs-material
pyparsing==3.1.1
# via mike
pytest==7.4.4
# via mlinfra (pyproject.toml)
python-dateutil==2.8.2
# via
# botocore
Expand Down Expand Up @@ -132,9 +122,7 @@ six==1.16.0
smmap==5.0.1
# via gitdb
tomli==2.0.1
# via
# mypy
# pytest
# via mypy
typing-extensions==4.9.0
# via mypy
urllib3==2.0.7
Expand Down
16 changes: 0 additions & 16 deletions src/mlinfra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.

from pathlib import Path


def absolute_project_root() -> Path:
"""
Returns the absolute path to the project root.
"""
return Path(__file__).absolute().parent


def relative_project_root() -> Path:
"""
Returns the relative path to the project root.
"""
return Path(__file__).relative_to(Path(__file__).parent.parent).parent
11 changes: 11 additions & 0 deletions src/mlinfra/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) mlinfra 2024. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
# https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
# permissions and limitations under the License.

import json
from importlib import resources
from typing import Any, Dict

import yaml
from mlinfra import absolute_project_root
from mlinfra import modules
from mlinfra.enums.cloud_provider import CloudProvider
from mlinfra.stack_processor.deployment_processor.deployment import (
AbstractDeployment,
Expand All @@ -22,30 +23,40 @@


class CloudInfraDeployment(AbstractDeployment):
"""
A class that configures the deployment of cloud infrastructure resources based on the specified provider.
Args:
stack_name (str): The name of the stack.
provider (CloudProvider): The cloud provider (AWS, GCP, or Azure).
region (str): The region where the resources will be deployed.
deployment_config (Dict[str, Any]): The deployment configuration in dictionary format.
"""

def __init__(
self,
stack_name: str,
provider: CloudProvider,
region: str,
deployment_config: yaml,
deployment_config: Dict[str, Any],
):
super(CloudInfraDeployment, self).__init__(
stack_name=stack_name,
provider=provider,
region=region,
deployment_config=deployment_config,
)
super().__init__(stack_name, provider, region, deployment_config)

def configure_required_provider_config(self):
"""
Configures the required provider configuration for the deployment.
Updates the Terraform JSON file with the necessary provider information.
"""

with open(
absolute_project_root() / f"modules/cloud/{self.provider.value}/terraform.tf.json",
resources.files(modules) / f"cloud/{self.provider.value}/terraform.tf.json",
"r",
) as data_json:
data = json.load(data_json)

# add random provider
with open(
absolute_project_root() / "modules/terraform_providers/random/terraform.tf.json",
resources.files(modules) / "terraform_providers/random/terraform.tf.json",
"r",
) as random_tf:
random_tf_json = json.load(random_tf)
Expand All @@ -58,17 +69,21 @@ def configure_required_provider_config(self):
generate_tf_json(module_name="terraform", json_module=data)

def configure_deployment_config(self):
# inject vpc module
"""
Configures the deployment configuration based on the provider.
Generates a Terraform JSON file for the specific provider.
"""
if self.provider == CloudProvider.AWS:
json_module = {"module": {"vpc": {}}}
json_module["module"]["vpc"]["name"] = f"{self.stack_name}-vpc"
json_module["module"]["vpc"]["source"] = "./modules/cloud/aws/vpc"

if "config" in self.deployment_config and "vpc" in self.deployment_config["config"]:
for vpc_config in self.deployment_config["config"]["vpc"]:
json_module["module"]["vpc"][vpc_config] = self.deployment_config["config"][
"vpc"
].get(vpc_config, None)
json_module = {
"module": {
"vpc": {
"name": f"{self.stack_name}-vpc",
"source": "./modules/cloud/aws/vpc",
}
}
}
vpc_config = self.deployment_config.get("config", {}).get("vpc", {})
json_module["module"]["vpc"].update(vpc_config)

generate_tf_json(module_name="vpc", json_module=json_module)
elif self.provider == CloudProvider.GCP:
Expand All @@ -79,6 +94,9 @@ def configure_deployment_config(self):
raise ValueError(f"Provider {self.provider} is not supported")

def configure_deployment(self):
"""
Configures the deployment by calling the `configure_required_provider_config()` and `configure_deployment_config()` methods.
"""
self.configure_required_provider_config()
self.configure_deployment_config()

Expand Down
38 changes: 33 additions & 5 deletions src/mlinfra/stack_processor/deployment_processor/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,34 @@
# or implied. See the License for the specific language governing
# permissions and limitations under the License.

import json
from abc import ABC, abstractmethod

import yaml
from mlinfra.enums.cloud_provider import CloudProvider


class AbstractDeployment(ABC):
"""
Abstract class for deployment
Abstract class for deployment.
Args:
stack_name (str): The name of the deployment stack.
provider (CloudProvider): The cloud provider for the deployment.
region (str): The region for the deployment.
deployment_config (dict): The deployment configuration.
Attributes:
stack_name (str): The name of the deployment stack.
provider (CloudProvider): The cloud provider for the deployment.
region (str): The region for the deployment.
deployment_config (dict): The deployment configuration.
"""

def __init__(
self,
stack_name: str,
provider: CloudProvider,
region: str,
deployment_config: yaml,
deployment_config: dict,
):
self.stack_name = stack_name
self.provider = provider
Expand All @@ -36,13 +46,31 @@ def __init__(

@abstractmethod
def configure_deployment(self):
"""
Abstract method that must be implemented by subclasses to configure the deployment.
"""
pass

# TODO: refactor statefile name
def get_statefile_name(self) -> str:
"""
Get the name of the statefile for the deployment.
Returns:
str: The name of the statefile.
"""
return f"tfstate-{self.stack_name}-{self.region}"

def get_provider_backend(self, provider: CloudProvider) -> json:
def get_provider_backend(self, provider: CloudProvider) -> dict:
"""
Get the backend configuration for the specified provider.
Args:
provider (CloudProvider): The cloud provider.
Returns:
json: The backend configuration.
"""
if provider == CloudProvider.AWS:
return {
"backend": {
Expand Down
Loading

0 comments on commit 25461ad

Please sign in to comment.