Skip to content

Commit

Permalink
grader: implement general dependencies & pretests
Browse files Browse the repository at this point in the history
This commit changes the previous "batch dependencies" in a few ways:
1. Dependencies can now be specified on any top-level unit, case or batch.
2. Dependencies are specified with YAML alias notation, not numbers.

It also implements pretests, including a `PRETEST-END` IPC packet.
  • Loading branch information
Riolku committed Nov 5, 2022
1 parent c346cec commit fae7b90
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 91 deletions.
6 changes: 6 additions & 0 deletions dmoj/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ def begin_grading_packet(self, is_pretested):
def grading_end_packet(self):
pass

def pretest_begin_packet(self):
pass

def pretest_end_packet(self):
pass

def batch_begin_packet(self):
pass

Expand Down
1 change: 1 addition & 0 deletions dmoj/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class ConfigNode:

def __init__(self, raw_config=None, parent=None, defaults=None, dynamic=True):
self.dynamic = dynamic
self.raw_config_id = id(raw_config)
if defaults:
self.raw_config = defaults
self.raw_config.update(raw_config or {})
Expand Down
41 changes: 25 additions & 16 deletions dmoj/graders/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from dmoj.problem import BatchedTestCase, TestCase
from typing import List, Tuple

from dmoj.config import InvalidInitException
from dmoj.problem import Batch, BatchedTestCase, StandaloneTestCase, TopLevelCase
from dmoj.utils.unicode import utf8bytes


Expand Down Expand Up @@ -29,37 +32,43 @@ def abort_grading(self):
except OSError:
pass

def _resolve_testcases(self, cfg, batch_no=0):
cases = []
for case_config in cfg:
def _resolve_testcases(self, cfg) -> List[TopLevelCase]:
cases: List[TopLevelCase] = []
for top_level_position, case_config in enumerate(cfg):
if 'batched' in case_config.raw_config:
self._batch_counter += 1
cases.append(
BatchedTestCase(
Batch(
self._batch_counter,
case_config,
self.problem,
self._resolve_testcases(case_config['batched'], self._batch_counter),
self._resolve_batched_cases(case_config['batched'], self._batch_counter),
)
)
else:
cases.append(TestCase(self._testcase_counter, batch_no, case_config, self.problem))
cases.append(StandaloneTestCase(self._testcase_counter, case_config, self.problem))
self._testcase_counter += 1
return cases

def cases(self):
pretest_test_cases = self.problem.config.pretest_test_cases
if self.run_pretests_only and pretest_test_cases:
return self._resolve_testcases(pretest_test_cases)
def _resolve_batched_cases(self, cfg, batch_no) -> List[BatchedTestCase]:
batched_cases = []
for case_config in cfg:
if 'batched' in case_config.raw_config:
raise InvalidInitException('nested batches')
batched_cases.append(BatchedTestCase(self._testcase_counter, batch_no, case_config, self.problem))
return batched_cases

test_cases = self._resolve_testcases(self.problem.config.test_cases)
def cases(self) -> Tuple[List[TopLevelCase], List[TopLevelCase]]:
pretest_test_cases = self.problem.config.pretest_test_cases
if pretest_test_cases:
pretest_test_cases = self._resolve_testcases(pretest_test_cases)

# Hack: force short-circuiting behavior
for case in pretest_test_cases:
# Hack: force short-circuiting behavior
case.points = 0
else:
pretest_test_cases = []

test_cases = pretest_test_cases + test_cases

return test_cases
# Important that this comes after the previous `_resolve_testcases` call, otherwise our underlying `position` values would be all wrong.
test_cases = self._resolve_testcases(self.problem.config.test_cases)
return (pretest_test_cases, test_cases)
170 changes: 113 additions & 57 deletions dmoj/judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@
import traceback
from enum import Enum
from http.server import HTTPServer
from itertools import groupby
from operator import itemgetter
from typing import Any, Callable, Dict, Generator, List, NamedTuple, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, Generator, List, NamedTuple, Optional, Set, Tuple

from dmoj import packet
from dmoj.control import JudgeControlRequestHandler
from dmoj.error import CompileError
from dmoj.judgeenv import clear_problem_dirs_cache, env, get_supported_problems_and_mtimes, startup_warnings
from dmoj.monitor import Monitor
from dmoj.problem import BatchedTestCase, Problem, TestCase
from dmoj.problem import Batch, Problem, TopLevelCase
from dmoj.result import Result
from dmoj.utils import builtin_int_patch
from dmoj.utils.ansi import ansi_style, print_ansi, strip_ansi
Expand All @@ -39,6 +37,8 @@ class IPC(Enum):
RESULT = 'RESULT'
BATCH_BEGIN = 'BATCH-BEGIN'
BATCH_END = 'BATCH-END'
PRETEST_BEGIN = 'PRETEST-BEGIN'
PRETEST_END = 'PRETEST-END'
GRADING_BEGIN = 'GRADING-BEGIN'
GRADING_END = 'GRADING-END'
GRADING_ABORTED = 'GRADING-ABORTED'
Expand Down Expand Up @@ -77,6 +77,12 @@ def __init__(self, packet_manager: packet.PacketManager) -> None:
self.updater_signal = threading.Event()
self.updater = threading.Thread(target=self._updater_thread)

self.in_pretests = False
self.pretest_batch = 0
self.pretest_case = 0
self.main_batch = 0
self.main_case = 0

@property
def current_submission(self):
worker = self.current_judge_worker
Expand Down Expand Up @@ -151,6 +157,8 @@ def _grading_thread_main(self, ipc_ready_signal: threading.Event, report) -> Non
IPC.GRADING_BEGIN: self._ipc_grading_begin,
IPC.GRADING_END: self._ipc_grading_end,
IPC.GRADING_ABORTED: self._ipc_grading_aborted,
IPC.PRETEST_BEGIN: self._ipc_pretest_begin,
IPC.PRETEST_END: self._ipc_pretest_end,
IPC.BATCH_BEGIN: self._ipc_batch_begin,
IPC.BATCH_END: self._ipc_batch_end,
IPC.RESULT: self._ipc_result,
Expand Down Expand Up @@ -199,6 +207,15 @@ def _ipc_grading_begin(self, _report, is_pretested: bool) -> None:
def _ipc_grading_end(self, _report) -> None:
self.packet_manager.grading_end_packet()

def _ipc_pretest_begin(self, _report) -> None:
self.in_pretests = True
self.packet_manager.pretest_begin_packet()

def _ipc_pretest_end(self, report) -> None:
self.in_pretests = False
report('')
self.packet_manager.pretest_end_packet()

def _ipc_result(self, report, batch_number: Optional[int], case_number: int, result: Result) -> None:
codes = result.readable_codes()

Expand All @@ -218,13 +235,29 @@ def _ipc_result(self, report, batch_number: Optional[int], case_number: int, res
colored_feedback,
colored_aux_codes,
)

case_padding = ' ' if batch_number is not None else ''
report(ansi_style('%sTest case %2d %-3s %s' % (case_padding, case_number, colored_codes[0], case_info)))
if self.in_pretests:
self.pretest_case += 1
report(
ansi_style(
'%sPretest case %2d %-3s %s' % (case_padding, self.pretest_case, colored_codes[0], case_info)
)
)
else:
self.main_case += 1
report(ansi_style('%sTest case %2d %-3s %s' % (case_padding, self.main_case, colored_codes[0], case_info)))

self.packet_manager.test_case_status_packet(case_number, result)

def _ipc_batch_begin(self, report, batch_number: int) -> None:
def _ipc_batch_begin(self, report, _batch_number: int) -> None:
self.packet_manager.batch_begin_packet()
report(ansi_style('#ansi[Batch #%d](yellow|bold)' % batch_number))
if self.in_pretests:
self.pretest_batch += 1
report(ansi_style('#ansi[Pretest Batch #%d](yellow|bold)' % self.pretest_batch))
else:
self.main_batch += 1
report(ansi_style('#ansi[Batch #%d](yellow|bold)' % self.main_batch))

def _ipc_batch_end(self, _report, _batch_number: int) -> None:
self.packet_manager.batch_end_packet()
Expand Down Expand Up @@ -459,69 +492,92 @@ def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]:
if hasattr(binary, 'warning') and binary.warning is not None:
yield IPC.COMPILE_MESSAGE, (binary.warning,)

yield IPC.GRADING_BEGIN, (self.grader.run_pretests_only,)

flattened_cases: List[Tuple[Optional[int], Union[TestCase, BatchedTestCase]]] = []
skip_all: bool = False
skip_current: bool = False
batch_number = 0
batch_dependencies: List[Set[int]] = []
for case in self.grader.cases():
if isinstance(case, BatchedTestCase):
batch_number += 1
for batched_case in case.batched_cases:
flattened_cases.append((batch_number, batched_case))
batch_dependencies.append(set(case.dependencies))
else:
flattened_cases.append((None, case))

case_number = 0
is_short_circuiting = False
is_short_circuiting_enabled = self.submission.short_circuit
passed_batches: Set[int] = set()
for batch_number, cases in groupby(flattened_cases, key=itemgetter(0)):
if batch_number:
yield IPC.BATCH_BEGIN, (batch_number,)

dependencies = batch_dependencies[batch_number - 1] # List is zero-indexed
if passed_batches & dependencies != dependencies:
is_short_circuiting = True

for _, case in cases:
case_number += 1

# Stop grading if we're short circuiting
if is_short_circuiting:
result = Result(case, result_flag=Result.SC)
global_test_number = 0

class GradingAbort(Exception):
pass

def grade_tests(
tests: List[TopLevelCase], should_run: Callable[[TopLevelCase, Set[Any]], bool]
) -> Generator[Tuple[IPC, tuple], None, Set[Any]]:
nonlocal batch_number, global_test_number, skip_all, skip_current
failed_tests: Set[Any] = set()
for test in tests:
skip_current = skip_all or not should_run(test, failed_tests)
if isinstance(test, Batch):
batch_number += 1
yield IPC.BATCH_BEGIN, (batch_number,)
for case in test.batched_cases:
global_test_number += 1
yield IPC.RESULT, (batch_number, global_test_number, grade_one(case))

yield IPC.BATCH_END, (batch_number,)

else:
result = self.grader.grade(case)
global_test_number += 1
yield IPC.RESULT, (None, global_test_number, grade_one(test))

# If the submission was killed due to a user-initiated abort, any result is meaningless.
if self._abort_requested:
yield IPC.GRADING_ABORTED, ()
return
if skip_current:
failed_tests.add(test.config.raw_config_id)
if not test.points or self.submission.short_circuit:
skip_all = True

return failed_tests

def grade_one(case) -> Result:
nonlocal skip_current
if skip_current:
return Result(case, result_flag=Result.SC)
else:
result = self.grader.grade(case)

if result.result_flag & Result.WA:
# If we failed a 0-point case, we will short-circuit every case after this.
is_short_circuiting_enabled |= not case.points
# If the submission was killed due to a user-initiated abort, any result is meaningless.
if self._abort_requested:
raise GradingAbort

# Short-circuit if we just failed a case in a batch, or if short-circuiting is currently enabled
# for all test cases (either this was requested by the site, or we failed a 0-point case in the
# past).
is_short_circuiting |= batch_number is not None or is_short_circuiting_enabled
if result.result_flag & result.WA:
skip_current = True

# Legacy hack: we need to allow graders to read and write `proc_output` on the `Result` object, but the
# judge controller only cares about the trimmed output, and shouldn't waste memory buffering the full
# output. So, we trim it here so we don't run out of memory in the controller.
result.proc_output = result.output
yield IPC.RESULT, (batch_number, case_number, result)
return result

if batch_number:
if not is_short_circuiting:
passed_batches.add(batch_number)
def should_run_test(case: TopLevelCase, failed_so_far: Set[int]) -> bool:
if case.dependencies is None:
# Default: depends on nothing.
return True
else:
return len(case.dependencies & failed_so_far) == 0

failed_pretests: Set[int] = set()

yield IPC.BATCH_END, (batch_number,)
is_short_circuiting &= is_short_circuiting_enabled
def should_run_main_test(case: TopLevelCase, failed_so_far: Set[int]) -> bool:
good_pretests = False
if case.dependencies is None:
# Default: depends on all pretests.
good_pretests = len(failed_pretests) == 0
else:
good_pretests = len(failed_pretests & case.dependencies) == 0
return good_pretests and should_run_test(case, failed_so_far)

yield IPC.GRADING_END, ()
yield IPC.GRADING_BEGIN, (self.grader.run_pretests_only,)
pretests, main_tests = self.grader.cases()
try:
if pretests:
yield IPC.PRETEST_BEGIN, ()
failed_pretests = yield from grade_tests(pretests, should_run_test)
yield IPC.PRETEST_END, ()
if not self.grader.run_pretests_only:
yield from grade_tests(main_tests, should_run_main_test)
except GradingAbort:
yield IPC.GRADING_ABORTED, ()
else:
yield IPC.GRADING_END, ()

def _do_abort(self) -> None:
self._abort_requested = True
Expand Down
9 changes: 9 additions & 0 deletions dmoj/packet.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,15 @@ def grading_end_packet(self):
self._flush_testcase_queue()
self._send_packet({'name': 'grading-end', 'submission-id': self.judge.current_submission.id})

def pretest_begin_packet(self):
log.debug('Begin pretests: %d', self.judge.current_submission.id)
self._send_packet({'name': 'pretest-begin', 'submission-id': self.judge.current_submission.id})

def pretest_end_packet(self):
log.debug('End pretests: %d', self.judge.current_submission.id)
self._flush_testcase_queue()
self._send_packet({'name': 'pretest-end', 'submission-id': self.judge.current_submission.id})

def batch_begin_packet(self):
self._batch += 1
log.debug('Enter batch number %d: %d', self._batch, self.judge.current_submission.id)
Expand Down
Loading

0 comments on commit fae7b90

Please sign in to comment.