From 9ad836a4917b177f9e68d2fbaa7fd2c462c2f6d7 Mon Sep 17 00:00:00 2001 From: Wenjun Si Date: Wed, 8 Jun 2022 16:19:16 +0800 Subject: [PATCH] black --- .../services/scheduling/supervisor/manager.py | 32 ++++++++++++------- .../supervisor/tests/test_manager.py | 8 +++-- mars/services/scheduling/worker/execution.py | 4 ++- 3 files changed, 29 insertions(+), 15 deletions(-) diff --git a/mars/services/scheduling/supervisor/manager.py b/mars/services/scheduling/supervisor/manager.py index 6ea910a8e2..d860ef96e9 100644 --- a/mars/services/scheduling/supervisor/manager.py +++ b/mars/services/scheduling/supervisor/manager.py @@ -172,9 +172,7 @@ async def _get_execution_ref(self, band: BandType): return await mo.actor_ref(SubtaskExecutionActor.default_uid(), address=band[0]) - async def set_subtask_result( - self, result: SubtaskResult, band: BandType - ): + async def set_subtask_result(self, result: SubtaskResult, band: BandType): info = self._subtask_infos[result.subtask_id] subtask_id = info.subtask.subtask_id notify_task_service = True @@ -346,6 +344,7 @@ async def batch_submit_subtask_to_band(self, args_list, kwargs_list): async def _submit_subtasks_to_band(self, band: BandType, subtask_ids: List[str]): execution_ref = await self._get_execution_ref(band) delays = [] + task_stage_count = defaultdict(lambda: 0) async with redirect_subtask_errors( self, self._get_subtasks_by_ids(subtask_ids) @@ -353,21 +352,30 @@ async def _submit_subtasks_to_band(self, band: BandType, subtask_ids: List[str]) for subtask_id in subtask_ids: subtask_info = self._subtask_infos[subtask_id] subtask = subtask_info.subtask - self._submitted_subtask_count.record( - 1, - { - "session_id": self._session_id, - "task_id": subtask.task_id, - "stage_id": subtask.stage_id, - }, - ) - logger.debug("Start run subtask %s in band %s.", subtask_id, band) + task_stage_count[(subtask.task_id, subtask.stage_id)] += 1 delays.append( execution_ref.run_subtask.delay(subtask, band[1], self.address) ) subtask_info.band_futures[band] = asyncio.Future() subtask_info.start_time = time.time() self._speculation_execution_scheduler.add_subtask(subtask_info) + + for (task_id, stage_id), cnt in task_stage_count.items(): + self._submitted_subtask_count.record( + cnt, + { + "session_id": self._session_id, + "task_id": task_id, + "stage_id": stage_id, + }, + ) + + logger.debug( + "Start run %d subtasks %r in band %s.", + len(subtask_ids), + subtask_ids, + band, + ) await execution_ref.run_subtask.batch(*delays, send=False) async def cancel_subtasks( diff --git a/mars/services/scheduling/supervisor/tests/test_manager.py b/mars/services/scheduling/supervisor/tests/test_manager.py index afa7136cd0..465859f7af 100644 --- a/mars/services/scheduling/supervisor/tests/test_manager.py +++ b/mars/services/scheduling/supervisor/tests/test_manager.py @@ -109,12 +109,16 @@ async def task_fun(): result.status = SubtaskStatus.cancelled result.error = ex result.traceback = ex.__traceback__ - await manager_ref.set_subtask_result.tell(result, (self.address, band_name)) + await manager_ref.set_subtask_result.tell( + result, (self.address, band_name) + ) raise else: result.status = SubtaskStatus.succeeded result.execution_end_time = time.time() - await manager_ref.set_subtask_result.tell(result, (self.address, band_name)) + await manager_ref.set_subtask_result.tell( + result, (self.address, band_name) + ) self._subtask_aiotasks[subtask.subtask_id][band_name] = asyncio.create_task( task_fun() diff --git a/mars/services/scheduling/worker/execution.py b/mars/services/scheduling/worker/execution.py index e108c3e9e5..23f7e9b159 100644 --- a/mars/services/scheduling/worker/execution.py +++ b/mars/services/scheduling/worker/execution.py @@ -572,7 +572,9 @@ async def subtask_caller(): manager_ref = await self._get_manager_ref( subtask.session_id, supervisor_address ) - await manager_ref.set_subtask_result.tell(res, (self.address, band_name)) + await manager_ref.set_subtask_result.tell( + res, (self.address, band_name) + ) finally: self._subtask_info.pop(subtask_id, None) self._finished_subtask_count.record(1, {"band": self.address})