Skip to content

Commit

Permalink
Ray executor can request resources for workers (#766)
Browse files Browse the repository at this point in the history
* add necessary changes

* some light renaming

* fix tests and mypy

* also add the remote_kwargs parameter to `parallelize_with_ray`

* add docstring
  • Loading branch information
zigaLuksic authored Oct 27, 2023
1 parent 498aeea commit a0d6b28
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 12 deletions.
5 changes: 2 additions & 3 deletions eolearn/core/eoexecution.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,12 @@ def run(self, workers: int | None = 1, multiprocess: bool = True, **tqdm_kwargs:

return full_execution_results

@classmethod
def _run_execution(
cls, processing_args: list[_ProcessingData], run_params: _ExecutionRunParams
self, processing_args: list[_ProcessingData], run_params: _ExecutionRunParams
) -> list[WorkflowResults]:
"""Parallelizes the execution for each item of processing_args list."""
return parallelize(
cls._execute_workflow,
self._execute_workflow,
processing_args,
workers=run_params.workers,
multiprocess=run_params.multiprocess,
Expand Down
54 changes: 46 additions & 8 deletions eolearn/core/extra/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,20 @@
"""
from __future__ import annotations

from typing import Any, Callable, Collection, Generator, Iterable, List, TypeVar, cast
from logging import FileHandler, Filter
from typing import Any, Callable, Collection, Generator, Iterable, List, Sequence, TypeVar, cast

from fs.base import FS

from eolearn.core.eonode import EONode

try:
import ray
except ImportError as exception:
raise ImportError("This module requires an installation of Ray Python package") from exception

from ..eoexecution import EOExecutor, _ExecutionRunParams, _ProcessingData
from ..eoworkflow import WorkflowResults
from ..eoexecution import EOExecutor, _ExecutionRunParams, _HandlerFactoryType, _ProcessingData
from ..eoworkflow import EOWorkflow, WorkflowResults
from ..utils.parallelize import _base_join_futures_iter

# pylint: disable=invalid-name
Expand All @@ -29,6 +34,33 @@
class RayExecutor(EOExecutor):
"""A special type of `EOExecutor` that works with Ray framework"""

def __init__(
self,
workflow: EOWorkflow,
execution_kwargs: Sequence[dict[EONode, dict[str, object]]],
*,
execution_names: list[str] | None = None,
save_logs: bool = False,
logs_folder: str = ".",
filesystem: FS | None = None,
logs_filter: Filter | None = None,
logs_handler_factory: _HandlerFactoryType = FileHandler,
raise_on_temporal_mismatch: bool = False,
ray_remote_kwargs: dict[str, Any] | None = None,
):
super().__init__(
workflow,
execution_kwargs,
execution_names=execution_names,
save_logs=save_logs,
logs_folder=logs_folder,
filesystem=filesystem,
logs_filter=logs_filter,
logs_handler_factory=logs_handler_factory,
raise_on_temporal_mismatch=raise_on_temporal_mismatch,
)
self.ray_remote_kwargs = ray_remote_kwargs

def run(self, **tqdm_kwargs: Any) -> list[WorkflowResults]: # type: ignore[override]
"""Runs the executor using a Ray cluster
Expand All @@ -43,12 +75,13 @@ def run(self, **tqdm_kwargs: Any) -> list[WorkflowResults]: # type: ignore[over
workers = ray.available_resources().get("CPU")
return super().run(workers=workers, multiprocess=True, **tqdm_kwargs)

@classmethod
def _run_execution(
cls, processing_args: list[_ProcessingData], run_params: _ExecutionRunParams
self, processing_args: list[_ProcessingData], run_params: _ExecutionRunParams
) -> list[WorkflowResults]:
"""Runs ray execution"""
futures = [_ray_workflow_executor.remote(workflow_args) for workflow_args in processing_args]
remote_kwargs = self.ray_remote_kwargs or {}
exec_func = _ray_workflow_executor.options(**remote_kwargs) # type: ignore[attr-defined]
futures = [exec_func.remote(workflow_args) for workflow_args in processing_args]
return join_ray_futures(futures, **run_params.tqdm_kwargs)


Expand All @@ -60,7 +93,10 @@ def _ray_workflow_executor(workflow_args: _ProcessingData) -> WorkflowResults:


def parallelize_with_ray(
function: Callable[[InputType], OutputType], *params: Iterable[InputType], **tqdm_kwargs: Any
function: Callable[[InputType], OutputType],
*params: Iterable[InputType],
ray_remote_kwargs: dict[str, Any] | None = None,
**tqdm_kwargs: Any,
) -> list[OutputType]:
"""Parallelizes function execution with Ray.
Expand All @@ -69,13 +105,15 @@ def parallelize_with_ray(
:param function: A normal function that is not yet decorated by `ray.remote`.
:param params: Iterables of parameters that will be used with given function.
:param ray_remote_kwargs: Keyword arguments passed to `ray.remote`.
:param tqdm_kwargs: Keyword arguments that will be propagated to `tqdm` progress bar.
:return: A list of results in the order that corresponds with the order of the given input `params`.
"""
ray_remote_kwargs = ray_remote_kwargs or {}
if not ray.is_initialized():
raise RuntimeError("Please initialize a Ray cluster before calling this method")

ray_function = ray.remote(function)
ray_function = ray.remote(function, **ray_remote_kwargs)
futures = [ray_function.remote(*function_params) for function_params in zip(*params)]
return join_ray_futures(futures, **tqdm_kwargs)

Expand Down
3 changes: 2 additions & 1 deletion tests/core/test_extra/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def filter(self, record):

@pytest.fixture(name="_simple_cluster", scope="module")
def _simple_cluster_fixture():
ray.init(log_to_driver=False)
ray.init(log_to_driver=False, resources={"resourceA": 1})
yield
ray.shutdown()

Expand Down Expand Up @@ -103,6 +103,7 @@ def test_read_logs(filter_logs, execution_names, workflow, execution_kwargs):
logs_folder=tmp_dir_name,
logs_filter=CustomLogFilter() if filter_logs else None,
execution_names=execution_names,
ray_remote_kwargs={"resources": {"resourceA": 0.5}},
)
executor.run()

Expand Down

0 comments on commit a0d6b28

Please sign in to comment.