Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor eoexecutor #764

Merged
merged 8 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 22 additions & 43 deletions eolearn/core/eoexecution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,25 @@
import concurrent.futures
import datetime as dt
import inspect
import itertools as it
import logging
import threading
import warnings
from dataclasses import dataclass
from logging import FileHandler, Filter, Handler, Logger
from typing import Any, Callable, Protocol, Sequence, Union
from typing import Any, Callable, Iterable, Protocol, Union

import fs
from fs.base import FS

from sentinelhub.exceptions import deprecated_function

from .eonode import EONode
from .eoworkflow import EOWorkflow, WorkflowResults
from .exceptions import EORuntimeWarning, TemporalDimensionWarning
from .exceptions import EODeprecationWarning, EORuntimeWarning, TemporalDimensionWarning
from .utils.fs import get_base_filesystem_and_path, get_full_path, pickle_fs, unpickle_fs
from .utils.logging import LogFileFilter
from .utils.parallelize import _decide_processing_type, _ProcessingType, parallelize
from .utils.parallelize import parallelize


class _HandlerWithFsFactoryType(Protocol):
Expand Down Expand Up @@ -79,9 +82,9 @@ class EOExecutor:
def __init__(
self,
workflow: EOWorkflow,
execution_kwargs: Sequence[dict[EONode, dict[str, object]]],
execution_kwargs: Iterable[dict[EONode, dict[str, object]]],
*,
execution_names: list[str] | None = None,
execution_names: Iterable[str] | None = None,
save_logs: bool = False,
logs_folder: str = ".",
filesystem: FS | None = None,
Expand Down Expand Up @@ -128,27 +131,23 @@ def __init__(

@staticmethod
def _parse_and_validate_execution_kwargs(
execution_kwargs: Sequence[dict[EONode, dict[str, object]]]
execution_kwargs: Iterable[dict[EONode, dict[str, object]]]
) -> list[dict[EONode, dict[str, object]]]:
"""Parses and validates execution arguments provided by user and raises an error if something is wrong."""
if not isinstance(execution_kwargs, (list, tuple)):
mlubej marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Parameter 'execution_kwargs' should be a list.")

for input_kwargs in execution_kwargs:
EOWorkflow.validate_input_kwargs(input_kwargs)

return [input_kwargs or {} for input_kwargs in execution_kwargs]
return list(execution_kwargs)

@staticmethod
def _parse_execution_names(execution_names: list[str] | None, execution_kwargs: Sequence) -> list[str]:
def _parse_execution_names(execution_names: Iterable[str] | None, execution_kwargs: list) -> list[str]:
"""Parses a list of execution names."""
if execution_names is None:
return [str(num) for num in range(1, len(execution_kwargs) + 1)]

if not isinstance(execution_names, (list, tuple)) or len(execution_names) != len(execution_kwargs):
raise ValueError(
"Parameter 'execution_names' has to be a list of the same size as the list of execution arguments."
)
execution_names = list(execution_names)
mlubej marked this conversation as resolved.
Show resolved Hide resolved
if len(execution_names) != len(execution_kwargs):
raise ValueError("Parameter 'execution_names' has to be of the same size as `execution_kwargs`.")
return execution_names

@staticmethod
Expand Down Expand Up @@ -181,11 +180,7 @@ def run(self, workers: int | None = 1, multiprocess: bool = True, **tqdm_kwargs:
if self.save_logs:
self.filesystem.makedirs(self.report_folder, recreate=True)

log_paths: Sequence[str | None]
if self.save_logs:
log_paths = self.get_log_paths(full_path=False)
else:
log_paths = [None] * len(self.execution_kwargs)
log_paths = self.get_log_paths(full_path=False) if self.save_logs else it.repeat(None)
mlubej marked this conversation as resolved.
Show resolved Hide resolved

filter_logs_by_thread = not multiprocess and workers is not None and workers > 1
processing_args = [
Expand All @@ -205,8 +200,7 @@ def run(self, workers: int | None = 1, multiprocess: bool = True, **tqdm_kwargs:
full_execution_results = self._run_execution(processing_args, run_params)

self.execution_results = [results.drop_outputs() for results in full_execution_results]
processing_type = self._get_processing_type(workers=workers, multiprocess=multiprocess)
self.general_stats = self._prepare_general_stats(workers, processing_type)
self.general_stats = self._prepare_general_stats(workers)

return full_execution_results

Expand Down Expand Up @@ -306,52 +300,39 @@ def _build_log_handler(

return handler

@staticmethod
def _get_processing_type(workers: int | None, multiprocess: bool) -> _ProcessingType:
"""Provides a type of processing according to given parameters."""
return _decide_processing_type(workers=workers, multiprocess=multiprocess)

def _prepare_general_stats(self, workers: int | None, processing_type: _ProcessingType) -> dict[str, object]:
def _prepare_general_stats(self, workers: int | None) -> dict[str, object]:
"""Prepares a dictionary with a general statistics about executions."""
failed_count = sum(results.workflow_failed() for results in self.execution_results)
return {
self.STATS_START_TIME: self.start_time,
self.STATS_END_TIME: dt.datetime.now(),
"finished": len(self.execution_results) - failed_count,
"failed": failed_count,
"processing_type": processing_type.value,
"workers": workers,
}

def get_successful_executions(self) -> list[int]:
"""Returns a list of IDs of successful executions. The IDs are integers from interval
`[0, len(execution_kwargs) - 1]`, sorted in increasing order.

:return: List of successful execution IDs
"""
return [idx for idx, results in enumerate(self.execution_results) if not results.workflow_failed()]

def get_failed_executions(self) -> list[int]:
"""Returns a list of IDs of failed executions. The IDs are integers from interval
`[0, len(execution_kwargs) - 1]`, sorted in increasing order.

:return: List of failed execution IDs
"""
return [idx for idx, results in enumerate(self.execution_results) if results.workflow_failed()]

def get_report_path(self, full_path: bool = True) -> str:
"""Returns the filename and file path of the report.

:param full_path: A flag to specify if it should return full absolute paths or paths relative to the
filesystem object.
:param full_path: Whether to return full absolute paths or paths relative to the filesystem object.
:return: Report filename
"""
if self.report_folder is None:
raise RuntimeError("Executor has to be run before the report path is created.")
report_path = fs.path.combine(self.report_folder, self.REPORT_FILENAME)
if full_path:
return get_full_path(self.filesystem, report_path)
return report_path
return get_full_path(self.filesystem, report_path) if full_path else report_path

def make_report(self, include_logs: bool = True) -> None:
"""Makes a html report and saves it into the same folder where logs are stored.
Expand All @@ -373,17 +354,15 @@ def make_report(self, include_logs: bool = True) -> None:
def get_log_paths(self, full_path: bool = True) -> list[str]:
"""Returns a list of file paths containing logs.

:param full_path: A flag to specify if it should return full absolute paths or paths relative to the
filesystem object.
:param full_path: Whether to return full absolute paths or paths relative to the filesystem object.
:return: A list of paths to log files.
"""
if self.report_folder is None:
raise RuntimeError("Executor has to be run before log paths are created.")
log_paths = [fs.path.combine(self.report_folder, f"eoexecution-{name}.log") for name in self.execution_names]
if full_path:
return [get_full_path(self.filesystem, path) for path in log_paths]
return log_paths
return [get_full_path(self.filesystem, path) for path in log_paths] if full_path else log_paths

@deprecated_function(EODeprecationWarning)
def read_logs(self) -> list[str | None]:
"""Loads the content of log files if logs have been saved."""
if not self.save_logs:
Expand Down
7 changes: 1 addition & 6 deletions eolearn/core/extra/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from ..eoexecution import EOExecutor, _ExecutionRunParams, _ProcessingData
from ..eoworkflow import WorkflowResults
from ..utils.parallelize import _base_join_futures_iter, _ProcessingType
from ..utils.parallelize import _base_join_futures_iter

# pylint: disable=invalid-name
InputType = TypeVar("InputType")
Expand Down Expand Up @@ -51,11 +51,6 @@ def _run_execution(
futures = [_ray_workflow_executor.remote(workflow_args) for workflow_args in processing_args]
return join_ray_futures(futures, **run_params.tqdm_kwargs)

@staticmethod
def _get_processing_type(*_: Any, **__: Any) -> _ProcessingType:
"""Provides a type of processing for later references."""
return _ProcessingType.RAY


@ray.remote
def _ray_workflow_executor(workflow_args: _ProcessingData) -> WorkflowResults:
Expand Down
9 changes: 7 additions & 2 deletions eolearn/visualization/eoexecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,13 @@ def make_report(self, include_logs: bool = True) -> None:
template = self._get_template()

execution_log_filenames = [fs.path.basename(log_path) for log_path in self.eoexecutor.get_log_paths()]
if self.eoexecutor.save_logs:
execution_logs = self.eoexecutor.read_logs() if include_logs else None
if not include_logs:
execution_logs = None
elif self.eoexecutor.save_logs:
execution_logs = []
for log_path in self.eoexecutor.get_log_paths():
mlubej marked this conversation as resolved.
Show resolved Hide resolved
with self.eoexecutor.filesystem.open(log_path, "r") as file_handle:
execution_logs.append(file_handle.read())
else:
execution_logs = ["No logs saved"] * len(self.eoexecutor.execution_kwargs)

Expand Down
3 changes: 0 additions & 3 deletions eolearn/visualization/report_templates/report.html
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,6 @@ <h2> Execution status </h2>
<li>
Number of failed executions: {{ general_stats['failed'] }}
</li>
<li>
Processing type: {{ general_stats['processing_type'] }}
</li>
<li>
Number of workers: {{ general_stats['workers'] }}
</li>
Expand Down
34 changes: 17 additions & 17 deletions tests/core/test_eoexecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
OutputTask,
WorkflowResults,
execute_with_mp_lock,
linearly_connect_tasks,
)
from eolearn.core.utils.fs import get_full_path

Expand Down Expand Up @@ -129,7 +128,7 @@ def __init__(self, path: str, filesystem: FS):
)
@pytest.mark.parametrize("execution_names", [None, [4, "x", "y", "z"]])
@pytest.mark.parametrize("logs_handler_factory", [FileHandler, DummyFilesystemFileHandler])
def test_read_logs(test_args, execution_names, workflow, execution_kwargs, logs_handler_factory):
def test_logs(test_args, execution_names, workflow, execution_kwargs, logs_handler_factory):
workers, multiprocess, filter_logs = test_args
with tempfile.TemporaryDirectory() as tmp_dir_name:
executor = EOExecutor(
Expand All @@ -143,7 +142,11 @@ def test_read_logs(test_args, execution_names, workflow, execution_kwargs, logs_
)
executor.run(workers=workers, multiprocess=multiprocess)

execution_logs = executor.read_logs()
execution_logs = []
for log_path in executor.get_log_paths():
with open(log_path) as f:
execution_logs.append(f.read())

mlubej marked this conversation as resolved.
Show resolved Hide resolved
assert len(execution_logs) == 4
for log in execution_logs:
assert len(log.split()) >= 3
Expand Down Expand Up @@ -201,11 +204,9 @@ def test_execution_results2(workflow, execution_kwargs):
assert workflow_results.outputs["output"] == 42


def test_exceptions(workflow, execution_kwargs):
with pytest.raises(ValueError):
EOExecutor(workflow, {})
def test_exception_wrong_length_execution_names(workflow, execution_kwargs):
with pytest.raises(ValueError):
EOExecutor(workflow, execution_kwargs, execution_names={1, 2, 3, 4})
EOExecutor(workflow, execution_kwargs, execution_names={1, 2, 3, 4, 5})
with pytest.raises(ValueError):
EOExecutor(workflow, execution_kwargs, execution_names=["a", "b"])

Expand Down Expand Up @@ -269,17 +270,16 @@ def test_without_lock(num_workers):

@pytest.mark.parametrize("multiprocess", [True, False])
def test_temporal_dim_error(multiprocess):
workflow = EOWorkflow(
linearly_connect_tasks(
CreateEOPatchTask(bbox=BBox((0, 0, 1, 1), CRS.POP_WEB)),
InitializeFeatureTask([FeatureType.DATA, "data"], (2, 5, 5, 1)),
)
)

executor = EOExecutor(workflow, [{}, {}])
for result in executor.run(workers=2, multiprocess=multiprocess):
create_node = EONode(CreateEOPatchTask())
init_node = EONode(InitializeFeatureTask((FeatureType.DATA, "data"), (2, 5, 5, 1)), inputs=[create_node])
workflow = EOWorkflow([create_node, init_node])
exec_kwargs = [{create_node: {"bbox": BBox((0, 0, 1, 1), CRS.POP_WEB)}}] * 2

executor = EOExecutor(workflow, exec_kwargs)
results = executor.run(workers=2, multiprocess=multiprocess)
for result in results:
assert result.error_node_uid is None

executor = EOExecutor(workflow, [{}, {}], raise_on_temporal_mismatch=True)
executor = EOExecutor(workflow, exec_kwargs, raise_on_temporal_mismatch=True)
for result in executor.run(workers=2, multiprocess=multiprocess):
assert result.error_node_uid is not None
6 changes: 5 additions & 1 deletion tests/core/test_extra/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ def test_read_logs(filter_logs, execution_names, workflow, execution_kwargs):
)
executor.run()

execution_logs = executor.read_logs()
execution_logs = []
for log_path in executor.get_log_paths():
with open(log_path) as f:
execution_logs.append(f.read())

assert len(execution_logs) == 4
for log in execution_logs:
assert len(log.split()) >= 3
Expand Down