Skip to content

Commit

Permalink
Add types to task.py (#1254)
Browse files Browse the repository at this point in the history
* Add types to task.py

Signed-off-by: Michael Carlstrom <[email protected]>
  • Loading branch information
InvincibleRMC authored Jul 28, 2024
1 parent 2a8f23e commit 573d9c8
Showing 1 changed file with 47 additions and 33 deletions.
80 changes: 47 additions & 33 deletions rclpy/rclpy/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,86 +15,93 @@
import inspect
import sys
import threading
from typing import (Callable, cast, Coroutine, Dict, Generator, Generic, List,
Optional, TYPE_CHECKING, TypeVar, Union)
import warnings
import weakref

if TYPE_CHECKING:
from rclpy.executors import Executor

def _fake_weakref():
T = TypeVar('T')


def _fake_weakref() -> None:
"""Return None when called to simulate a weak reference that has been garbage collected."""
return None


class Future:
class Future(Generic[T]):
"""Represent the outcome of a task in the future."""

def __init__(self, *, executor=None):
def __init__(self, *, executor: Optional['Executor'] = None) -> None:
# true if the task is done or cancelled
self._done = False
# true if the task is cancelled
self._cancelled = False
# the final return value of the handler
self._result = None
self._result: Optional[T] = None
# An exception raised by the handler when called
self._exception = None
self._exception: Optional[Exception] = None
self._exception_fetched = False
# callbacks to be scheduled after this task completes
self._callbacks = []
self._callbacks: List[Callable[['Future[T]'], None]] = []
# Lock for threadsafety
self._lock = threading.Lock()
# An executor to use when scheduling done callbacks
self._executor = None
self._executor: Optional[Union[weakref.ReferenceType['Executor'],
Callable[[], None]]] = None
self._set_executor(executor)

def __del__(self):
def __del__(self) -> None:
if self._exception is not None and not self._exception_fetched:
print(
'The following exception was never retrieved: ' + str(self._exception),
file=sys.stderr)

def __await__(self):
def __await__(self) -> Generator[None, None, Optional[T]]:
# Yield if the task is not finished
while not self._done:
yield
return self.result()

def cancel(self):
def cancel(self) -> None:
"""Request cancellation of the running task if it is not done already."""
with self._lock:
if not self._done:
self._cancelled = True
self._schedule_or_invoke_done_callbacks()

def cancelled(self):
def cancelled(self) -> bool:
"""
Indicate if the task has been cancelled.
:return: True if the task was cancelled
:rtype: bool
"""
return self._cancelled

def done(self):
def done(self) -> bool:
"""
Indicate if the task has finished executing.
:return: True if the task is finished or raised while it was executing
:rtype: bool
"""
return self._done

def result(self):
def result(self) -> Optional[T]:
"""
Get the result of a done task.
:raises: Exception if one was set during the task.
:return: The result set by the task, or None if no result was set.
"""
if self._exception:
raise self.exception()
exception = self.exception()
if exception:
raise exception
return self._result

def exception(self):
def exception(self) -> Optional[Exception]:
"""
Get an exception raised by a done task.
Expand All @@ -103,7 +110,7 @@ def exception(self):
self._exception_fetched = True
return self._exception

def set_result(self, result):
def set_result(self, result: T) -> None:
"""
Set the result returned by a task.
Expand All @@ -115,7 +122,7 @@ def set_result(self, result):
self._cancelled = False
self._schedule_or_invoke_done_callbacks()

def set_exception(self, exception):
def set_exception(self, exception: Exception) -> None:
"""
Set the exception raised by the task.
Expand All @@ -128,13 +135,14 @@ def set_exception(self, exception):
self._cancelled = False
self._schedule_or_invoke_done_callbacks()

def _schedule_or_invoke_done_callbacks(self):
def _schedule_or_invoke_done_callbacks(self) -> None:
"""
Schedule done callbacks on the executor if possible, else run them directly.
This function assumes self._lock is not held.
"""
with self._lock:
assert self._executor is not None
executor = self._executor()
callbacks = self._callbacks
self._callbacks = []
Expand All @@ -152,15 +160,15 @@ def _schedule_or_invoke_done_callbacks(self):
# Don't let exceptions be raised because there may be more callbacks to call
warnings.warn('Unhandled exception in done callback: {}'.format(e))

def _set_executor(self, executor):
def _set_executor(self, executor: Optional['Executor']) -> None:
"""Set the executor this future is associated with."""
with self._lock:
if executor is None:
self._executor = _fake_weakref
else:
self._executor = weakref.ref(executor)

def add_done_callback(self, callback):
def add_done_callback(self, callback: Callable[['Future[T]'], None]) -> None:
"""
Add a callback to be executed when the task is done.
Expand All @@ -174,6 +182,7 @@ def add_done_callback(self, callback):
invoke = False
with self._lock:
if self._done:
assert self._executor is not None
executor = self._executor()
if executor is not None:
executor.create_task(callback, self)
Expand All @@ -187,7 +196,7 @@ def add_done_callback(self, callback):
callback(self)


class Task(Future):
class Task(Future[T]):
"""
Execute a function or coroutine.
Expand All @@ -197,17 +206,21 @@ class Task(Future):
This class should only be instantiated by :class:`rclpy.executors.Executor`.
"""

def __init__(self, handler, args=None, kwargs=None, executor=None):
def __init__(self,
handler: Union[Callable[[], T], Coroutine[None, None, T], None],
args: Optional[List[object]] = None,
kwargs: Optional[Dict[str, object]] = None,
executor: Optional['Executor'] = None) -> None:
super().__init__(executor=executor)
# _handler is either a normal function or a coroutine
self._handler = handler
# Arguments passed into the function
if args is None:
args = []
self._args = args
self._args: Optional[List[object]] = args
if kwargs is None:
kwargs = {}
self._kwargs = kwargs
self._kwargs: Optional[Dict[str, object]] = kwargs
if inspect.iscoroutinefunction(handler):
self._handler = handler(*args, **kwargs)
self._args = None
Expand All @@ -217,7 +230,7 @@ def __init__(self, handler, args=None, kwargs=None, executor=None):
# Lock acquired to prevent task from executing in parallel with itself
self._task_lock = threading.Lock()

def __call__(self):
def __call__(self) -> None:
"""
Run or resume a task.
Expand All @@ -235,11 +248,12 @@ def __call__(self):

if inspect.iscoroutine(self._handler):
# Execute a coroutine
handler = cast(Coroutine[None, None, T], self._handler)
try:
self._handler.send(None)
handler.send(None)
except StopIteration as e:
# The coroutine finished; store the result
self._handler.close()
handler.close()
self.set_result(e.value)
self._complete_task()
except Exception as e:
Expand All @@ -248,6 +262,7 @@ def __call__(self):
else:
# Execute a normal function
try:
assert self._handler is not None and callable(self._handler)
self.set_result(self._handler(*self._args, **self._kwargs))
except Exception as e:
self.set_exception(e)
Expand All @@ -257,17 +272,16 @@ def __call__(self):
finally:
self._task_lock.release()

def _complete_task(self):
def _complete_task(self) -> None:
"""Cleanup after task finished."""
self._handler = None
self._args = None
self._kwargs = None

def executing(self):
def executing(self) -> bool:
"""
Check if the task is currently being executed.
:return: True if the task is currently executing.
:rtype: bool
"""
return self._executing

0 comments on commit 573d9c8

Please sign in to comment.