Skip to content

Commit

Permalink
Replace _IOPubThread with BaseThread
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Nov 15, 2024
1 parent 615ec12 commit 1834b58
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 39 deletions.
6 changes: 5 additions & 1 deletion ipykernel/inprocess/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,8 @@ async def poll(self, timeout=0):
return statistics.current_buffer_used != 0

def close(self):
pass
if self.is_shell:
self.in_send_stream.close()
self.in_receive_stream.close()
self.out_send_stream.close()
self.out_receive_stream.close()
46 changes: 9 additions & 37 deletions ipykernel/iostream.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
from binascii import b2a_hex
from collections import defaultdict, deque
from io import StringIO, TextIOBase
from threading import Event, Thread, local
from threading import local
from typing import Any, Callable

import zmq
import zmq_anyio
from anyio import create_task_group, run, sleep, to_thread
from anyio import sleep
from jupyter_client.session import extract_header

from .thread import BaseThread

# -----------------------------------------------------------------------------
# Globals
# -----------------------------------------------------------------------------
Expand All @@ -38,38 +40,6 @@
# -----------------------------------------------------------------------------


class _IOPubThread(Thread):
"""A thread for a IOPub."""

def __init__(self, tasks, **kwargs):
"""Initialize the thread."""
super().__init__(name="IOPub", **kwargs)
self._tasks = tasks
self.pydev_do_not_trace = True
self.is_pydev_daemon_thread = True
self.daemon = True
self.__stop = Event()

def run(self):
"""Run the thread."""
self.name = "IOPub"
run(self._main)

async def _main(self):
async with create_task_group() as self._task_group:
for task in self._tasks:
self._task_group.start_soon(task)
await to_thread.run_sync(self.__stop.wait)
self._task_group.cancel_scope.cancel()

def stop(self):
"""Stop the thread.
This method is threadsafe.
"""
self.__stop.set()


class IOPubThread:
"""An object for sending IOPub messages in a background thread
Expand Down Expand Up @@ -109,7 +79,9 @@ def __init__(self, socket: zmq_anyio.Socket, pipe=False):
tasks = [self._handle_event, self._run_event_pipe_gc, self.socket.start]
if pipe:
tasks.append(self._handle_pipe_msgs)
self.thread = _IOPubThread(tasks)
self.thread = BaseThread(name="IOPub", daemon=True)
for task in tasks:
self.thread.start_soon(task)

def _setup_event_pipe(self):
"""Create the PULL socket listening for events that should fire in this thread."""
Expand Down Expand Up @@ -179,7 +151,7 @@ async def _handle_event(self):
event_f = self._events.popleft()
event_f()
except Exception:
if self.thread.__stop.is_set():
if self.thread.stopped.is_set():
return
raise

Expand Down Expand Up @@ -211,7 +183,7 @@ async def _handle_pipe_msgs(self):
while True:
await self._handle_pipe_msg()
except Exception:
if self.thread.__stop.is_set():
if self.thread.stopped.is_set():
return
raise

Expand Down
6 changes: 5 additions & 1 deletion ipykernel/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from collections.abc import Awaitable
from queue import Queue
from threading import Thread
from threading import Event, Thread
from typing import Callable

from anyio import create_task_group, run, to_thread
Expand All @@ -18,6 +18,8 @@ class BaseThread(Thread):
def __init__(self, **kwargs):
"""Initialize the thread."""
super().__init__(**kwargs)
self.started = Event()
self.stopped = Event()
self.pydev_do_not_trace = True
self.is_pydev_daemon_thread = True
self._tasks: Queue[Callable[[], Awaitable[None]] | None] = Queue()
Expand All @@ -31,6 +33,7 @@ def run(self) -> None:

async def _main(self) -> None:
async with create_task_group() as tg:
self.started.set()
while True:
task = await to_thread.run_sync(self._tasks.get)
if task is None:
Expand All @@ -44,3 +47,4 @@ def stop(self) -> None:
This method is threadsafe.
"""
self._tasks.put(None)
self.stopped.set()

0 comments on commit 1834b58

Please sign in to comment.