diff --git a/rclpy/rclpy/executors.py b/rclpy/rclpy/executors.py index a72036d7a..db41f58b3 100644 --- a/rclpy/rclpy/executors.py +++ b/rclpy/rclpy/executors.py @@ -230,7 +230,7 @@ def shutdown(self, timeout_sec: Optional[float] = None) -> bool: :param timeout_sec: Seconds to wait. Block forever if ``None`` or negative. Don't wait if 0. :return: ``True`` if all outstanding callbacks finished executing, or ``False`` if the - timeot expires before all outstanding work is done. + timeout expires before all outstanding work is done. """ with self._shutdown_lock: if not self._is_shutdown: @@ -931,3 +931,23 @@ def spin_once_until_future_complete( ) -> None: future.add_done_callback(lambda x: self.wake()) self._spin_once_impl(timeout_sec, future.done) + + def shutdown( + self, + timeout_sec: float = None, + *, + wait_for_threads: bool = True + ) -> bool: + """ + Stop executing callbacks and wait for their completion. + + :param timeout_sec: Seconds to wait. Block forever if ``None`` or negative. + Don't wait if 0. + :param wait_for_threads: If true, this function will block until all executor threads + have joined. + :return: ``True`` if all outstanding callbacks finished executing, or ``False`` if the + timeout expires before all outstanding work is done. + """ + success: bool = super().shutdown(timeout_sec) + self._executor.shutdown(wait=wait_for_threads) + return success diff --git a/rclpy/test/test_executor.py b/rclpy/test/test_executor.py index 223a34db2..e663d1e3c 100644 --- a/rclpy/test/test_executor.py +++ b/rclpy/test/test_executor.py @@ -168,6 +168,24 @@ def test_multi_threaded_executor_executes(self): finally: executor.shutdown() + def test_multi_threaded_executor_closes_threads(self): + self.assertIsNotNone(self.node.handle) + + def get_threads(): + return {t.name for t in threading.enumerate()} + + main_thread_name = get_threads() + # Explicitly specify 2_threads for single thread system failure + executor = MultiThreadedExecutor(context=self.context, num_threads=2) + + try: + # Give the executor a callback so at least one thread gets spun up + self.assertTrue(self.func_execution(executor)) + finally: + self.assertTrue(main_thread_name != get_threads()) + executor.shutdown(wait_for_threads=True) + self.assertTrue(main_thread_name == get_threads()) + def test_add_node_to_executor(self): self.assertIsNotNone(self.node.handle) executor = SingleThreadedExecutor(context=self.context)