diff --git a/leverage/checker/__init__.py b/leverage/checker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/leverage/checker/checks.py b/leverage/checker/checks.py new file mode 100644 index 0000000..4aaad85 --- /dev/null +++ b/leverage/checker/checks.py @@ -0,0 +1,71 @@ +import os +from abc import ABC, abstractmethod +from typing import List + +import hcl2 +import yaml + + +class VersionCheck(ABC): + def __init__(self, name: str, version_rule: str): + self.name = name + self.version_rule = version_rule + + @abstractmethod + def run(self) -> None: + """Implement this method to check version compatibility""" + pass + + +class CommandVersionCheck(VersionCheck): + def __init__(self, name: str, version_rule: str): + super().__init__(name, version_rule) + self.modules: List[ModuleVersionCheck] = [] + + def add_module(self, module: "ModuleVersionCheck") -> None: + self.modules.append(module) + + def run(self) -> None: + print(f"Checking command {self.name} with version rule {self.version_rule}") + for module in self.modules: + module.run() + + +class ModuleVersionCheck(VersionCheck): + def run(self) -> None: + print(f"Checking module {self.name} with version rule {self.version_rule}") + + +class CommandGroupCheck: + def __init__(self): + self.commands: List[CommandVersionCheck] = [] + + def add_command(self, command: CommandVersionCheck) -> None: + self.commands.append(command) + + def run(self) -> None: + print("Running group checks for all commands...") + for command in self.commands: + command.run() + print("All group checks completed successfully.") + + +def load_config(filename: str = "commands.yml") -> dict: + with open(filename, "r") as file: + data = yaml.safe_load(file) + return data + + +def setup_version_check_hierarchy(config: dict) -> CommandGroupCheck: + root_group = CommandGroupCheck() # This will hold all top-level commands + for cmd_info in config["commands"]: + command = CommandVersionCheck(name=cmd_info["name"], version_rule=cmd_info["version_rule"]) + root_group.add_command(command) + for mod_info in cmd_info.get("modules", []): + module = ModuleVersionCheck(name=mod_info["name"], version_rule=mod_info["version_rule"]) + command.add_module(module) + return root_group + + +def run_checks(check: CommandGroupCheck) -> None: + check.run() diff --git a/leverage/checker/utils.py b/leverage/checker/utils.py new file mode 100644 index 0000000..d0d16ed --- /dev/null +++ b/leverage/checker/utils.py @@ -0,0 +1,30 @@ +import logging +import time + + +class TimeIt: + """ + Context manager to measure and log the execution time of a block of code. + It uses the logger provided, or defaults to the root logger if none is provided. + + Args: + - task_name (str): A name for the task to help identify it in the logs. + - logger (logging.Logger): Optional. A logger object to use for logging the time. + """ + + def __init__(self, task_name="Unnamed Task", logger=None): + self.task_name = task_name + self.logger = logger if logger is not None else logging.getLogger(__name__) + self.start_time = None + + def __enter__(self): + self.start_time = time.time() + return self # You can return anything that might be useful, but here we don't need to + + def __exit__(self, exc_type, exc_val, exc_tb): + elapsed_time = time.time() - self.start_time + hours, remainder = divmod(elapsed_time, 3600) + minutes, seconds = divmod(remainder, 60) + milliseconds = (elapsed_time - int(elapsed_time)) * 1000 + human_readable = "{:02}:{:02}:{:02}.{:03}".format(int(hours), int(minutes), int(seconds), int(milliseconds)) + self.logger.debug(f"{self.task_name} took {human_readable} (hh:mm:ss.mmm)") diff --git a/leverage/checker/version_parser.py b/leverage/checker/version_parser.py new file mode 100644 index 0000000..b085c6e --- /dev/null +++ b/leverage/checker/version_parser.py @@ -0,0 +1,84 @@ +import os +import re +from typing import Dict, Any, Union, List + +import hcl2 + + +class VersionExtractor: + """Extracts versions from parsed Terraform configurations using best practices with type tagging.""" + + @staticmethod + def extract_versions(tf_config: Dict[str, Any]) -> Dict[str, Dict[str, str]]: + versions = {} + VersionExtractor.extract_core_and_providers(tf_config, versions) + VersionExtractor.extract_module_versions(tf_config, versions) + return versions + + @staticmethod + def extract_core_and_providers(tf_config: Dict[str, Any], versions: Dict[str, Dict[str, str]]): + for terraform_block in tf_config.get("terraform", []): + if isinstance(terraform_block, dict): + if "required_version" in terraform_block: + versions["terraform"] = {"type": "core", "version": terraform_block["required_version"]} + if "required_providers" in terraform_block: + VersionExtractor.process_providers(terraform_block["required_providers"], versions) + + @staticmethod + def process_providers(providers: Union[Dict[str, Any], List[Dict[str, Any]]], versions: Dict[str, Dict[str, str]]): + if isinstance(providers, dict): + VersionExtractor.extract_provider_versions(providers, versions) + elif isinstance(providers, list): + for provider_dict in providers: + VersionExtractor.extract_provider_versions(provider_dict, versions) + else: + print(f"Error: Providers data structure not recognized: {providers}") + + @staticmethod + def extract_provider_versions(providers: Dict[str, Any], versions: Dict[str, Dict[str, str]]): + for provider, details in providers.items(): + if isinstance(details, dict) and "version" in details: + versions[provider] = {"type": "provider", "version": details["version"]} + elif isinstance(details, str): + versions[provider] = {"type": "provider", "version": details} + + @staticmethod + def extract_module_versions(tf_config: Dict[str, Any], versions: Dict[str, Dict[str, str]]): + module_version_pattern = re.compile(r"\?ref=v([\d\.]+)$") + for module in tf_config.get("module", []): + if isinstance(module, dict) and "source" in module: + source = module["source"] + match = module_version_pattern.search(source) + if match: + versions[f"Module: {module['source']}"] = {"type": "module", "version": match.group(1)} + elif "version" in module: + versions[f"Module: {module['source']}"] = {"type": "module", "version": module["version"]} + + +class TerraformFileParser: + @staticmethod + def load(file_path: str) -> Dict[str, Any]: + with open(file_path, "r") as tf_file: + return hcl2.load(tf_file) + + +class VersionManager: + def __init__(self, directory: str): + self.directory = directory + + def find_versions(self) -> Dict[str, Dict[str, str]]: + versions = {} + for root, _, files in os.walk(self.directory): + for file in files: + if file.endswith(".tf"): + file_path = os.path.join(root, file) + try: + tf_config = TerraformFileParser.load(file_path) + versions.update(VersionExtractor.extract_versions(tf_config)) + except Exception as e: + self.handle_error(e, file_path) + return versions + + @staticmethod + def handle_error(e: Exception, file_path: str): + print(f"Error processing {file_path}: {e}") diff --git a/leverage/checker/versions.yaml b/leverage/checker/versions.yaml new file mode 100644 index 0000000..e229aa7 --- /dev/null +++ b/leverage/checker/versions.yaml @@ -0,0 +1,8 @@ +terraform: + core: ">=1.0.0" + providers: + aws: ">=3.40.0" + google: ">=3.50.0" + modules: + networking: ">=2.1.0" + security: ">=1.5.0" diff --git a/leverage/modules/terraform.py b/leverage/modules/terraform.py index 7e769c4..bcfe8b9 100644 --- a/leverage/modules/terraform.py +++ b/leverage/modules/terraform.py @@ -1,3 +1,4 @@ +import json import os import re @@ -10,6 +11,8 @@ from leverage._internals import pass_container from leverage._internals import pass_state from leverage._utils import tar_directory, AwsCredsContainer, LiveContainer, ExitError +from leverage.checker.utils import TimeIt +from leverage.checker.version_parser import VersionManager from leverage.container import TerraformContainer from leverage.container import get_docker_client from leverage.modules.utils import env_var_option, mount_option, auth_mfa, auth_sso @@ -549,3 +552,18 @@ def _validate_layout(tf: TerraformContainer): valid_layout = False return valid_layout + + +@terraform.command(context_settings=CONTEXT_SETTINGS) +@click.argument("args", nargs=-1) +@pass_container +@click.pass_context +def checks(context, tf, args): + """Run pre-flight checks for terraform""" + with TimeIt("Check layer location"): + tf.paths.check_for_layer_location() # We want this to be run at layer level! + with TimeIt("Parse current path tf files"): + version_manager = VersionManager(tf.paths.cwd) + found_versions = version_manager.find_versions() + logger.info(f"Current path: {tf.paths.cwd}") + logger.info(f"Versions found: {json.dumps(found_versions, indent=4, sort_keys=True)}")