Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ray executor can request resources for workers #766

Merged
merged 6 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions eolearn/core/eoexecution.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,12 @@ def run(self, workers: int | None = 1, multiprocess: bool = True, **tqdm_kwargs:

return full_execution_results

@classmethod
def _run_execution(
cls, processing_args: list[_ProcessingData], run_params: _ExecutionRunParams
self, processing_args: list[_ProcessingData], run_params: _ExecutionRunParams
) -> list[WorkflowResults]:
"""Parallelizes the execution for each item of processing_args list."""
return parallelize(
cls._execute_workflow,
self._execute_workflow,
processing_args,
workers=run_params.workers,
multiprocess=run_params.multiprocess,
Expand Down
54 changes: 46 additions & 8 deletions eolearn/core/extra/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,20 @@
"""
from __future__ import annotations

from typing import Any, Callable, Collection, Generator, Iterable, List, TypeVar, cast
from logging import FileHandler, Filter
from typing import Any, Callable, Collection, Generator, Iterable, List, Sequence, TypeVar, cast

from fs.base import FS

from eolearn.core.eonode import EONode

try:
import ray
except ImportError as exception:
raise ImportError("This module requires an installation of Ray Python package") from exception

from ..eoexecution import EOExecutor, _ExecutionRunParams, _ProcessingData
from ..eoworkflow import WorkflowResults
from ..eoexecution import EOExecutor, _ExecutionRunParams, _HandlerFactoryType, _ProcessingData
from ..eoworkflow import EOWorkflow, WorkflowResults
from ..utils.parallelize import _base_join_futures_iter

# pylint: disable=invalid-name
Expand All @@ -29,6 +34,33 @@
class RayExecutor(EOExecutor):
"""A special type of `EOExecutor` that works with Ray framework"""

def __init__(
self,
workflow: EOWorkflow,
execution_kwargs: Sequence[dict[EONode, dict[str, object]]],
*,
execution_names: list[str] | None = None,
save_logs: bool = False,
logs_folder: str = ".",
filesystem: FS | None = None,
logs_filter: Filter | None = None,
logs_handler_factory: _HandlerFactoryType = FileHandler,
raise_on_temporal_mismatch: bool = False,
ray_remote_kwargs: dict[str, Any] | None = None,
):
super().__init__(
workflow,
execution_kwargs,
execution_names=execution_names,
save_logs=save_logs,
logs_folder=logs_folder,
filesystem=filesystem,
logs_filter=logs_filter,
logs_handler_factory=logs_handler_factory,
raise_on_temporal_mismatch=raise_on_temporal_mismatch,
)
self.ray_remote_kwargs = ray_remote_kwargs

def run(self, **tqdm_kwargs: Any) -> list[WorkflowResults]: # type: ignore[override]
"""Runs the executor using a Ray cluster

Expand All @@ -43,12 +75,13 @@ def run(self, **tqdm_kwargs: Any) -> list[WorkflowResults]: # type: ignore[over
workers = ray.available_resources().get("CPU")
return super().run(workers=workers, multiprocess=True, **tqdm_kwargs)

@classmethod
def _run_execution(
cls, processing_args: list[_ProcessingData], run_params: _ExecutionRunParams
self, processing_args: list[_ProcessingData], run_params: _ExecutionRunParams
) -> list[WorkflowResults]:
"""Runs ray execution"""
futures = [_ray_workflow_executor.remote(workflow_args) for workflow_args in processing_args]
remote_kwargs = self.ray_remote_kwargs or {}
exec_func = _ray_workflow_executor.options(**remote_kwargs) # type: ignore[attr-defined]
futures = [exec_func.remote(workflow_args) for workflow_args in processing_args]
return join_ray_futures(futures, **run_params.tqdm_kwargs)


Expand All @@ -60,7 +93,10 @@ def _ray_workflow_executor(workflow_args: _ProcessingData) -> WorkflowResults:


def parallelize_with_ray(
function: Callable[[InputType], OutputType], *params: Iterable[InputType], **tqdm_kwargs: Any
function: Callable[[InputType], OutputType],
*params: Iterable[InputType],
ray_remote_kwargs: dict[str, Any] | None = None,
**tqdm_kwargs: Any,
) -> list[OutputType]:
"""Parallelizes function execution with Ray.

Expand All @@ -69,13 +105,15 @@ def parallelize_with_ray(

:param function: A normal function that is not yet decorated by `ray.remote`.
:param params: Iterables of parameters that will be used with given function.
:param ray_remote_kwargs: Keyword arguments passed to `ray.remote`.
:param tqdm_kwargs: Keyword arguments that will be propagated to `tqdm` progress bar.
:return: A list of results in the order that corresponds with the order of the given input `params`.
"""
ray_remote_kwargs = ray_remote_kwargs or {}
if not ray.is_initialized():
raise RuntimeError("Please initialize a Ray cluster before calling this method")

ray_function = ray.remote(function)
ray_function = ray.remote(function, **ray_remote_kwargs)
futures = [ray_function.remote(*function_params) for function_params in zip(*params)]
return join_ray_futures(futures, **tqdm_kwargs)

Expand Down
3 changes: 2 additions & 1 deletion tests/core/test_extra/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def filter(self, record):

@pytest.fixture(name="_simple_cluster", scope="module")
def _simple_cluster_fixture():
ray.init(log_to_driver=False)
ray.init(log_to_driver=False, resources={"resourceA": 1})
yield
ray.shutdown()

Expand Down Expand Up @@ -103,6 +103,7 @@ def test_read_logs(filter_logs, execution_names, workflow, execution_kwargs):
logs_folder=tmp_dir_name,
logs_filter=CustomLogFilter() if filter_logs else None,
execution_names=execution_names,
ray_remote_kwargs={"resources": {"resourceA": 0.5}},
)
executor.run()

Expand Down