diff --git a/sw_utils/common.py b/sw_utils/common.py index 534b7c2..841cec8 100644 --- a/sw_utils/common.py +++ b/sw_utils/common.py @@ -7,16 +7,30 @@ class InterruptHandler: """ Tracks SIGINT and SIGTERM signals. + Usage: + with InterruptHandler() as interrupt_handler: + while not interrupt_handler.exit: + ... """ exit = False - def __init__(self) -> None: + def __enter__(self) -> 'InterruptHandler': signal.signal(signal.SIGINT, self.exit_gracefully) signal.signal(signal.SIGTERM, self.exit_gracefully) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + signal.signal(signal.SIGINT, self.exit_default) + signal.signal(signal.SIGTERM, self.exit_default) - # noinspection PyUnusedLocal def exit_gracefully(self, signum: int, *args, **kwargs) -> None: # pylint: disable=unused-argument + if self.exit: + raise KeyboardInterrupt logger.info('Received interrupt signal %s, exiting...', signum) self.exit = True + + def exit_default(self, signum: int, *args, **kwargs) -> None: + # pylint: disable=unused-argument + raise KeyboardInterrupt