From a7aea8ed73f2faa0b109d3fe06a63f982bd80905 Mon Sep 17 00:00:00 2001 From: Michael Carlstrom Date: Wed, 31 Jul 2024 07:22:43 -0400 Subject: [PATCH] Add types to callback_groups.py (#1251) Signed-off-by: Michael Carlstrom Co-authored-by: Chris Lalancette --- rclpy/rclpy/callback_groups.py | 38 +++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/rclpy/rclpy/callback_groups.py b/rclpy/rclpy/callback_groups.py index ff98c2f5f..37412fea7 100644 --- a/rclpy/rclpy/callback_groups.py +++ b/rclpy/rclpy/callback_groups.py @@ -13,9 +13,19 @@ # limitations under the License. from threading import Lock +from typing import Literal, Optional, TYPE_CHECKING, Union import weakref +if TYPE_CHECKING: + from rclpy.subscription import Subscription + from rclpy.timer import Timer + from rclpy.client import Client + from rclpy.service import Service + from rclpy.waitable import Waitable + Entity = Union[Subscription, Timer, Client, Service, Waitable] + + class CallbackGroup: """ The base class for a callback group. @@ -29,9 +39,9 @@ class CallbackGroup: def __init__(self) -> None: super().__init__() - self.entities: set = set() + self.entities: set[weakref.ReferenceType['Entity']] = set() - def add_entity(self, entity) -> None: + def add_entity(self, entity: 'Entity') -> None: """ Add an entity to the callback group. @@ -39,7 +49,7 @@ def add_entity(self, entity) -> None: """ self.entities.add(weakref.ref(entity)) - def has_entity(self, entity) -> bool: + def has_entity(self, entity: 'Entity') -> bool: """ Determine if an entity has been added to this group. @@ -47,7 +57,7 @@ def has_entity(self, entity) -> bool: """ return weakref.ref(entity) in self.entities - def can_execute(self, entity) -> bool: + def can_execute(self, entity: 'Entity') -> bool: """ Determine if an entity can be executed. @@ -56,7 +66,7 @@ def can_execute(self, entity) -> bool: """ raise NotImplementedError() - def beginning_execution(self, entity) -> bool: + def beginning_execution(self, entity: 'Entity') -> bool: """ Get permission for the callback from the group to begin executing an entity. @@ -68,7 +78,7 @@ def beginning_execution(self, entity) -> bool: """ raise NotImplementedError() - def ending_execution(self, entity) -> None: + def ending_execution(self, entity: 'Entity') -> None: """ Notify group that a callback has finished executing. @@ -80,30 +90,30 @@ def ending_execution(self, entity) -> None: class ReentrantCallbackGroup(CallbackGroup): """Allow callbacks to be executed in parallel without restriction.""" - def can_execute(self, entity): + def can_execute(self, entity: 'Entity') -> Literal[True]: return True - def beginning_execution(self, entity): + def beginning_execution(self, entity: 'Entity') -> Literal[True]: return True - def ending_execution(self, entity): + def ending_execution(self, entity: 'Entity') -> None: pass class MutuallyExclusiveCallbackGroup(CallbackGroup): """Allow only one callback to be executing at a time.""" - def __init__(self): + def __init__(self) -> None: super().__init__() - self._active_entity = None + self._active_entity: Optional['Entity'] = None self._lock = Lock() - def can_execute(self, entity): + def can_execute(self, entity: 'Entity') -> bool: with self._lock: assert weakref.ref(entity) in self.entities return self._active_entity is None - def beginning_execution(self, entity): + def beginning_execution(self, entity: 'Entity') -> bool: with self._lock: assert weakref.ref(entity) in self.entities if self._active_entity is None: @@ -111,7 +121,7 @@ def beginning_execution(self, entity): return True return False - def ending_execution(self, entity): + def ending_execution(self, entity: 'Entity') -> None: with self._lock: assert self._active_entity == entity self._active_entity = None