diff --git a/dmoj/cli.py b/dmoj/cli.py index d8c4e9668..c56600e49 100644 --- a/dmoj/cli.py +++ b/dmoj/cli.py @@ -39,6 +39,12 @@ def begin_grading_packet(self, is_pretested): def grading_end_packet(self): pass + def pretest_begin_packet(self): + pass + + def pretest_end_packet(self): + pass + def batch_begin_packet(self): pass diff --git a/dmoj/config.py b/dmoj/config.py index 6aa19b8d0..36f2e4be5 100644 --- a/dmoj/config.py +++ b/dmoj/config.py @@ -46,6 +46,7 @@ class ConfigNode: def __init__(self, raw_config=None, parent=None, defaults=None, dynamic=True): self.dynamic = dynamic + self.raw_config_id = id(raw_config) if defaults: self.raw_config = defaults self.raw_config.update(raw_config or {}) diff --git a/dmoj/graders/base.py b/dmoj/graders/base.py index 4f60f7bfa..b77279a5f 100644 --- a/dmoj/graders/base.py +++ b/dmoj/graders/base.py @@ -1,4 +1,7 @@ -from dmoj.problem import BatchedTestCase, TestCase +from typing import List, Tuple + +from dmoj.config import InvalidInitException +from dmoj.problem import Batch, BatchedTestCase, StandaloneTestCase, TopLevelCase from dmoj.utils.unicode import utf8bytes @@ -29,37 +32,43 @@ def abort_grading(self): except OSError: pass - def _resolve_testcases(self, cfg, batch_no=0): - cases = [] - for case_config in cfg: + def _resolve_testcases(self, cfg) -> List[TopLevelCase]: + cases: List[TopLevelCase] = [] + for top_level_position, case_config in enumerate(cfg): if 'batched' in case_config.raw_config: self._batch_counter += 1 cases.append( - BatchedTestCase( + Batch( self._batch_counter, case_config, self.problem, - self._resolve_testcases(case_config['batched'], self._batch_counter), + self._resolve_batched_cases(case_config['batched'], self._batch_counter), ) ) else: - cases.append(TestCase(self._testcase_counter, batch_no, case_config, self.problem)) + cases.append(StandaloneTestCase(self._testcase_counter, case_config, self.problem)) self._testcase_counter += 1 return cases - def cases(self): - pretest_test_cases = self.problem.config.pretest_test_cases - if self.run_pretests_only and pretest_test_cases: - return self._resolve_testcases(pretest_test_cases) + def _resolve_batched_cases(self, cfg, batch_no) -> List[BatchedTestCase]: + batched_cases = [] + for case_config in cfg: + if 'batched' in case_config.raw_config: + raise InvalidInitException('nested batches') + batched_cases.append(BatchedTestCase(self._testcase_counter, batch_no, case_config, self.problem)) + return batched_cases - test_cases = self._resolve_testcases(self.problem.config.test_cases) + def cases(self) -> Tuple[List[TopLevelCase], List[TopLevelCase]]: + pretest_test_cases = self.problem.config.pretest_test_cases if pretest_test_cases: pretest_test_cases = self._resolve_testcases(pretest_test_cases) - # Hack: force short-circuiting behavior for case in pretest_test_cases: + # Hack: force short-circuiting behavior case.points = 0 + else: + pretest_test_cases = [] - test_cases = pretest_test_cases + test_cases - - return test_cases + # Important that this comes after the previous `_resolve_testcases` call, otherwise our underlying `position` values would be all wrong. + test_cases = self._resolve_testcases(self.problem.config.test_cases) + return (pretest_test_cases, test_cases) diff --git a/dmoj/judge.py b/dmoj/judge.py index 2ce1cfbc9..f08d39165 100644 --- a/dmoj/judge.py +++ b/dmoj/judge.py @@ -8,16 +8,14 @@ import traceback from enum import Enum from http.server import HTTPServer -from itertools import groupby -from operator import itemgetter -from typing import Any, Callable, Dict, Generator, List, NamedTuple, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, Generator, List, NamedTuple, Optional, Set, Tuple from dmoj import packet from dmoj.control import JudgeControlRequestHandler from dmoj.error import CompileError from dmoj.judgeenv import clear_problem_dirs_cache, env, get_supported_problems_and_mtimes, startup_warnings from dmoj.monitor import Monitor -from dmoj.problem import BatchedTestCase, Problem, TestCase +from dmoj.problem import Batch, Problem, TopLevelCase from dmoj.result import Result from dmoj.utils import builtin_int_patch from dmoj.utils.ansi import ansi_style, print_ansi, strip_ansi @@ -39,6 +37,8 @@ class IPC(Enum): RESULT = 'RESULT' BATCH_BEGIN = 'BATCH-BEGIN' BATCH_END = 'BATCH-END' + PRETEST_BEGIN = 'PRETEST-BEGIN' + PRETEST_END = 'PRETEST-END' GRADING_BEGIN = 'GRADING-BEGIN' GRADING_END = 'GRADING-END' GRADING_ABORTED = 'GRADING-ABORTED' @@ -77,6 +77,12 @@ def __init__(self, packet_manager: packet.PacketManager) -> None: self.updater_signal = threading.Event() self.updater = threading.Thread(target=self._updater_thread) + self.in_pretests = False + self.pretest_batch = 0 + self.pretest_case = 0 + self.main_batch = 0 + self.main_case = 0 + @property def current_submission(self): worker = self.current_judge_worker @@ -151,6 +157,8 @@ def _grading_thread_main(self, ipc_ready_signal: threading.Event, report) -> Non IPC.GRADING_BEGIN: self._ipc_grading_begin, IPC.GRADING_END: self._ipc_grading_end, IPC.GRADING_ABORTED: self._ipc_grading_aborted, + IPC.PRETEST_BEGIN: self._ipc_pretest_begin, + IPC.PRETEST_END: self._ipc_pretest_end, IPC.BATCH_BEGIN: self._ipc_batch_begin, IPC.BATCH_END: self._ipc_batch_end, IPC.RESULT: self._ipc_result, @@ -199,6 +207,15 @@ def _ipc_grading_begin(self, _report, is_pretested: bool) -> None: def _ipc_grading_end(self, _report) -> None: self.packet_manager.grading_end_packet() + def _ipc_pretest_begin(self, _report) -> None: + self.in_pretests = True + self.packet_manager.pretest_begin_packet() + + def _ipc_pretest_end(self, report) -> None: + self.in_pretests = False + report('') + self.packet_manager.pretest_end_packet() + def _ipc_result(self, report, batch_number: Optional[int], case_number: int, result: Result) -> None: codes = result.readable_codes() @@ -218,13 +235,29 @@ def _ipc_result(self, report, batch_number: Optional[int], case_number: int, res colored_feedback, colored_aux_codes, ) + case_padding = ' ' if batch_number is not None else '' - report(ansi_style('%sTest case %2d %-3s %s' % (case_padding, case_number, colored_codes[0], case_info))) + if self.in_pretests: + self.pretest_case += 1 + report( + ansi_style( + '%sPretest case %2d %-3s %s' % (case_padding, self.pretest_case, colored_codes[0], case_info) + ) + ) + else: + self.main_case += 1 + report(ansi_style('%sTest case %2d %-3s %s' % (case_padding, self.main_case, colored_codes[0], case_info))) + self.packet_manager.test_case_status_packet(case_number, result) - def _ipc_batch_begin(self, report, batch_number: int) -> None: + def _ipc_batch_begin(self, report, _batch_number: int) -> None: self.packet_manager.batch_begin_packet() - report(ansi_style('#ansi[Batch #%d](yellow|bold)' % batch_number)) + if self.in_pretests: + self.pretest_batch += 1 + report(ansi_style('#ansi[Pretest Batch #%d](yellow|bold)' % self.pretest_batch)) + else: + self.main_batch += 1 + report(ansi_style('#ansi[Batch #%d](yellow|bold)' % self.main_batch)) def _ipc_batch_end(self, _report, _batch_number: int) -> None: self.packet_manager.batch_end_packet() @@ -459,69 +492,92 @@ def _grade_cases(self) -> Generator[Tuple[IPC, tuple], None, None]: if hasattr(binary, 'warning') and binary.warning is not None: yield IPC.COMPILE_MESSAGE, (binary.warning,) - yield IPC.GRADING_BEGIN, (self.grader.run_pretests_only,) - - flattened_cases: List[Tuple[Optional[int], Union[TestCase, BatchedTestCase]]] = [] + skip_all: bool = False + skip_current: bool = False batch_number = 0 - batch_dependencies: List[Set[int]] = [] - for case in self.grader.cases(): - if isinstance(case, BatchedTestCase): - batch_number += 1 - for batched_case in case.batched_cases: - flattened_cases.append((batch_number, batched_case)) - batch_dependencies.append(set(case.dependencies)) - else: - flattened_cases.append((None, case)) - - case_number = 0 - is_short_circuiting = False - is_short_circuiting_enabled = self.submission.short_circuit - passed_batches: Set[int] = set() - for batch_number, cases in groupby(flattened_cases, key=itemgetter(0)): - if batch_number: - yield IPC.BATCH_BEGIN, (batch_number,) - - dependencies = batch_dependencies[batch_number - 1] # List is zero-indexed - if passed_batches & dependencies != dependencies: - is_short_circuiting = True - - for _, case in cases: - case_number += 1 - - # Stop grading if we're short circuiting - if is_short_circuiting: - result = Result(case, result_flag=Result.SC) + global_test_number = 0 + + class GradingAbort(Exception): + pass + + def grade_tests( + tests: List[TopLevelCase], should_run: Callable[[TopLevelCase, Set[Any]], bool] + ) -> Generator[Tuple[IPC, tuple], None, Set[Any]]: + nonlocal batch_number, global_test_number, skip_all, skip_current + failed_tests: Set[Any] = set() + for test in tests: + skip_current = skip_all or not should_run(test, failed_tests) + if isinstance(test, Batch): + batch_number += 1 + yield IPC.BATCH_BEGIN, (batch_number,) + for case in test.batched_cases: + global_test_number += 1 + yield IPC.RESULT, (batch_number, global_test_number, grade_one(case)) + + yield IPC.BATCH_END, (batch_number,) + else: - result = self.grader.grade(case) + global_test_number += 1 + yield IPC.RESULT, (None, global_test_number, grade_one(test)) - # If the submission was killed due to a user-initiated abort, any result is meaningless. - if self._abort_requested: - yield IPC.GRADING_ABORTED, () - return + if skip_current: + failed_tests.add(test.config.raw_config_id) + if not test.points or self.submission.short_circuit: + skip_all = True + + return failed_tests + + def grade_one(case) -> Result: + nonlocal skip_current + if skip_current: + return Result(case, result_flag=Result.SC) + else: + result = self.grader.grade(case) - if result.result_flag & Result.WA: - # If we failed a 0-point case, we will short-circuit every case after this. - is_short_circuiting_enabled |= not case.points + # If the submission was killed due to a user-initiated abort, any result is meaningless. + if self._abort_requested: + raise GradingAbort - # Short-circuit if we just failed a case in a batch, or if short-circuiting is currently enabled - # for all test cases (either this was requested by the site, or we failed a 0-point case in the - # past). - is_short_circuiting |= batch_number is not None or is_short_circuiting_enabled + if result.result_flag & result.WA: + skip_current = True # Legacy hack: we need to allow graders to read and write `proc_output` on the `Result` object, but the # judge controller only cares about the trimmed output, and shouldn't waste memory buffering the full # output. So, we trim it here so we don't run out of memory in the controller. result.proc_output = result.output - yield IPC.RESULT, (batch_number, case_number, result) + return result - if batch_number: - if not is_short_circuiting: - passed_batches.add(batch_number) + def should_run_test(case: TopLevelCase, failed_so_far: Set[int]) -> bool: + if case.dependencies is None: + # Default: depends on nothing. + return True + else: + return len(case.dependencies & failed_so_far) == 0 + + failed_pretests: Set[int] = set() - yield IPC.BATCH_END, (batch_number,) - is_short_circuiting &= is_short_circuiting_enabled + def should_run_main_test(case: TopLevelCase, failed_so_far: Set[int]) -> bool: + good_pretests = False + if case.dependencies is None: + # Default: depends on all pretests. + good_pretests = len(failed_pretests) == 0 + else: + good_pretests = len(failed_pretests & case.dependencies) == 0 + return good_pretests and should_run_test(case, failed_so_far) - yield IPC.GRADING_END, () + yield IPC.GRADING_BEGIN, (self.grader.run_pretests_only,) + pretests, main_tests = self.grader.cases() + try: + if pretests: + yield IPC.PRETEST_BEGIN, () + failed_pretests = yield from grade_tests(pretests, should_run_test) + yield IPC.PRETEST_END, () + if not self.grader.run_pretests_only: + yield from grade_tests(main_tests, should_run_main_test) + except GradingAbort: + yield IPC.GRADING_ABORTED, () + else: + yield IPC.GRADING_END, () def _do_abort(self) -> None: self._abort_requested = True diff --git a/dmoj/packet.py b/dmoj/packet.py index f88eb2ee3..765760f70 100644 --- a/dmoj/packet.py +++ b/dmoj/packet.py @@ -347,6 +347,15 @@ def grading_end_packet(self): self._flush_testcase_queue() self._send_packet({'name': 'grading-end', 'submission-id': self.judge.current_submission.id}) + def pretest_begin_packet(self): + log.debug('Begin pretests: %d', self.judge.current_submission.id) + self._send_packet({'name': 'pretest-begin', 'submission-id': self.judge.current_submission.id}) + + def pretest_end_packet(self): + log.debug('End pretests: %d', self.judge.current_submission.id) + self._flush_testcase_queue() + self._send_packet({'name': 'pretest-end', 'submission-id': self.judge.current_submission.id}) + def batch_begin_packet(self): self._batch += 1 log.debug('Enter batch number %d: %d', self._batch, self.judge.current_submission.id) diff --git a/dmoj/problem.py b/dmoj/problem.py index ef49eb721..4ac457329 100644 --- a/dmoj/problem.py +++ b/dmoj/problem.py @@ -5,6 +5,7 @@ import zipfile from collections import defaultdict from functools import partial +from typing import List import yaml from yaml.parser import ParserError @@ -218,29 +219,23 @@ def __init__(self, problem_data, meta={}): ) -class BatchedTestCase: - def __init__(self, batch_no, config, problem, cases): +class TopLevelCase: + def __init__(self, config): self.config = config - self.batch_no = batch_no self.points = config.points - self.dependencies = config.dependencies - self.batched_cases = cases - if any(isinstance(case, BatchedTestCase) for case in self.batched_cases): - raise InvalidInitException('nested batches') - self.problem = problem - if any(dependency >= batch_no for dependency in self.dependencies): - raise InvalidInitException('dependencies depends on non-earlier batch') - if any(dependency < 1 for dependency in self.dependencies): - raise InvalidInitException('dependencies must be positive integers') + if config.depends is not None: + if any(not isinstance(dependency, ConfigNode) for dependency in config.depends): + raise InvalidInitException('dependencies should use YAML references') - def __str__(self): - return 'BatchedTestCase{cases=%s}' % str(self.batched_cases) + self.dependencies = {dependency.raw_config_id for dependency in config.depends} + else: + self.dependencies = None -class TestCase: - def __init__(self, count, batch_no, config, problem): - self.position = count - self.batch = batch_no +class AbstractTestCase: + def __init__(self, position, batch, config, problem): + self.position = position + self.batch = batch self.config = config self.problem = problem self.points = config.points @@ -396,3 +391,25 @@ def __getstate__(self): def __setstate__(self, state): self.__dict__.update(state) + + +class BatchedTestCase(AbstractTestCase): + def __init__(self, position, batch, config, problem): + super().__init__(position, batch, config, problem) + + +class StandaloneTestCase(AbstractTestCase, TopLevelCase): + def __init__(self, position, config, problem): + AbstractTestCase.__init__(self, position, None, config, problem) + TopLevelCase.__init__(self, config) + + +class Batch(TopLevelCase): + def __init__(self, batch_no, config, problem, cases: List[BatchedTestCase]): + super().__init__(config) + self.batch = batch_no + self.batched_cases = cases + self.problem = problem + + def __str__(self): + return 'Batch{cases=%s}' % str(self.batched_cases) diff --git a/dmoj/testsuite.py b/dmoj/testsuite.py index cc30104df..d8f8e26eb 100644 --- a/dmoj/testsuite.py +++ b/dmoj/testsuite.py @@ -119,6 +119,12 @@ def begin_grading_packet(self, is_pretested): def grading_end_packet(self): pass + def pretest_begin_packet(self): + pass + + def pretest_end_packet(self): + pass + def batch_begin_packet(self): pass