Skip to content

Commit

Permalink
judge: implement instant aborts
Browse files Browse the repository at this point in the history
The way this works is:

- Worker creates a tempdir, and sets `tempfile.tempdir` to this directory.
- Worker sends back the tempdir. The parent process is responsible for cleaning
  it up when the worker exits.

Abortions are then implemented as sending `SIGKILL` to the worker.

As a side benefit of this implementation, we also get to drop the hacky
`CompiledExecutor` cache deletion.
  • Loading branch information
Xyene committed Dec 31, 2023
1 parent accffd0 commit a33b5bf
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 53 deletions.
5 changes: 5 additions & 0 deletions dmoj/graders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def grade(self, case: TestCase) -> Result:
def _generate_binary(self) -> BaseExecutor:
raise NotImplementedError

<<<<<<< HEAD
def abort_grading(self) -> None:
self._abort_requested = True
if self._current_proc:
Expand All @@ -47,6 +48,10 @@ def abort_grading(self) -> None:

def _resolve_testcases(self, cfg, batch_no=0) -> List[BaseTestCase]:
cases: List[BaseTestCase] = []
=======
def _resolve_testcases(self, cfg, batch_no=0):
cases = []
>>>>>>> 662a3127 (judge: implement instant aborts)
for case_config in cfg:
if 'batched' in case_config.raw_config:
self._batch_counter += 1
Expand Down
96 changes: 43 additions & 53 deletions dmoj/judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import logging
import multiprocessing
import os
import shutil
import signal
import sys
import tempfile
import threading
import traceback
from enum import Enum
Expand Down Expand Up @@ -41,14 +43,14 @@ class IPC(Enum):
BATCH_END = 'BATCH-END'
GRADING_BEGIN = 'GRADING-BEGIN'
GRADING_END = 'GRADING-END'
GRADING_ABORTED = 'GRADING-ABORTED'
UNHANDLED_EXCEPTION = 'UNHANDLED-EXCEPTION'
REQUEST_ABORT = 'REQUEST-ABORT'


class JudgeWorkerAborted(Exception):
pass


# This needs to be at least as large as the timeout for the largest compiler time limit, but we don't enforce that here.
# (Otherwise, aborting during a compilation that exceeds this time limit would result in a `TimeoutError` IE instead of
# a `CompileError`.)
IPC_TIMEOUT = 60 # seconds


Expand Down Expand Up @@ -128,8 +130,6 @@ def begin_grading(self, submission: Submission, report=logger.info, blocking=Fal
)
)

# FIXME(tbrindus): what if we receive an abort from the judge before IPC handshake completes? We'll send
# an abort request down the pipe, possibly messing up the handshake.
self.current_judge_worker = JudgeWorker(submission)

ipc_ready_signal = threading.Event()
Expand All @@ -147,13 +147,19 @@ def _grading_thread_main(self, ipc_ready_signal: threading.Event, report) -> Non
assert self.current_judge_worker is not None

try:
worker_tempdir = None

def _ipc_hello(_report, tempdir: str):
nonlocal worker_tempdir
ipc_ready_signal.set()
worker_tempdir = tempdir

ipc_handler_dispatch: Dict[IPC, Callable] = {
IPC.HELLO: lambda _report: ipc_ready_signal.set(),
IPC.HELLO: _ipc_hello,
IPC.COMPILE_ERROR: self._ipc_compile_error,
IPC.COMPILE_MESSAGE: self._ipc_compile_message,
IPC.GRADING_BEGIN: self._ipc_grading_begin,
IPC.GRADING_END: self._ipc_grading_end,
IPC.GRADING_ABORTED: self._ipc_grading_aborted,
IPC.BATCH_BEGIN: self._ipc_batch_begin,
IPC.BATCH_END: self._ipc_batch_end,
IPC.RESULT: self._ipc_result,
Expand All @@ -176,12 +182,17 @@ def _grading_thread_main(self, ipc_ready_signal: threading.Event, report) -> Non
% (self.current_submission.problem_id, self.current_submission.id)
)
)
except JudgeWorkerAborted:
self.packet_manager.submission_aborted_packet()
except Exception: # noqa: E722, we want to catch everything
self.log_internal_error()
finally:
self.current_judge_worker.wait_with_timeout()
self.current_judge_worker = None

if worker_tempdir:
shutil.rmtree(worker_tempdir)

# Might not have been set if an exception was encountered before HELLO message, so signal here to keep the
# other side from waiting forever.
ipc_ready_signal.set()
Expand Down Expand Up @@ -232,10 +243,6 @@ def _ipc_batch_begin(self, report, batch_number: int) -> None:
def _ipc_batch_end(self, _report, _batch_number: int) -> None:
self.packet_manager.batch_end_packet()

def _ipc_grading_aborted(self, report) -> None:
self.packet_manager.submission_aborted_packet()
report(ansi_style('#ansi[Forcefully terminating grading. Temporary files may not be deleted.](red|bold)'))

def _ipc_unhandled_exception(self, _report, message: str) -> None:
logger.error('Unhandled exception in worker process')
self.log_internal_error(message=message)
Expand All @@ -254,10 +261,9 @@ def abort_grading(self, submission_id: Optional[int] = None) -> None:
'Received abortion request for %d, but %d is currently running', submission_id, worker.submission.id
)
else:
logger.info('Received abortion request for %d', worker.submission.id)
# These calls are idempotent, so it doesn't matter if we raced and the worker has exited already.
worker.request_abort_grading()
worker.wait_with_timeout()
logger.info('Received abortion request for %d, killing worker', worker.submission.id)
# This call is idempotent, so it doesn't matter if we raced and the worker has exited already.
worker.abort_grading__kill_worker()

def listen(self) -> None:
"""
Expand All @@ -270,7 +276,8 @@ def murder(self) -> None:
"""
End any submission currently executing, and exit the judge.
"""
self.abort_grading()
if self.current_judge_worker:
self.current_judge_worker.abort_grading__kill_worker()
self.updater_exit = True
self.updater_signal.set()
if self.packet_manager:
Expand Down Expand Up @@ -304,8 +311,8 @@ def log_internal_error(self, exc: Optional[BaseException] = None, message: Optio
class JudgeWorker:
def __init__(self, submission: Submission) -> None:
self.submission = submission
self._abort_requested = False
self._sent_sigkill_to_worker_process = False
self._aborted = False
self._timed_out = False
# FIXME(tbrindus): marked Any pending grader cleanups.
self.grader: Any = None

Expand All @@ -331,8 +338,12 @@ def communicate(self) -> Generator[Tuple[IPC, tuple], None, None]:
self.worker_process.kill()
raise
except EOFError:
if self._sent_sigkill_to_worker_process:
raise TimeoutError('worker did not exit in %d seconds, so it was killed' % IPC_TIMEOUT)
if self._aborted:
raise JudgeWorkerAborted() from None

if self._timed_out:
raise TimeoutError('worker did not exit in %d seconds, so it was killed' % IPC_TIMEOUT) from None

raise
except Exception:
logger.error('Failed to read IPC message from worker!')
Expand All @@ -354,16 +365,14 @@ def wait_with_timeout(self) -> None:
finally:
if self.worker_process.is_alive():
logger.error('Worker is still alive, sending SIGKILL!')
self._sent_sigkill_to_worker_process = True
self._timed_out = True
self.worker_process.kill()

def request_abort_grading(self) -> None:
assert self.worker_process_conn

try:
self.worker_process_conn.send((IPC.REQUEST_ABORT, ()))
except Exception:
logger.exception('Failed to send abort request to worker, did it race?')
def abort_grading__kill_worker(self) -> None:
if self.worker_process and self.worker_process.is_alive():
self._aborted = True
self.worker_process.kill()
self.worker_process.join(timeout=1)

def _worker_process_main(
self,
Expand All @@ -384,15 +393,12 @@ def _ipc_recv_thread_main() -> None:
while True:
try:
ipc_type, data = judge_process_conn.recv()
except: # noqa: E722, whatever happened, we have to abort now.
except: # noqa: E722, whatever happened, we have to exit now.
logger.exception('Judge unexpectedly hung up!')
self._do_abort()
return

if ipc_type == IPC.BYE:
return
elif ipc_type == IPC.REQUEST_ABORT:
self._do_abort()
else:
raise RuntimeError('worker got unexpected IPC message from judge: %s' % ((ipc_type, data),))

Expand All @@ -402,9 +408,12 @@ def _report_unhandled_exception() -> None:
judge_process_conn.send((IPC.UNHANDLED_EXCEPTION, (message,)))
judge_process_conn.send((IPC.BYE, ()))

tempdir = tempfile.mkdtemp('dmoj-judge-worker')
tempfile.tempdir = tempdir

ipc_recv_thread = None
try:
judge_process_conn.send((IPC.HELLO, ()))
judge_process_conn.send((IPC.HELLO, (tempdir,)))

ipc_recv_thread = threading.Thread(target=_ipc_recv_thread_main, daemon=True)
ipc_recv_thread.start()
Expand Down Expand Up @@ -439,15 +448,6 @@ def _report_unhandled_exception() -> None:
if ipc_recv_thread.is_alive():
logger.error('Judge IPC recv thread is still alive after timeout, shutting worker down anyway!')

# FIXME(tbrindus): we need to do this because cleaning up temporary directories happens on __del__, which
# won't get called if we exit the process right now (so we'd leak all files created by the grader). This
# should be refactored to have an explicit `cleanup()` or similar, rather than relying on refcounting
# working out.
from dmoj.executors.compiled_executor import _CompiledExecutorMeta

for cached_executor in _CompiledExecutorMeta.compiled_binary_cache.values():
cached_executor.is_cached = False
cached_executor.cleanup()
self.grader = None

def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]:
Expand Down Expand Up @@ -505,11 +505,6 @@ def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]:
assert isinstance(case, TestCase)
result = self.grader.grade(case)

# If the submission was killed due to a user-initiated abort, any result is meaningless.
if self._abort_requested:
yield IPC.GRADING_ABORTED, ()
return

if result.result_flag & Result.WA:
# If we failed a 0-point case, we will short-circuit every case after this.
is_short_circuiting_enabled |= not case.points
Expand All @@ -534,11 +529,6 @@ def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]:

yield IPC.GRADING_END, ()

def _do_abort(self) -> None:
self._abort_requested = True
if self.grader:
self.grader.abort_grading()


class ClassicJudge(Judge):
def __init__(self, host, port, **kwargs) -> None:
Expand Down

0 comments on commit a33b5bf

Please sign in to comment.