Skip to content

Commit

Permalink
Add types to waitable.py (#1328)
Browse files Browse the repository at this point in the history
* add types

Signed-off-by: Michael Carlstrom <[email protected]>

* move typing into string

Signed-off-by: Michael Carlstrom <[email protected]>

* move Future type into string

Signed-off-by: Michael Carlstrom <[email protected]>

* flake8 fixes

Signed-off-by: Michael Carlstrom <[email protected]>

* move typedicts to outside TYPE_CHECKING

Signed-off-by: Michael Carlstrom <[email protected]>

* rerun stuck ci

Signed-off-by: Michael Carlstrom <[email protected]>

* undo accidental removal

Signed-off-by: Michael Carlstrom <[email protected]>

* add functions

Signed-off-by: Michael Carlstrom <[email protected]>

---------

Signed-off-by: Michael Carlstrom <[email protected]>
Co-authored-by: Shane Loretz <[email protected]>
  • Loading branch information
InvincibleRMC and sloretz authored Aug 21, 2024
1 parent 1eb4208 commit 7e3005a
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 37 deletions.
18 changes: 14 additions & 4 deletions rclpy/rclpy/action/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import threading
import time
from typing import Any
from typing import TypedDict
import uuid
import weakref

Expand All @@ -32,6 +34,14 @@
from unique_identifier_msgs.msg import UUID


class ClientGoalHandleDict(TypedDict, total=False):
goal: Any
cancel: Any
result: Any
feedback: Any
status: Any


class ClientGoalHandle():
"""Goal handle for working with Action Clients."""

Expand Down Expand Up @@ -108,7 +118,7 @@ def get_result_async(self):
return self._action_client._get_result_async(self)


class ActionClient(Waitable):
class ActionClient(Waitable[ClientGoalHandleDict]):
"""ROS Action client."""

def __init__(
Expand Down Expand Up @@ -237,9 +247,9 @@ def is_ready(self, wait_set):
self._is_result_response_ready = ready_entities[4]
return any(ready_entities)

def take_data(self):
def take_data(self) -> ClientGoalHandleDict:
"""Take stuff from lower level so the wait set doesn't immediately wake again."""
data = {}
data: ClientGoalHandleDict = {}
if self._is_goal_response_ready:
taken_data = self._client_handle.take_goal_response(
self._action_type.Impl.SendGoalService.Response)
Expand Down Expand Up @@ -277,7 +287,7 @@ def take_data(self):

return data

async def execute(self, taken_data):
async def execute(self, taken_data: ClientGoalHandleDict) -> None:
"""
Execute work after data has been taken from a ready wait set.
Expand Down
17 changes: 13 additions & 4 deletions rclpy/rclpy/action/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import threading
import traceback

from typing import Any, TypedDict

from action_msgs.msg import GoalInfo, GoalStatus

from rclpy.executors import await_or_execute
Expand Down Expand Up @@ -49,6 +51,13 @@ class CancelResponse(Enum):
GoalEvent = _rclpy.GoalEvent


class ServerGoalHandleDict(TypedDict, total=False):
goal: Any
cancel: Any
result: Any
expired: Any


class ServerGoalHandle:
"""Goal handle for working with Action Servers."""

Expand Down Expand Up @@ -178,7 +187,7 @@ def default_cancel_callback(cancel_request):
return CancelResponse.REJECT


class ActionServer(Waitable):
class ActionServer(Waitable[ServerGoalHandleDict]):
"""ROS Action server."""

def __init__(
Expand Down Expand Up @@ -446,9 +455,9 @@ def is_ready(self, wait_set):
self._is_goal_expired = ready_entities[3]
return any(ready_entities)

def take_data(self):
def take_data(self) -> ServerGoalHandleDict:
"""Take stuff from lower level so the wait set doesn't immediately wake again."""
data = {}
data: ServerGoalHandleDict = {}
if self._is_goal_request_ready:
with self._lock:
taken_data = self._handle.take_goal_request(
Expand Down Expand Up @@ -482,7 +491,7 @@ def take_data(self):

return data

async def execute(self, taken_data):
async def execute(self, taken_data: ServerGoalHandleDict) -> None:
"""
Execute work after data has been taken from a ready wait set.
Expand Down
4 changes: 2 additions & 2 deletions rclpy/rclpy/callback_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from threading import Lock
from typing import Literal, Optional, TYPE_CHECKING, Union
from typing import Any, Literal, Optional, TYPE_CHECKING, Union
import weakref


Expand All @@ -23,7 +23,7 @@
from rclpy.client import Client
from rclpy.service import Service
from rclpy.waitable import Waitable
Entity = Union[Subscription, Timer, Client, Service, Waitable]
Entity = Union[Subscription, Timer, Client, Service, Waitable[Any]]


class CallbackGroup:
Expand Down
13 changes: 10 additions & 3 deletions rclpy/rclpy/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from enum import IntEnum
from typing import Any
from typing import Callable
from typing import List
from typing import Optional
Expand All @@ -27,6 +28,9 @@
from rclpy.waitable import NumberOfEntities
from rclpy.waitable import Waitable

if TYPE_CHECKING:
from typing import TypeAlias


if TYPE_CHECKING:
from rclpy.subscription import SubscriptionHandle
Expand Down Expand Up @@ -75,7 +79,10 @@
UnsupportedEventTypeError = _rclpy.UnsupportedEventTypeError


class EventHandler(Waitable):
EventHandlerData: 'TypeAlias' = Optional[Any]


class EventHandler(Waitable[EventHandlerData]):
"""Waitable type to handle QoS events."""

def __init__(
Expand Down Expand Up @@ -106,15 +113,15 @@ def is_ready(self, wait_set):
self._ready_to_take_data = True
return self._ready_to_take_data

def take_data(self):
def take_data(self) -> EventHandlerData:
"""Take stuff from lower level so the wait set doesn't immediately wake again."""
if self._ready_to_take_data:
self._ready_to_take_data = False
with self.__event:
return self.__event.take_event()
return None

async def execute(self, taken_data):
async def execute(self, taken_data: EventHandlerData) -> None:
"""Execute work after data has been taken from a ready wait set."""
if not taken_data:
return
Expand Down
2 changes: 1 addition & 1 deletion rclpy/rclpy/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def _wait_for_ready_callbacks(
timers: List[Timer] = []
clients: List[Client] = []
services: List[Service] = []
waitables: List[Waitable] = []
waitables: List[Waitable[Any]] = []
for node in nodes_to_use:
subscriptions.extend(filter(self.can_execute, node.subscriptions))
timers.extend(filter(self.can_execute, node.timers))
Expand Down
9 changes: 5 additions & 4 deletions rclpy/rclpy/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import time

from types import TracebackType
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterator
Expand Down Expand Up @@ -181,7 +182,7 @@ def __init__(
self._services: List[Service] = []
self._timers: List[Timer] = []
self._guards: List[GuardCondition] = []
self.__waitables: List[Waitable] = []
self.__waitables: List[Waitable[Any]] = []
self._default_callback_group = MutuallyExclusiveCallbackGroup()
self._pre_set_parameters_callbacks: List[Callable[[List[Parameter]], List[Parameter]]] = []
self._on_set_parameters_callbacks: \
Expand Down Expand Up @@ -290,7 +291,7 @@ def guards(self) -> Iterator[GuardCondition]:
yield from self._guards

@property
def waitables(self) -> Iterator[Waitable]:
def waitables(self) -> Iterator[Waitable[Any]]:
"""Get waitables that have been created on this node."""
yield from self.__waitables

Expand Down Expand Up @@ -1485,7 +1486,7 @@ def _validate_qos_or_depth_parameter(self, qos_or_depth) -> QoSProfile:
raise TypeError(
'Expected QoSProfile or int, but received {!r}'.format(type(qos_or_depth)))

def add_waitable(self, waitable: Waitable) -> None:
def add_waitable(self, waitable: Waitable[Any]) -> None:
"""
Add a class that is capable of adding things to the wait set.
Expand All @@ -1494,7 +1495,7 @@ def add_waitable(self, waitable: Waitable) -> None:
self.__waitables.append(waitable)
self._wake_executor()

def remove_waitable(self, waitable: Waitable) -> None:
def remove_waitable(self, waitable: Waitable[Any]) -> None:
"""
Remove a Waitable that was previously added to the node.
Expand Down
58 changes: 39 additions & 19 deletions rclpy/rclpy/waitable.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from types import TracebackType
from typing import Any, Generic, List, Optional, Type, TYPE_CHECKING, TypeVar


from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy

T = TypeVar('T')


if TYPE_CHECKING:
from typing_extensions import Self

from rclpy.callback_groups import CallbackGroup
from rclpy.task import Future


class NumberOfEntities:

Expand All @@ -24,8 +39,8 @@ class NumberOfEntities:
'num_events']

def __init__(
self, num_subs=0, num_gcs=0, num_timers=0,
num_clients=0, num_services=0, num_events=0
self, num_subs: int = 0, num_gcs: int = 0, num_timers: int = 0,
num_clients: int = 0, num_services: int = 0, num_events: int = 0
):
self.num_subscriptions = num_subs
self.num_guard_conditions = num_gcs
Expand All @@ -34,7 +49,7 @@ def __init__(
self.num_services = num_services
self.num_events = num_events

def __add__(self, other):
def __add__(self, other: 'NumberOfEntities') -> 'NumberOfEntities':
result = self.__class__()
result.num_subscriptions = self.num_subscriptions + other.num_subscriptions
result.num_guard_conditions = self.num_guard_conditions + other.num_guard_conditions
Expand All @@ -44,7 +59,7 @@ def __add__(self, other):
result.num_events = self.num_events + other.num_events
return result

def __iadd__(self, other):
def __iadd__(self, other: 'NumberOfEntities') -> 'NumberOfEntities':
self.num_subscriptions += other.num_subscriptions
self.num_guard_conditions += other.num_guard_conditions
self.num_timers += other.num_timers
Expand All @@ -53,59 +68,64 @@ def __iadd__(self, other):
self.num_events += other.num_events
return self

def __repr__(self):
def __repr__(self) -> str:
return '<{0}({1}, {2}, {3}, {4}, {5}, {6})>'.format(
self.__class__.__name__, self.num_subscriptions,
self.num_guard_conditions, self.num_timers, self.num_clients,
self.num_services, self.num_events)


class Waitable:
class Waitable(Generic[T]):
"""
Add something to a wait set and execute it.
This class wraps a collection of entities which can be added to a wait set.
"""

def __init__(self, callback_group):
def __init__(self, callback_group: 'CallbackGroup'):
# A callback group to control when this entity can execute (used by Executor)
self.callback_group = callback_group
self.callback_group.add_entity(self)
# Flag set by executor when a handler has been created but not executed (used by Executor)
self._executor_event = False
# List of Futures that have callbacks needing execution
self._futures = []
self._futures: List[Future[Any]] = []

def __enter__(self):
def __enter__(self) -> 'Self':
"""Implement to mark entities as in-use to prevent destruction while waiting on them."""
pass
raise NotImplementedError('Must be implemented by subclass')

def __exit__(self, t, v, tb):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
"""Implement to mark entities as not-in-use to allow destruction after waiting on them."""
pass
raise NotImplementedError('Must be implemented by subclass')

def add_future(self, future):
def add_future(self, future: 'Future[Any]') -> None:
self._futures.append(future)

def remove_future(self, future):
def remove_future(self, future: 'Future[Any]') -> None:
self._futures.remove(future)

def is_ready(self, wait_set):
def is_ready(self, wait_set: _rclpy.WaitSet) -> bool:
"""Return True if entities are ready in the wait set."""
raise NotImplementedError('Must be implemented by subclass')

def take_data(self):
def take_data(self) -> T:
"""Take stuff from lower level so the wait set doesn't immediately wake again."""
raise NotImplementedError('Must be implemented by subclass')

async def execute(self, taken_data):
async def execute(self, taken_data: T) -> None:
"""Execute work after data has been taken from a ready wait set."""
raise NotImplementedError('Must be implemented by subclass')

def get_num_entities(self):
def get_num_entities(self) -> NumberOfEntities:
"""Return number of each type of entity used."""
raise NotImplementedError('Must be implemented by subclass')

def add_to_wait_set(self, wait_set):
def add_to_wait_set(self, wait_set: _rclpy.WaitSet) -> None:
"""Add entities to wait set."""
raise NotImplementedError('Must be implemented by subclass')
6 changes: 6 additions & 0 deletions rclpy/test/test_create_while_spinning.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ class DummyWaitable(Waitable):
def __init__(self):
super().__init__(ReentrantCallbackGroup())

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
pass

def is_ready(self, wait_set):
return False

Expand Down
Loading

0 comments on commit 7e3005a

Please sign in to comment.