From 533d50c396ad2eb611193f099560f70835c9c7e6 Mon Sep 17 00:00:00 2001 From: Michael Carlstrom Date: Wed, 11 Sep 2024 12:33:40 -0400 Subject: [PATCH] Executors types (#1345) Signed-off-by: Michael Carlstrom Co-authored-by: Tomoya Fujita --- rclpy/rclpy/callback_groups.py | 3 +- rclpy/rclpy/executors.py | 181 +++++++++++++++++++----------- rclpy/rclpy/task.py | 26 +++-- rclpy/rclpy/timer.py | 2 +- rclpy/src/rclpy/action_client.hpp | 2 +- rclpy/src/rclpy/action_server.cpp | 2 +- rclpy/src/rclpy/action_server.hpp | 2 +- 7 files changed, 140 insertions(+), 78 deletions(-) diff --git a/rclpy/rclpy/callback_groups.py b/rclpy/rclpy/callback_groups.py index bee08d611..f0d964a56 100644 --- a/rclpy/rclpy/callback_groups.py +++ b/rclpy/rclpy/callback_groups.py @@ -24,7 +24,8 @@ from rclpy.service import Service from rclpy.waitable import Waitable from rclpy.guard_condition import GuardCondition - Entity = Union[Subscription, Timer, Client, Service, Waitable[Any], GuardCondition] + Entity = Union[Subscription[Any], Timer, Client[Any, Any], Service[Any, Any], + GuardCondition, Waitable[Any]] class CallbackGroup: diff --git a/rclpy/rclpy/executors.py b/rclpy/rclpy/executors.py index 17169063d..8bf5a1a3b 100644 --- a/rclpy/rclpy/executors.py +++ b/rclpy/rclpy/executors.py @@ -24,8 +24,10 @@ from types import TracebackType from typing import Any from typing import Callable +from typing import cast from typing import ContextManager from typing import Coroutine +from typing import Dict from typing import Generator from typing import List from typing import Optional @@ -58,33 +60,48 @@ # For documentation purposes # TODO(jacobperron): Make all entities implement the 'Waitable' interface for better type checking -WaitableEntityType = TypeVar('WaitableEntityType') + +T = TypeVar('T') # Avoid import cycle if TYPE_CHECKING: + from typing import TypeAlias + from rclpy.node import Node # noqa: F401 + from .callback_groups import Entity + EntityT = TypeVar('EntityT', bound=Entity) + + +FunctionOrCoroutineFunction: 'TypeAlias' = Union[Callable[..., T], + Callable[..., Coroutine[None, None, T]]] + + +YieldedCallback: 'TypeAlias' = Generator[Tuple[Task[None], + 'Optional[Entity]', + 'Optional[Node]'], None, None] class _WorkTracker: """Track the amount of work that is in progress.""" - def __init__(self): + def __init__(self) -> None: # Number of tasks that are being executed self._num_work_executing = 0 self._work_condition = Condition() - def __enter__(self): + def __enter__(self) -> None: """Increment the amount of executing work by 1.""" with self._work_condition: self._num_work_executing += 1 - def __exit__(self, t, v, tb): + def __exit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], exctb: Optional[TracebackType]) -> None: """Decrement the amount of work executing by 1.""" with self._work_condition: self._num_work_executing -= 1 self._work_condition.notify_all() - def wait(self, timeout_sec: Optional[float] = None): + def wait(self, timeout_sec: Optional[float] = None) -> bool: """ Wait until all work completes. @@ -102,12 +119,14 @@ def wait(self, timeout_sec: Optional[float] = None): return True -async def await_or_execute(callback: Union[Callable, Coroutine], *args) -> Any: +async def await_or_execute(callback: FunctionOrCoroutineFunction[T], *args: Any) -> T: """Await a callback if it is a coroutine, else execute it.""" if inspect.iscoroutinefunction(callback): # Await a coroutine + callback = cast(Callable[..., Coroutine[None, None, T]], callback) return await callback(*args) else: + callback = cast(Callable[..., T], callback) # Call a normal function return callback(*args) @@ -139,15 +158,15 @@ class ConditionReachedException(Exception): class TimeoutObject: """Use timeout object to save timeout.""" - def __init__(self, timeout: float): + def __init__(self, timeout: float) -> None: self._timeout = timeout @property - def timeout(self): + def timeout(self) -> float: return self._timeout @timeout.setter - def timeout(self, timeout): + def timeout(self, timeout: float) -> None: self._timeout = timeout @@ -181,10 +200,10 @@ def __init__(self, *, context: Optional[Context] = None) -> None: self._nodes: Set[Node] = set() self._nodes_lock = RLock() # Tasks to be executed (oldest first) 3-tuple Task, Entity, Node - self._tasks: List[Tuple[Task, Optional[WaitableEntityType], Optional[Node]]] = [] + self._tasks: List[Tuple[Task[Any], 'Optional[Entity]', Optional[Node]]] = [] self._tasks_lock = Lock() # This is triggered when wait_for_ready_callbacks should rebuild the wait list - self._guard = GuardCondition( + self._guard: Optional[GuardCondition] = GuardCondition( callback=None, callback_group=None, context=self._context) # True if shutdown has been called self._is_shutdown = False @@ -192,12 +211,13 @@ def __init__(self, *, context: Optional[Context] = None) -> None: # Protect against shutdown() being called in parallel in two threads self._shutdown_lock = Lock() # State for wait_for_ready_callbacks to reuse generator - self._cb_iter = None - self._last_args = None - self._last_kwargs = None + self._cb_iter: Optional[YieldedCallback] = None + self._last_args: Optional[tuple[object, ...]] = None + self._last_kwargs: Optional[Dict[str, object]] = None # Executor cannot use ROS clock because that requires a node self._clock = Clock(clock_type=ClockType.STEADY_TIME) - self._sigint_gc = SignalHandlerGuardCondition(context) + self._sigint_gc: Optional[SignalHandlerGuardCondition] = \ + SignalHandlerGuardCondition(context) self._context.on_shutdown(self.wake) @property @@ -205,7 +225,8 @@ def context(self) -> Context: """Get the context associated with the executor.""" return self._context - def create_task(self, callback: Union[Callable, Coroutine], *args, **kwargs) -> Task: + def create_task(self, callback: FunctionOrCoroutineFunction[T], *args: Any, **kwargs: Any + ) -> Task[T]: """ Add a callback or coroutine to be executed during :meth:`spin` and return a Future. @@ -219,7 +240,8 @@ def create_task(self, callback: Union[Callable, Coroutine], *args, **kwargs) -> task = Task(callback, args, kwargs, executor=self) with self._tasks_lock: self._tasks.append((task, None, None)) - self._guard.trigger() + if self._guard: + self._guard.trigger() # Task inherits from Future return task @@ -236,7 +258,8 @@ def shutdown(self, timeout_sec: Optional[float] = None) -> bool: if not self._is_shutdown: self._is_shutdown = True # Tell executor it's been shut down - self._guard.trigger() + if self._guard: + self._guard.trigger() if not self._is_shutdown: if not self._work_tracker.wait(timeout_sec): return False @@ -257,7 +280,7 @@ def shutdown(self, timeout_sec: Optional[float] = None) -> bool: self._last_kwargs = None return True - def __del__(self): + def __del__(self) -> None: if self._sigint_gc is not None: self._sigint_gc.destroy() @@ -273,7 +296,8 @@ def add_node(self, node: 'Node') -> bool: self._nodes.add(node) node.executor = self # Rebuild the wait set so it includes this new node - self._guard.trigger() + if self._guard: + self._guard.trigger() return True return False @@ -290,7 +314,8 @@ def remove_node(self, node: 'Node') -> None: pass else: # Rebuild the wait set so it doesn't include this node - self._guard.trigger() + if self._guard: + self._guard.trigger() def wake(self) -> None: """ @@ -313,7 +338,7 @@ def spin(self) -> None: def spin_until_future_complete( self, - future: Future, + future: Future[Any], timeout_sec: Optional[float] = None ) -> None: """Execute callbacks until a given future is done or a timeout occurs.""" @@ -352,7 +377,7 @@ def spin_once(self, timeout_sec: Optional[float] = None) -> None: def spin_once_until_future_complete( self, - future: Future, + future: Future[Any], timeout_sec: Optional[Union[float, TimeoutObject]] = None ) -> None: """ @@ -367,7 +392,7 @@ def spin_once_until_future_complete( """ raise NotImplementedError() - def _take_timer(self, tmr): + def _take_timer(self, tmr: Timer) -> Optional[Callable[[], Coroutine[None, None, None]]]: try: with tmr.handle: info = tmr.handle.call_timer_with_info() @@ -376,7 +401,9 @@ def _take_timer(self, tmr): actual_call_time=info['actual_call_time'], clock_type=tmr.clock.clock_type) - def check_argument_type(callback_func, target_type): + def check_argument_type(callback_func: Union[Callable[[], None], + Callable[[TimerInfo], None]], + target_type: Type[TimerInfo]) -> Optional[str]: sig = inspect.signature(callback_func) for param in sig.parameters.values(): if param.annotation == target_type: @@ -387,15 +414,19 @@ def check_argument_type(callback_func, target_type): # User might change the Timer.callback function signature at runtime, # so it needs to check the signature every time. - arg_name = check_argument_type(tmr.callback, target_type=TimerInfo) - prefilled_arg = {arg_name: timer_info} + if tmr.callback: + arg_name = check_argument_type(tmr.callback, target_type=TimerInfo) if arg_name is not None: - async def _execute(): - await await_or_execute(partial(tmr.callback, **prefilled_arg)) + prefilled_arg = {arg_name: timer_info} + + async def _execute() -> None: + if tmr.callback: + await await_or_execute(partial(tmr.callback, **prefilled_arg)) return _execute else: - async def _execute(): - await await_or_execute(tmr.callback) + async def _execute() -> None: + if tmr.callback: + await await_or_execute(tmr.callback) return _execute except InvalidHandle: # Timer is a Destroyable, which means that on __enter__ it can throw an @@ -406,7 +437,8 @@ async def _execute(): return None - def _take_subscription(self, sub): + def _take_subscription(self, sub: Subscription[Any] + ) -> Optional[Callable[[], Coroutine[None, None, None]]]: try: with sub.handle: msg_info = sub.handle.take_message(sub.msg_type, sub.raw) @@ -418,7 +450,7 @@ def _take_subscription(self, sub): else: msg_tuple = msg_info - async def _execute(): + async def _execute() -> None: await await_or_execute(sub.callback, *msg_tuple) return _execute @@ -431,12 +463,13 @@ async def _execute(): return None - def _take_client(self, client): + def _take_client(self, client: Client[Any, Any] + ) -> Optional[Callable[[], Coroutine[None, None, None]]]: try: with client.handle: header_and_response = client.handle.take_response(client.srv_type.Response) - async def _execute(): + async def _execute() -> None: header, response = header_and_response if header is None: return @@ -460,12 +493,13 @@ async def _execute(): return None - def _take_service(self, srv): + def _take_service(self, srv: Service[Any, Any] + ) -> Optional[Callable[[], Coroutine[None, None, None]]]: try: with srv.handle: request_and_header = srv.handle.service_take_request(srv.srv_type.Request) - async def _execute(): + async def _execute() -> None: (request, header) = request_and_header if header is None: return @@ -482,17 +516,19 @@ async def _execute(): return None - def _take_guard_condition(self, gc): + def _take_guard_condition(self, gc: GuardCondition + ) -> Callable[[], Coroutine[None, None, None]]: gc._executor_triggered = False - async def _execute(): - await await_or_execute(gc.callback) + async def _execute() -> None: + if gc.callback: + await await_or_execute(gc.callback) return _execute - def _take_waitable(self, waitable): + def _take_waitable(self, waitable: Waitable[Any]) -> Callable[[], Coroutine[None, None, None]]: data = waitable.take_data() - async def _execute(): + async def _execute() -> None: for future in waitable._futures: future._set_executor(self) await waitable.execute(data) @@ -500,10 +536,11 @@ async def _execute(): def _make_handler( self, - entity: WaitableEntityType, + entity: 'EntityT', node: 'Node', - take_from_wait_list: Callable, - ) -> Task: + take_from_wait_list: Callable[['EntityT'], + Optional[Callable[[], Coroutine[None, None, None]]]], + ) -> Task[None]: """ Make a handler that performs work on an entity. @@ -514,8 +551,10 @@ def _make_handler( # Mark this so it doesn't get added back to the wait list entity._executor_event = True - async def handler(entity, gc, is_shutdown, work_tracker): - if is_shutdown or not entity.callback_group.beginning_execution(entity): + async def handler(entity: 'EntityT', gc: GuardCondition, is_shutdown: bool, + work_tracker: _WorkTracker) -> None: + if is_shutdown or entity.callback_group is not None and \ + not entity.callback_group.beginning_execution(entity): # Didn't get the callback, or the executor has been ordered to stop entity._executor_event = False gc.trigger() @@ -533,7 +572,8 @@ async def handler(entity, gc, is_shutdown, work_tracker): if call_coroutine is not None: await call_coroutine() finally: - entity.callback_group.ending_execution(entity) + if entity.callback_group: + entity.callback_group.ending_execution(entity) # Signal that work has been done so the next callback in a mutually exclusive # callback group can get executed @@ -550,21 +590,22 @@ async def handler(entity, gc, is_shutdown, work_tracker): self._tasks.append((task, entity, node)) return task - def can_execute(self, entity: WaitableEntityType) -> bool: + def can_execute(self, entity: 'Entity') -> bool: """ Determine if a callback for an entity can be executed. :param entity: Subscription, Timer, Guard condition, etc :returns: ``True`` if the entity callback can be executed, ``False`` otherwise. """ - return not entity._executor_event and entity.callback_group.can_execute(entity) + return not entity._executor_event and entity.callback_group is not None \ + and entity.callback_group.can_execute(entity) def _wait_for_ready_callbacks( self, timeout_sec: Optional[Union[float, TimeoutObject]] = None, nodes: Optional[List['Node']] = None, condition: Callable[[], bool] = lambda: False, - ) -> Generator[Tuple[Task, WaitableEntityType, 'Node'], None, None]: + ) -> YieldedCallback: """ Yield callbacks that are ready to be executed. @@ -587,7 +628,7 @@ def _wait_for_ready_callbacks( while not yielded_work and not self._is_shutdown and not condition(): # Refresh "all" nodes in case executor was woken by a node being added or removed nodes_to_use = nodes - if nodes is None: + if nodes_to_use is None: nodes_to_use = self.get_nodes() # Yield tasks in-progress before waiting for new work @@ -605,11 +646,11 @@ def _wait_for_ready_callbacks( self._tasks = list(filter(lambda t_e_n: not t_e_n[0].done(), self._tasks)) # Gather entities that can be waited on - subscriptions: List[Subscription] = [] + subscriptions: List[Subscription[Any, ]] = [] guards: List[GuardCondition] = [] timers: List[Timer] = [] - clients: List[Client] = [] - services: List[Service] = [] + clients: List[Client[Any, Any]] = [] + services: List[Service[Any, Any]] = [] waitables: List[Waitable[Any]] = [] for node in nodes_to_use: subscriptions.extend(filter(self.can_execute, node.subscriptions)) @@ -626,8 +667,10 @@ def _wait_for_ready_callbacks( if timeout_timer is not None: timers.append(timeout_timer) - guards.append(self._guard) - guards.append(self._sigint_gc) + if self._guard: + guards.append(self._guard) + if self._sigint_gc: + guards.append(self._sigint_gc) entity_count = NumberOfEntities( len(subscriptions), len(guards), len(timers), len(clients), len(services)) @@ -682,6 +725,9 @@ def _wait_for_ready_callbacks( except InvalidHandle: pass + if self._context.handle is None: + raise RuntimeError('Cannot enter context if context is None') + context_stack.enter_context(self._context.handle) wait_set = _rclpy.WaitSet( @@ -742,7 +788,7 @@ def _wait_for_ready_callbacks( if tmr.handle.pointer in timers_ready: # Check timer is ready to workaround rcl issue with cancelled timers if tmr.handle.is_timer_ready(): - if tmr.callback_group.can_execute(tmr): + if tmr.callback_group and tmr.callback_group.can_execute(tmr): handler = self._make_handler(tmr, node, self._take_timer) yielded_work = True yield handler, tmr, node @@ -756,7 +802,7 @@ def _wait_for_ready_callbacks( for gc in node.guards: if gc._executor_triggered: - if gc.callback_group.can_execute(gc): + if gc.callback_group and gc.callback_group.can_execute(gc): handler = self._make_handler(gc, node, self._take_guard_condition) yielded_work = True yield handler, gc, node @@ -786,7 +832,9 @@ def _wait_for_ready_callbacks( if condition(): raise ConditionReachedException() - def wait_for_ready_callbacks(self, *args, **kwargs) -> Tuple[Task, WaitableEntityType, 'Node']: + def wait_for_ready_callbacks(self, *args: Any, **kwargs: Any) -> Tuple[Task[None], + 'Optional[Entity]', + 'Optional[Node]']: """ Return callbacks that are ready to be executed. @@ -844,8 +892,9 @@ def _spin_once_impl( pass else: handler() - if handler.exception() is not None: - raise handler.exception() + exception = handler.exception() + if exception is not None: + raise exception handler.result() # raise any exceptions @@ -854,7 +903,7 @@ def spin_once(self, timeout_sec: Optional[float] = None) -> None: def spin_once_until_future_complete( self, - future: Future, + future: Future[Any], timeout_sec: Optional[Union[float, TimeoutObject]] = None ) -> None: future.add_done_callback(lambda x: self.wake()) @@ -892,7 +941,7 @@ def __init__( warnings.warn( 'MultiThreadedExecutor is used with a single thread.\n' 'Use the SingleThreadedExecutor instead.') - self._futures = [] + self._futures: List[Future[Any]] = [] self._executor = ThreadPoolExecutor(num_threads) def _spin_once_impl( @@ -926,7 +975,7 @@ def spin_once(self, timeout_sec: Optional[float] = None) -> None: def spin_once_until_future_complete( self, - future: Future, + future: Future[Any], timeout_sec: Optional[Union[float, TimeoutObject]] = None ) -> None: future.add_done_callback(lambda x: self.wake()) @@ -934,7 +983,7 @@ def spin_once_until_future_complete( def shutdown( self, - timeout_sec: float = None, + timeout_sec: Optional[float] = None, *, wait_for_threads: bool = True ) -> bool: diff --git a/rclpy/rclpy/task.py b/rclpy/rclpy/task.py index 81a56ab5b..10fae2742 100644 --- a/rclpy/rclpy/task.py +++ b/rclpy/rclpy/task.py @@ -15,16 +15,21 @@ import inspect import sys import threading -from typing import (Callable, cast, Coroutine, Dict, Generator, Generic, List, +from typing import (Callable, cast, Coroutine, Dict, Generator, Generic, Iterable, List, Optional, TYPE_CHECKING, TypeVar, Union) import warnings import weakref if TYPE_CHECKING: + from typing import TypeAlias + from rclpy.executors import Executor T = TypeVar('T') +FunctionOrCoroutineFunction: 'TypeAlias' = Union[Callable[[], T], + Callable[..., Coroutine[None, None, T]]] + def _fake_weakref() -> None: """Return None when called to simulate a weak reference that has been garbage collected.""" @@ -207,13 +212,11 @@ class Task(Future[T]): """ def __init__(self, - handler: Union[Callable[[], T], Coroutine[None, None, T], None], - args: Optional[List[object]] = None, + handler: FunctionOrCoroutineFunction[T], + args: Optional[Iterable[object]] = None, kwargs: Optional[Dict[str, object]] = None, executor: Optional['Executor'] = None) -> None: super().__init__(executor=executor) - # _handler is either a normal function or a coroutine - self._handler = handler # Arguments passed into the function if args is None: args = [] @@ -221,10 +224,19 @@ def __init__(self, if kwargs is None: kwargs = {} self._kwargs: Optional[Dict[str, object]] = kwargs + + # _handler is either a normal function or a coroutine if inspect.iscoroutinefunction(handler): - self._handler = handler(*args, **kwargs) + self._handler: Union[ + Coroutine[None, None, T], + Callable[[], T], + None + ] = handler(*args, **kwargs) self._args = None self._kwargs = None + else: + handler = cast(Callable[[], T], handler) + self._handler = handler # True while the task is being executed self._executing = False # Lock acquired to prevent task from executing in parallel with itself @@ -248,7 +260,7 @@ def __call__(self) -> None: if inspect.iscoroutine(self._handler): # Execute a coroutine - handler = cast(Coroutine[None, None, T], self._handler) + handler = self._handler try: handler.send(None) except StopIteration as e: diff --git a/rclpy/rclpy/timer.py b/rclpy/rclpy/timer.py index c5b577053..706234aff 100644 --- a/rclpy/rclpy/timer.py +++ b/rclpy/rclpy/timer.py @@ -67,7 +67,7 @@ class Timer: def __init__( self, callback: Union[Callable[[], None], Callable[[TimerInfo], None], None], - callback_group: CallbackGroup, + callback_group: Optional[CallbackGroup], timer_period_ns: int, clock: Clock, *, diff --git a/rclpy/src/rclpy/action_client.hpp b/rclpy/src/rclpy/action_client.hpp index 5dcf04b90..49d616e4c 100644 --- a/rclpy/src/rclpy/action_client.hpp +++ b/rclpy/src/rclpy/action_client.hpp @@ -207,7 +207,7 @@ class ActionClient : public Destroyable, public std::enable_shared_from_this