Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config Cleanup #79

Merged
merged 13 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/CI-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,17 @@ jobs:
run: |
export HDF5_DEBUG=1
export NETCDF_DEBUG=1
export XARRAY_BACKEND=h5netcdf
export XARRAY_ENGINE=h5netcdf
export PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS=300
pytest -vvv -s --cov tests/meta/*.py
- name: Test with pytest (Unit)
run: |
export XARRAY_BACKEND=h5netcdf
export XARRAY_ENGINE=h5netcdf
export PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS=300
pytest -vvv -s --cov tests/unit/*.py
- name: Test with pytest (Integration)
run: |
export XARRAY_BACKEND=h5netcdf
export XARRAY_ENGINE=h5netcdf
export PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS=300
pytest -vvv -s --cov tests/integration/*.py
- name: Test with doctest
Expand Down
3 changes: 3 additions & 0 deletions examples/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.nc
slurm*.out
pymorize_report.log
43 changes: 22 additions & 21 deletions examples/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,22 @@
from pathlib import Path


def rm_file(fname):
try:
fname.unlink(fname)
print(f"Removed file: {fname}")
except Exception as e:
print(f"Error removing file {fname}: {e}")


def rm_dir(dirname):
try:
shutil.rmtree(dirname)
print(f"Removed directory: {dirname}")
except Exception as e:
print(f"Error removing directory {dirname}: {e}")


def cleanup():
current_dir = Path.cwd()

Expand All @@ -15,34 +31,19 @@ def cleanup():
and item.name.startswith("slurm")
and item.name.endswith("out")
):
try:
item.unlink()
print(f"Removed file: {item}")
except Exception as e:
print(f"Error removing file {item}: {e}")
rm_file(item)
if (
item.is_file()
and item.name.startswith("pymorize")
and item.name.endswith("json")
):
try:
item.unlink()
print(f"Removed file: {item}")
except Exception as e:
print(f"Error removing file {item}: {e}")
rm_file(item)
if item.is_file() and item.name.endswith("nc"):
try:
item.unlink()
print(f"Removed file: {item}")
except Exception as e:
print(f"Error removing file {item}: {e}")

rm_file(item)
if item.name == "pymorize_report.log":
rm_file(item)
elif item.is_dir() and item.name == "logs":
try:
shutil.rmtree(item)
print(f"Removed directory: {item}")
except Exception as e:
print(f"Error removing directory {item}: {e}")
rm_dir(item)
print("Cleanup completed.")


Expand Down
7 changes: 4 additions & 3 deletions examples/pymorize.slurm
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#!/bin/bash -l
#SBATCH --account=ab0246
#SBATCH --job-name=pymorize-controller # <<< This is the main job, it will launch subjobs if you have Dask enabled.
#SBATCH --account=ab0246 # <<< Adapt this to your computing account!
#SBATCH --partition=compute
#SBATCH --nodes=1
#SBATCH --time=00:30:00
# export PREFECT_SERVER_ALLOW_EPHEMERAL_MODE=False
#SBATCH --time=00:30:00 # <<< You may need more time, adapt as needed!
export PREFECT_SERVER_ALLOW_EPHEMERAL_MODE=True
export PREFECT_SERVER_API_HOST=0.0.0.0
conda activate pymorize
prefect server start &
Expand Down
7 changes: 5 additions & 2 deletions examples/sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@ pymorize:
# parallel: True
warn_on_no_rule: False
use_flox: True
cluster_mode: fixed
dask_cluster: "slurm"
dask_cluster_scaling_mode: fixed
fixed_jobs: 12
# minimum_jobs: 8
# maximum_jobs: 30
dimensionless_mapping_table: ../data/dimensionless_mappings.yaml
# You can add your own path to the dimensionless mapping table
# If nothing is specified here, it will use the built-in one.
# dimensionless_mapping_table: ../data/dimensionless_mappings.yaml
rules:
- name: paul_example_rule
description: "You can put some text here"
Expand Down
9 changes: 9 additions & 0 deletions src/pymorize/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,18 @@
"""

import dask
from dask.distributed import LocalCluster
from dask_jobqueue import SLURMCluster

from .logging import logger

CLUSTER_MAPPINGS = {
"local": LocalCluster,
"slurm": SLURMCluster,
}
CLUSTER_SCALE_SUPPORT = {"local": False, "slurm": True}
CLUSTER_ADAPT_SUPPORT = {"local": False, "slurm": True}


def set_dashboard_link(cluster):
"""
Expand Down
86 changes: 58 additions & 28 deletions src/pymorize/cmorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import xarray as xr # noqa: F401
import yaml
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
from everett.manager import generate_uppercase_key, get_runtime_config
from prefect import flow, task
from prefect.futures import wait
from rich.progress import track

from .cluster import set_dashboard_link
from .cluster import (CLUSTER_ADAPT_SUPPORT, CLUSTER_MAPPINGS,
CLUSTER_SCALE_SUPPORT, set_dashboard_link)
from .config import PymorizeConfig, PymorizeConfigManager
from .data_request import (DataRequest, DataRequestTable, DataRequestVariable,
IgnoreTableFiles)
Expand Down Expand Up @@ -88,17 +88,21 @@ def __init__(

################################################################################
# Post_Init:
if self._pymorize_cfg("parallel"):
if self._pymorize_cfg("parallel_backend") == "dask":
self._post_init_configure_dask()
self._post_init_create_dask_cluster()
if self._pymorize_cfg("enable_dask"):
logger.debug("Setting up dask configuration...")
self._post_init_configure_dask()
logger.debug("...done!")
logger.debug("Creating dask cluster...")
self._post_init_create_dask_cluster()
logger.debug("...done!")
self._post_init_create_pipelines()
self._post_init_create_rules()
self._post_init_read_bare_tables()
self._post_init_create_data_request()
self._post_init_populate_rules_with_tables()
self._post_init_read_dimensionless_unit_mappings()
self._post_init_data_request_variables()
logger.debug("...post-init done!")
################################################################################

def _post_init_configure_dask(self):
Expand All @@ -120,29 +124,42 @@ def _post_init_configure_dask(self):

def _post_init_create_dask_cluster(self):
# FIXME: In the future, we can support PBS, too.
logger.info("Setting up SLURMCluster...")
self._cluster = SLURMCluster()
logger.info("Setting up dask cluster...")
cluster_name = self._pymorize_cfg("dask_cluster")
ClusterClass = CLUSTER_MAPPINGS[cluster_name]
self._cluster = ClusterClass()
set_dashboard_link(self._cluster)
cluster_mode = self._pymorize_cfg.get("cluster_mode", "adapt")
if cluster_mode == "adapt":
min_jobs = self._pymorize_cfg.get("minimum_jobs", 1)
max_jobs = self._pymorize_cfg.get("maximum_jobs", 10)
self._cluster.adapt(minimum_jobs=min_jobs, maximum_jobs=max_jobs)
elif cluster_mode == "fixed":
jobs = self._pymorize_cfg.get("fixed_jobs", 5)
self._cluster.scale(jobs=jobs)
cluster_scaling_mode = self._pymorize_cfg.get(
"dask_cluster_scaling_mode", "adapt"
)
if cluster_scaling_mode == "adapt":
if CLUSTER_ADAPT_SUPPORT[cluster_name]:
min_jobs = self._pymorize_cfg.get(
"dask_cluster_scaling_minimum_jobs", 1
)
max_jobs = self._pymorize_cfg.get(
"dask_cluster_scaling_maximum_jobs", 10
)
self._cluster.adapt(minimum_jobs=min_jobs, maximum_jobs=max_jobs)
else:
logger.warning(f"{self._cluster} does not support adaptive scaling!")
elif cluster_scaling_mode == "fixed":
if CLUSTER_SCALE_SUPPORT[cluster_name]:
jobs = self._pymorize_cfg.get("dask_cluster_scaling_fixed_jobs", 5)
self._cluster.scale(jobs=jobs)
else:
logger.warning(f"{self._cluster} does not support fixed scaing")
else:
raise ValueError(
"You need to specify adapt or fixed for pymorize.cluster_mode"
"You need to specify adapt or fixed for pymorize.dask_cluster_scaling_mode"
)
# Wait for at least min_jobs to be available...
# FIXME: Client needs to be available here?
logger.info(f"SLURMCluster can be found at: {self._cluster=}")
# FIXME: Include the gateway option if possible
# FIXME: Does ``Client`` needs to be available here?
logger.info(f"Cluster can be found at: {self._cluster=}")
logger.info(f"Dashboard {self._cluster.dashboard_link}")
# NOTE(PG): In CI context, os.getlogin and nodename may not be available (???)

username = getpass.getuser()
nodename = getattr(os.uname(), "nodename", "UNKNOWN")
# FIXME: Include the gateway option if possible
logger.info(
"To see the dashboards run the following command in your computer's "
"terminal:\n"
Expand All @@ -152,7 +169,7 @@ def _post_init_create_dask_cluster(self):

dask_extras = 0
logger.info("Importing Dask Extras...")
if self._pymorize_cfg.get("use_flox", True):
if self._pymorize_cfg.get("enable_flox", True):
dask_extras += 1
logger.info("...flox...")
import flox # noqa: F401
Expand Down Expand Up @@ -337,7 +354,9 @@ def validate(self):
# self._check_rules_for_output_dir()
# FIXME(PS): Turn off this check, see GH #59 (https://tinyurl.com/3z7d8uuy)
# self._check_is_subperiod()
logger.debug("Starting validate....")
self._check_units()
logger.debug("...done!")

def _check_is_subperiod(self):
logger.info("checking frequency in netcdf file and in table...")
Expand Down Expand Up @@ -443,6 +462,7 @@ def from_dict(cls, data):
instance._post_init_create_data_request()
instance._post_init_data_request_variables()
instance._post_init_read_dimensionless_unit_mappings()
logger.debug("Object creation done!")
return instance

def add_rule(self, rule):
Expand Down Expand Up @@ -509,16 +529,23 @@ def check_rules_for_output_dir(self, output_dir):
logger.warning(filepath)

def process(self, parallel=None):
logger.debug("Process start!")
if parallel is None:
parallel = self._pymorize_cfg.get("parallel", True)
if parallel:
parallel_backend = self._pymorize_cfg.get("parallel_backend", "prefect")
return self.parallel_process(backend=parallel_backend)
logger.debug("Parallel processing...")
# FIXME(PG): This is mixed up, hard-coding to prefect for now...
workflow_backend = self._pymorize_cfg.get(
"pipeline_orchestrator", "prefect"
)
logger.debug(f"...with {workflow_backend}...")
return self.parallel_process(backend=workflow_backend)
else:
return self.serial_process()

def parallel_process(self, backend="prefect"):
if backend == "prefect":
logger.debug("About to submit _parallel_process_prefect()")
return self._parallel_process_prefect()
elif backend == "dask":
return self._parallel_process_dask()
Expand All @@ -529,6 +556,8 @@ def _parallel_process_prefect(self):
# prefect_logger = get_run_logger()
# logger = prefect_logger
# @flow(task_runner=DaskTaskRunner(address=self._cluster.scheduler_address))
logger.debug("Defining dynamically generated prefect workflow...")

@flow
def dynamic_flow():
rule_results = []
Expand All @@ -537,6 +566,9 @@ def dynamic_flow():
wait(rule_results)
return rule_results

logger.debug("...done!")

logger.debug("About to return dynamic_flow()...")
return dynamic_flow()

def _parallel_process_dask(self, external_client=None):
Expand Down Expand Up @@ -567,13 +599,11 @@ def _process_rule(self, rule):
# FIXME(PG): This might also be a place we need to consider copies...
rule.match_pipelines(self.pipelines)
data = None
# NOTE(PG): Send in a COPY of the rule, not the original rule
local_rule_copy = copy.deepcopy(rule)
if not len(rule.pipelines) > 0:
logger.error("No pipeline defined, something is wrong!")
for pipeline in rule.pipelines:
logger.info(f"Running {str(pipeline)}")
data = pipeline.run(data, local_rule_copy)
data = pipeline.run(data, rule)
return data

@task
Expand Down
Loading