Skip to content

Commit

Permalink
Add types to callback_groups.py (#1251)
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Carlstrom <[email protected]>
Co-authored-by: Chris Lalancette <[email protected]>
  • Loading branch information
InvincibleRMC and clalancette authored Jul 31, 2024
1 parent dc72d8c commit a7aea8e
Showing 1 changed file with 24 additions and 14 deletions.
38 changes: 24 additions & 14 deletions rclpy/rclpy/callback_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,19 @@
# limitations under the License.

from threading import Lock
from typing import Literal, Optional, TYPE_CHECKING, Union
import weakref


if TYPE_CHECKING:
from rclpy.subscription import Subscription
from rclpy.timer import Timer
from rclpy.client import Client
from rclpy.service import Service
from rclpy.waitable import Waitable
Entity = Union[Subscription, Timer, Client, Service, Waitable]


class CallbackGroup:
"""
The base class for a callback group.
Expand All @@ -29,25 +39,25 @@ class CallbackGroup:

def __init__(self) -> None:
super().__init__()
self.entities: set = set()
self.entities: set[weakref.ReferenceType['Entity']] = set()

def add_entity(self, entity) -> None:
def add_entity(self, entity: 'Entity') -> None:
"""
Add an entity to the callback group.
:param entity: a subscription, timer, client, service, or waitable instance.
"""
self.entities.add(weakref.ref(entity))

def has_entity(self, entity) -> bool:
def has_entity(self, entity: 'Entity') -> bool:
"""
Determine if an entity has been added to this group.
:param entity: a subscription, timer, client, service, or waitable instance.
"""
return weakref.ref(entity) in self.entities

def can_execute(self, entity) -> bool:
def can_execute(self, entity: 'Entity') -> bool:
"""
Determine if an entity can be executed.
Expand All @@ -56,7 +66,7 @@ def can_execute(self, entity) -> bool:
"""
raise NotImplementedError()

def beginning_execution(self, entity) -> bool:
def beginning_execution(self, entity: 'Entity') -> bool:
"""
Get permission for the callback from the group to begin executing an entity.
Expand All @@ -68,7 +78,7 @@ def beginning_execution(self, entity) -> bool:
"""
raise NotImplementedError()

def ending_execution(self, entity) -> None:
def ending_execution(self, entity: 'Entity') -> None:
"""
Notify group that a callback has finished executing.
Expand All @@ -80,38 +90,38 @@ def ending_execution(self, entity) -> None:
class ReentrantCallbackGroup(CallbackGroup):
"""Allow callbacks to be executed in parallel without restriction."""

def can_execute(self, entity):
def can_execute(self, entity: 'Entity') -> Literal[True]:
return True

def beginning_execution(self, entity):
def beginning_execution(self, entity: 'Entity') -> Literal[True]:
return True

def ending_execution(self, entity):
def ending_execution(self, entity: 'Entity') -> None:
pass


class MutuallyExclusiveCallbackGroup(CallbackGroup):
"""Allow only one callback to be executing at a time."""

def __init__(self):
def __init__(self) -> None:
super().__init__()
self._active_entity = None
self._active_entity: Optional['Entity'] = None
self._lock = Lock()

def can_execute(self, entity):
def can_execute(self, entity: 'Entity') -> bool:
with self._lock:
assert weakref.ref(entity) in self.entities
return self._active_entity is None

def beginning_execution(self, entity):
def beginning_execution(self, entity: 'Entity') -> bool:
with self._lock:
assert weakref.ref(entity) in self.entities
if self._active_entity is None:
self._active_entity = entity
return True
return False

def ending_execution(self, entity):
def ending_execution(self, entity: 'Entity') -> None:
with self._lock:
assert self._active_entity == entity
self._active_entity = None

0 comments on commit a7aea8e

Please sign in to comment.