Skip to content

Commit

Permalink
Add support for experiment and lab directory nesting in packages, cle…
Browse files Browse the repository at this point in the history
…an up package validation
  • Loading branch information
aangelos28 committed Sep 17, 2024
1 parent 414726a commit 7eaf9ac
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 197 deletions.
14 changes: 4 additions & 10 deletions eos/configuration/configuration_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os
from typing import TYPE_CHECKING


from eos.configuration.exceptions import (
EosConfigurationError,
)
Expand Down Expand Up @@ -57,10 +55,8 @@ def get_lab_loaded_statuses(self) -> dict[str, bool]:
all_labs = set()

for package in self._package_manager.get_all_packages():
labs_dir = package.labs_dir
if labs_dir.is_dir():
package_labs = [d for d in os.listdir(labs_dir) if (labs_dir / d).is_dir()]
all_labs.update(package_labs)
package_labs = self._package_manager.get_labs_in_package(package.name)
all_labs.update(package_labs)

return {lab: lab in self.labs for lab in all_labs}

Expand Down Expand Up @@ -132,10 +128,8 @@ def get_experiment_loaded_statuses(self) -> dict[str, bool]:
all_experiments = set()

for package in self._package_manager.get_all_packages():
experiments_dir = package.experiments_dir
if experiments_dir.is_dir():
package_experiments = [d for d in os.listdir(experiments_dir) if (experiments_dir / d).is_dir()]
all_experiments.update(package_experiments)
package_experiments = self._package_manager.get_experiments_in_package(package.name)
all_experiments.update(package_experiments)

return {exp: exp in self.experiments for exp in all_experiments}

Expand Down
79 changes: 65 additions & 14 deletions eos/configuration/package_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,13 @@ class EntityConfigReader(Generic[T]):
"""

@staticmethod
def read_entity_config(file_path: str, entity_type: EntityType) -> ConfigType:
def read_entity_config(user_dir: str, file_path: str, entity_type: EntityType) -> ConfigType:
entity_info = ENTITY_INFO[entity_type]
return EntityConfigReader._read_config(file_path, entity_info.config_type, f"{entity_type.name}")
return EntityConfigReader._read_config(user_dir, file_path, entity_info.config_type, f"{entity_type.name}")

@staticmethod
def read_all_entity_configs(base_dir: str, entity_type: EntityType) -> tuple[dict[str, ConfigType], dict[str, str]]:
def read_all_entity_configs(user_dir: str, base_dir: str, entity_type: EntityType) -> tuple[
dict[str, ConfigType], dict[str, str]]:
entity_info = ENTITY_INFO[entity_type]
configs = {}
dirs_to_types = {}
Expand All @@ -87,7 +88,7 @@ def read_all_entity_configs(base_dir: str, entity_type: EntityType) -> tuple[dic
config_file_path = Path(root) / entity_info.config_file_name

try:
structured_config = EntityConfigReader.read_entity_config(str(config_file_path), entity_type)
structured_config = EntityConfigReader.read_entity_config(user_dir, str(config_file_path), entity_type)
entity_type_name = structured_config.type
configs[entity_type_name] = structured_config
dirs_to_types[entity_subdir] = entity_type_name
Expand All @@ -104,9 +105,9 @@ def read_all_entity_configs(base_dir: str, entity_type: EntityType) -> tuple[dic
return configs, dirs_to_types

@staticmethod
def _read_config(file_path: str, config_type: type[ConfigType], config_name: str) -> ConfigType:
def _read_config(user_dir: str, file_path: str, config_type: type[ConfigType], config_name: str) -> ConfigType:
try:
config_data = EntityConfigReader._process_jinja_yaml(file_path)
config_data = EntityConfigReader._process_jinja_yaml(user_dir, file_path)

structured_config = OmegaConf.merge(OmegaConf.structured(config_type), OmegaConf.create(config_data))
_ = OmegaConf.to_object(structured_config)
Expand All @@ -122,7 +123,7 @@ def _read_config(file_path: str, config_type: type[ConfigType], config_name: str
raise EosConfigurationError(f"Error processing {config_name} configuration: {e!s}") from e

@staticmethod
def _process_jinja_yaml(file_path: str) -> dict[str, Any]:
def _process_jinja_yaml(user_dir: str, file_path: str) -> dict[str, Any]:
"""
Process a YAML file with Jinja2 templating, without passing any variables.
Expand All @@ -139,7 +140,7 @@ def _process_jinja_yaml(file_path: str) -> dict[str, Any]:

try:
env = jinja2.Environment(
loader=jinja2.FileSystemLoader(Path(file_path).parents[3]), # user directory
loader=jinja2.FileSystemLoader(Path(user_dir)), # user directory
undefined=jinja2.StrictUndefined,
autoescape=True,
)
Expand Down Expand Up @@ -171,7 +172,7 @@ def discover_packages(self) -> dict[str, Package]:
package_path = self.user_dir / item

if package_path.is_dir():
packages[item] = Package(item, package_path)
packages[item] = Package(item, str(package_path))

return packages

Expand Down Expand Up @@ -201,12 +202,12 @@ def __init__(self, user_dir: str):
def read_lab_config(self, lab_name: str) -> LabConfig:
entity_location = self._get_entity_location(lab_name, EntityType.LAB)
config_file_path = self._get_config_file_path(entity_location, EntityType.LAB)
return EntityConfigReader.read_entity_config(config_file_path, EntityType.LAB)
return EntityConfigReader.read_entity_config(self.user_dir, config_file_path, EntityType.LAB)

def read_experiment_config(self, experiment_name: str) -> ExperimentConfig:
entity_location = self._get_entity_location(experiment_name, EntityType.EXPERIMENT)
config_file_path = self._get_config_file_path(entity_location, EntityType.EXPERIMENT)
return EntityConfigReader.read_entity_config(config_file_path, EntityType.EXPERIMENT)
return EntityConfigReader.read_entity_config(self.user_dir, config_file_path, EntityType.EXPERIMENT)

def read_task_configs(self) -> tuple[dict[str, TaskSpecification], dict[str, str]]:
return self._read_all_entity_configs(EntityType.TASK)
Expand Down Expand Up @@ -238,12 +239,61 @@ def find_package_for_task(self, task_name: str) -> Package | None:
def find_package_for_device(self, device_name: str) -> Package | None:
return self._find_package_for_entity(device_name, EntityType.DEVICE)

def get_entity_dir(self, entity_name: str, entity_type: EntityType) -> Path:
entity_location = self._get_entity_location(entity_name, entity_type)
package = self.packages[entity_location.package_name]
return Path(getattr(package, f"{ENTITY_INFO[entity_type].dir_name}_dir") / entity_location.entity_path)

def get_experiments_in_package(self, package_name: str) -> list[str]:
package = self.get_package(package_name)
if not package:
raise EosMissingConfigurationError(f"Package '{package_name}' not found")

return [
entity_name
for entity_name, location in self.entity_indices[EntityType.EXPERIMENT].items()
if location.package_name == package_name
]

def get_labs_in_package(self, package_name: str) -> list[str]:
package = self.get_package(package_name)
if not package:
raise EosMissingConfigurationError(f"Package '{package_name}' not found")

return [
entity_name
for entity_name, location in self.entity_indices[EntityType.LAB].items()
if location.package_name == package_name
]

def get_tasks_in_package(self, package_name: str) -> list[str]:
package = self.get_package(package_name)
if not package:
raise EosMissingConfigurationError(f"Package '{package_name}' not found")

return [
entity_name
for entity_name, location in self.entity_indices[EntityType.TASK].items()
if location.package_name == package_name
]

def get_devices_in_package(self, package_name: str) -> list[str]:
package = self.get_package(package_name)
if not package:
raise EosMissingConfigurationError(f"Package '{package_name}' not found")

return [
entity_name
for entity_name, location in self.entity_indices[EntityType.DEVICE].items()
if location.package_name == package_name
]

def add_package(self, package_name: str) -> None:
package_path = Path(self.user_dir) / package_name
if not package_path.is_dir():
raise EosMissingConfigurationError(f"Package directory '{package_path}' does not exist")

new_package = Package(package_name, package_path)
new_package = Package(package_name, str(package_path))
PackageValidator(self.user_dir, {package_name: new_package}).validate()

self.packages[package_name] = new_package
Expand Down Expand Up @@ -317,7 +367,7 @@ def _get_config_file_path(self, entity_location: EntityLocationInfo, entity_type
EosMissingConfigurationError,
)

return config_file_path
return str(config_file_path)

def _read_all_entity_configs(self, entity_type: EntityType) -> tuple[dict[str, T], dict[str, str]]:
all_configs = {}
Expand All @@ -326,7 +376,8 @@ def _read_all_entity_configs(self, entity_type: EntityType) -> tuple[dict[str, T
entity_dir = Path(getattr(package, f"{ENTITY_INFO[entity_type].dir_name}_dir"))
if not entity_dir.is_dir():
continue
configs, dirs_to_types = EntityConfigReader.read_all_entity_configs(entity_dir, entity_type)
configs, dirs_to_types = EntityConfigReader.read_all_entity_configs(self.user_dir, str(entity_dir),
entity_type)
all_configs.update(configs)
all_dirs_to_types.update({Path(package.name) / k: v for k, v in dirs_to_types.items()})
return all_configs, all_dirs_to_types
159 changes: 1 addition & 158 deletions eos/configuration/package_validator.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,5 @@
import os
from pathlib import Path

from eos.configuration.constants import (
COMMON_DIR,
EXPERIMENTS_DIR,
LABS_DIR,
DEVICES_DIR,
TASKS_DIR,
LAB_CONFIG_FILE_NAME,
EXPERIMENT_CONFIG_FILE_NAME,
TASK_CONFIG_FILE_NAME,
TASK_IMPLEMENTATION_FILE_NAME,
DEVICE_CONFIG_FILE_NAME,
DEVICE_IMPLEMENTATION_FILE_NAME,
)
from eos.configuration.exceptions import EosMissingConfigurationError, EosConfigurationError
from eos.configuration.exceptions import EosMissingConfigurationError
from eos.configuration.package import Package
from eos.logging.logger import log


class PackageValidator:
Expand All @@ -31,143 +14,3 @@ def __init__(self, user_dir: str, packages: dict[str, Package]):
def validate(self) -> None:
if not self.packages:
raise EosMissingConfigurationError(f"No valid packages found in the user directory '{self.user_dir}'")

for package in self.packages.values():
self._validate_package_structure(package)

def _validate_package_structure(self, package: Package) -> None:
"""
Validate the structure of a single package.
"""
if not any(
[
package.common_dir.is_dir(),
package.experiments_dir.is_dir(),
package.labs_dir.is_dir(),
package.devices_dir.is_dir(),
package.tasks_dir.is_dir(),
]
):
raise EosMissingConfigurationError(
f"Package '{package.name}' does not contain any of the directories: "
f"{COMMON_DIR}, {EXPERIMENTS_DIR}, {LABS_DIR}, {DEVICES_DIR}, {TASKS_DIR}"
)

if package.labs_dir.is_dir():
self._validate_labs_dir(package)

if package.experiments_dir.is_dir():
self._validate_experiments_dir(package)

if package.devices_dir.is_dir():
self._validate_devices_dir(package)

if package.tasks_dir.is_dir():
self._validate_tasks_dir(package)

@staticmethod
def _validate_labs_dir(package: Package) -> None:
"""
Validate the structure of the labs directory.
"""
for file in os.listdir(package.labs_dir):
file_path = package.labs_dir / file
if not file_path.is_dir():
raise EosConfigurationError(
f"Non-directory file found in '{package.labs_dir}'. Only lab directories are allowed."
)

for lab in os.listdir(package.labs_dir):
lab_file_path = package.labs_dir / lab / LAB_CONFIG_FILE_NAME
if not lab_file_path.is_file():
raise EosMissingConfigurationError(f"Lab file '{LAB_CONFIG_FILE_NAME}' does not exist for lab '{lab}'")

log.debug(f"Detected lab '{lab}' in package '{package.name}'")

@staticmethod
def _validate_experiments_dir(package: Package) -> None:
"""
Validate the structure of the experiments directory.
"""
for file in os.listdir(package.experiments_dir):
file_path = package.experiments_dir / file
if not file_path.is_dir():
raise EosConfigurationError(
f"Non-directory file found in '{package.experiments_dir}'. Only experiment directories "
f"are allowed."
)

experiment_config_file = file_path / EXPERIMENT_CONFIG_FILE_NAME
if not experiment_config_file.is_file():
raise EosMissingConfigurationError(
f"Experiment configuration file '{EXPERIMENT_CONFIG_FILE_NAME}' does not exist for "
f"experiment '{file}'"
)

log.debug(f"Detected experiment '{file}' in package '{package.name}'")

@staticmethod
def _validate_tasks_dir(package: Package) -> None:
"""
Validate the structure of the tasks directory.
Ensure each subdirectory represents a task and contains the necessary files.
"""
task_types = []
for current_dir, _, files in os.walk(package.tasks_dir):
if TASK_CONFIG_FILE_NAME not in files:
continue

task_dir = Path(current_dir)
task_name = task_dir.relative_to(package.tasks_dir)

config_file = task_dir / TASK_CONFIG_FILE_NAME
implementation_file = task_dir / TASK_IMPLEMENTATION_FILE_NAME

if not config_file.is_file():
raise EosMissingConfigurationError(
f"Task configuration file '{TASK_CONFIG_FILE_NAME}' not found for task '{task_name}' "
f"in package '{package.name}'."
)

if not implementation_file.is_file():
raise EosMissingConfigurationError(
f"Task implementation file '{TASK_IMPLEMENTATION_FILE_NAME}' not found for task "
f"'{task_name}' in package '{package.name}'."
)

task_types.append(task_dir)

log.debug(f"Detected tasks '{task_types}' in package '{package.name}'")

@staticmethod
def _validate_devices_dir(package: Package) -> None:
"""
Validate the structure of the devices directory.
Ensure each subdirectory represents a device and contains the necessary files.
"""
device_types = []
for current_dir, _, files in os.walk(package.devices_dir):
if DEVICE_CONFIG_FILE_NAME not in files:
continue

device_dir = Path(current_dir)
device_name = device_dir.relative_to(package.devices_dir)

config_file = device_dir / DEVICE_CONFIG_FILE_NAME
implementation_file = device_dir / DEVICE_IMPLEMENTATION_FILE_NAME

if not config_file.is_file():
raise EosMissingConfigurationError(
f"Device configuration file '{DEVICE_CONFIG_FILE_NAME}' not found for device "
f"'{device_name}' in package '{package.name}'."
)

if not implementation_file.is_file():
raise EosMissingConfigurationError(
f"Device implementation file '{DEVICE_IMPLEMENTATION_FILE_NAME}' not found for device "
f"'{device_name}' in package '{package.name}'."
)

device_types.append(device_dir)

log.debug("Detected devices '%S' in package '%s'", device_types, package.name)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
CAMPAIGN_OPTIMIZER_FILE_NAME,
CAMPAIGN_OPTIMIZER_CREATION_FUNCTION_NAME,
)
from eos.configuration.package_manager import PackageManager
from eos.configuration.package_manager import PackageManager, EntityType
from eos.configuration.plugin_registries.plugin_registry import PluginRegistry, PluginRegistryConfig
from eos.logging.logger import log
from eos.optimization.abstract_sequential_optimizer import AbstractSequentialOptimizer
Expand Down Expand Up @@ -76,9 +76,8 @@ def load_campaign_optimizer(self, experiment_type: str) -> None:
log.warning(f"No package found for experiment '{experiment_type}'.")
return

optimizer_file = (
Path(experiment_package.experiments_dir) / experiment_type / self._config.implementation_file_name
)
optimizer_file = self._package_manager.get_entity_dir(experiment_type,
EntityType.EXPERIMENT) / CAMPAIGN_OPTIMIZER_FILE_NAME

if not Path(optimizer_file).exists():
log.warning(
Expand Down
Loading

0 comments on commit 7eaf9ac

Please sign in to comment.