Skip to content

Commit

Permalink
Simplify API & fix case
Browse files Browse the repository at this point in the history
  • Loading branch information
wjsi committed Jul 16, 2022
1 parent 9ad836a commit 6c5199b
Show file tree
Hide file tree
Showing 12 changed files with 62 additions and 58 deletions.
2 changes: 2 additions & 0 deletions mars/deploy/oscar/tests/fault_injection_config_with_rerun.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ scheduling:
storage:
# shared-memory38 may lose object if the process crash after put success.
backends: [plasma]
plasma:
store_memory: 32M
10 changes: 3 additions & 7 deletions mars/deploy/oscar/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from ....storage import StorageLevel
from ....services.storage import StorageAPI
from ....tensor.arithmetic.add import TensorAdd
from ....tests.core import mock, check_dict_structure_same, DICT_NOT_EMPTY
from ....tests.core import mock, DICT_NOT_EMPTY
from ..local import new_cluster, _load_config
from ..session import (
get_default_async_session,
Expand Down Expand Up @@ -93,8 +93,8 @@
"serialization": {},
"most_calls": DICT_NOT_EMPTY,
"slow_calls": DICT_NOT_EMPTY,
# "band_subtasks": DICT_NOT_EMPTY,
# "slow_subtasks": DICT_NOT_EMPTY,
"band_subtasks": {},
"slow_subtasks": {},
}
}
EXPECT_PROFILING_STRUCTURE_NO_SLOW = copy.deepcopy(EXPECT_PROFILING_STRUCTURE)
Expand Down Expand Up @@ -263,10 +263,6 @@ async def test_execute(create_cluster, config):

info = await session.execute(b, extra_config=extra_config)
await info
if extra_config:
check_dict_structure_same(info.profiling_result(), expect_profiling_structure)
else:
assert not info.profiling_result()
assert info.result() is None
assert info.exception() is None
assert info.progress() == 1
Expand Down
4 changes: 2 additions & 2 deletions mars/deploy/oscar/tests/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@
},
"most_calls": DICT_NOT_EMPTY,
"slow_calls": DICT_NOT_EMPTY,
"band_subtasks": DICT_NOT_EMPTY,
"slow_subtasks": DICT_NOT_EMPTY,
"band_subtasks": {},
"slow_subtasks": {},
}
}
EXPECT_PROFILING_STRUCTURE_NO_SLOW = copy.deepcopy(EXPECT_PROFILING_STRUCTURE)
Expand Down
9 changes: 7 additions & 2 deletions mars/deploy/oscar/tests/test_ray_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import logging
import os
Expand All @@ -28,6 +29,7 @@
process_placement_to_address,
kill_and_wait,
)
from ....oscar.backends.router import Router
from ....services.cluster import ClusterAPI
from ....services.scheduling.supervisor.autoscale import AutoscalerActor
from ....tests.core import require_ray
Expand Down Expand Up @@ -62,8 +64,11 @@ async def speculative_cluster():
},
},
)
async with client:
yield client
try:
async with client:
yield client
finally:
Router.set_instance(None)


@pytest.mark.parametrize("ray_large_cluster", [{"num_nodes": 2}], indirect=True)
Expand Down
11 changes: 8 additions & 3 deletions mars/oscar/backends/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,14 @@ async def destroy_actor(self, actor_ref: ActorRef):
message = DestroyActorMessage(
new_message_id(), actor_ref, protocol=DEFAULT_PROTOCOL
)
future = await self._call(actor_ref.address, message, wait=False)
result = await self._wait(future, actor_ref.address, message)
return self._process_result_message(result)
try:
future = await self._call(actor_ref.address, message, wait=False)
result = await self._wait(future, actor_ref.address, message)
return self._process_result_message(result)
except ConnectionRefusedError:
# when remote server already destroyed,
# we assume all actors destroyed already
pass

async def kill_actor(self, actor_ref: ActorRef, force: bool = True):
# get main_pool_address
Expand Down
18 changes: 5 additions & 13 deletions mars/services/scheduling/api/oscar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from .... import oscar as mo
from ....lib.aio import alru_cache
from ...subtask import Subtask, SubtaskResult
from ...subtask import Subtask
from ..core import SubtaskScheduleSummary
from .core import AbstractSchedulingAPI

Expand Down Expand Up @@ -99,7 +99,6 @@ async def cancel_subtasks(
self,
subtask_ids: List[str],
kill_timeout: Union[float, int] = None,
wait: bool = False,
):
"""
Cancel pending and running subtasks.
Expand All @@ -111,18 +110,11 @@ async def cancel_subtasks(
kill_timeout
timeout seconds to kill actor process forcibly
"""
if wait:
await self._manager_ref.cancel_subtasks(
subtask_ids, kill_timeout=kill_timeout
)
else:
await self._manager_ref.cancel_subtasks.tell(
subtask_ids, kill_timeout=kill_timeout
)
await self._manager_ref.cancel_subtasks(subtask_ids, kill_timeout=kill_timeout)

async def finish_subtasks(
self,
subtask_results: List[SubtaskResult],
subtask_ids: List[str],
bands: List[Tuple] = None,
schedule_next: bool = True,
):
Expand All @@ -132,14 +124,14 @@ async def finish_subtasks(
Parameters
----------
subtask_results
subtask_ids
results of subtasks, must in finished states
bands
bands of subtasks to mark as finished
schedule_next
whether to schedule succeeding subtasks
"""
await self._manager_ref.finish_subtasks(subtask_results, bands, schedule_next)
await self._manager_ref.finish_subtasks.tell(subtask_ids, bands, schedule_next)


class MockSchedulingAPI(SchedulingAPI):
Expand Down
42 changes: 24 additions & 18 deletions mars/services/scheduling/supervisor/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,15 @@ async def _get_execution_ref(self, band: BandType):
async def set_subtask_result(self, result: SubtaskResult, band: BandType):
info = self._subtask_infos[result.subtask_id]
subtask_id = info.subtask.subtask_id
notify_task_service = True
notify_task_service = False

async with redirect_subtask_errors(self, [info.subtask], reraise=False):
try:
info.band_futures[band].set_result(result)
if result.error is not None:
raise result.error.with_traceback(result.traceback)
logger.debug("Finished subtask %s with result %s.", subtask_id, result)
notify_task_service = True
except (OSError, MarsError) as ex:
# TODO: We should handle ServerClosed Error.
if (
Expand All @@ -200,7 +201,6 @@ async def set_subtask_result(self, result: SubtaskResult, band: BandType):
[info.subtask.priority or tuple()],
exclude_bands=set(info.band_futures.keys()),
)
notify_task_service = False
else:
raise ex
except asyncio.CancelledError:
Expand Down Expand Up @@ -244,16 +244,14 @@ async def set_subtask_result(self, result: SubtaskResult, band: BandType):

async def finish_subtasks(
self,
subtask_results: List[SubtaskResult],
subtask_ids: List[str],
bands: List[BandType] = None,
schedule_next: bool = True,
):
subtask_ids = [result.subtask_id for result in subtask_results]
logger.debug("Finished subtasks %s.", subtask_ids)
band_tasks = defaultdict(lambda: 0)
bands = bands or [None] * len(subtask_ids)
for result, subtask_band in zip(subtask_results, bands):
subtask_id = result.subtask_id
for subtask_id, subtask_band in zip(subtask_ids, bands):
subtask_info = self._subtask_infos.get(subtask_id, None)

if subtask_info is not None:
Expand All @@ -265,13 +263,16 @@ async def finish_subtasks(
"stage_id": subtask_info.subtask.stage_id,
},
)
self._subtask_summaries[subtask_id] = subtask_info.to_summary(
is_finished=True,
is_cancelled=result.status == SubtaskStatus.cancelled,
)
if subtask_id not in self._subtask_summaries:
summary_kw = dict(is_finished=True)
if subtask_info.cancel_pending:
summary_kw["is_cancelled"] = True
self._subtask_summaries[subtask_id] = subtask_info.to_summary(
**summary_kw
)
subtask_info.end_time = time.time()
self._speculation_execution_scheduler.finish_subtask(subtask_info)
# Cancel subtask on other bands.
# Cancel subtask on other bands.
aio_task = subtask_info.band_futures.pop(subtask_band, None)
if aio_task:
yield aio_task
Expand Down Expand Up @@ -321,7 +322,7 @@ async def batch_submit_subtask_to_band(self, args_list, kwargs_list):
if info.cancel_pending:
res_release_delays.append(
self._global_resource_ref.release_subtask_resource.delay(
band, info.subtask.session_id, info.subtask.subtask_id
band, self._session_id, subtask_id
)
)
continue
Expand All @@ -330,6 +331,12 @@ async def batch_submit_subtask_to_band(self, args_list, kwargs_list):
"Subtask %s is not in added subtasks set, it may be finished or canceled, skip it.",
subtask_id,
)
# in case resource already allocated, do deallocate
res_release_delays.append(
self._global_resource_ref.release_subtask_resource.delay(
band, self._session_id, subtask_id
)
)
continue
band_to_subtask_ids[band].append(subtask_id)

Expand Down Expand Up @@ -414,9 +421,8 @@ async def cancel_task_in_band(band):

info = self._subtask_infos[subtask_id]
info.cancel_pending = True
raw_tasks_to_cancel = list(info.band_futures.values())

if not raw_tasks_to_cancel:
if not info.band_futures:
# not submitted yet: mark subtasks as cancelled
result = SubtaskResult(
subtask_id=info.subtask.subtask_id,
Expand All @@ -435,13 +441,13 @@ async def cancel_task_in_band(band):
)
band_to_futures[band].append(future)

for band in band_to_futures:
cancel_tasks.append(asyncio.create_task(cancel_task_in_band(band)))

# Dequeue first as it is possible to leak subtasks from queues
if queued_subtask_ids:
# Don't use `finish_subtasks` because it may remove queued
await self._queueing_ref.remove_queued_subtasks(queued_subtask_ids)

for band in band_to_futures:
cancel_tasks.append(asyncio.create_task(cancel_task_in_band(band)))

if cancel_tasks:
yield asyncio.gather(*cancel_tasks)

Expand Down
4 changes: 3 additions & 1 deletion mars/services/scheduling/tests/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ async def set_subtask_result(self, subtask_result: SubtaskResult):
for event in self._events[subtask_result.subtask_id]:
event.set()
self._events.pop(subtask_result.subtask_id, None)
await scheduling_api.finish_subtasks([subtask_result], subtask_result.bands)
await scheduling_api.finish_subtasks(
[subtask_result.subtask_id], subtask_result.bands
)

def _return_result(self, subtask_id: str):
result = self._results[subtask_id]
Expand Down
8 changes: 1 addition & 7 deletions mars/services/scheduling/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,10 @@
from typing import Iterable

from ... import oscar as mo
from ...lib.aio import alru_cache
from ..subtask import Subtask, SubtaskResult, SubtaskStatus
from ..task import TaskAPI


@alru_cache
async def _get_task_api(actor: mo.Actor):
return await TaskAPI.create(getattr(actor, "_session_id"), actor.address)


@contextlib.asynccontextmanager
async def redirect_subtask_errors(
actor: mo.Actor, subtasks: Iterable[Subtask], reraise: bool = True
Expand All @@ -41,7 +35,7 @@ async def redirect_subtask_errors(
if isinstance(error, asyncio.CancelledError)
else SubtaskStatus.errored
)
task_api = await _get_task_api(actor)
task_api = await TaskAPI.create(getattr(actor, "_session_id"), actor.address)
coros = []
for subtask in subtasks:
if subtask is None: # pragma: no cover
Expand Down
9 changes: 5 additions & 4 deletions mars/services/task/execution/mars/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ async def set_subtask_result(self, result: SubtaskResult, band: BandType = None)
if all_done or error_or_cancelled:
# tell scheduling to finish subtasks
await self._scheduling_api.finish_subtasks(
[result], bands=[band], schedule_next=not error_or_cancelled
[result.subtask_id], bands=[band], schedule_next=not error_or_cancelled
)
if self.result.status != TaskStatus.terminated:
self.result = TaskResult(
Expand Down Expand Up @@ -184,8 +184,7 @@ async def set_subtask_result(self, result: SubtaskResult, band: BandType = None)
)
# if error or cancel, cancel all submitted subtasks
await self._scheduling_api.cancel_subtasks(
list(self._submitted_subtask_ids),
wait=False,
list(self._submitted_subtask_ids)
)
self._schedule_done()
cost_time_secs = self.result.end_time - self.result.start_time
Expand Down Expand Up @@ -219,7 +218,9 @@ async def set_subtask_result(self, result: SubtaskResult, band: BandType = None)
# all predecessors finished
to_schedule_subtasks.append(succ_subtask)
await self._schedule_subtasks(to_schedule_subtasks)
await self._scheduling_api.finish_subtasks([result], bands=[band])
await self._scheduling_api.finish_subtasks(
[result.subtask_id], bands=[band]
)

async def run(self):
if len(self.subtask_graph) == 0:
Expand Down
1 change: 1 addition & 0 deletions mars/services/task/tests/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def f1():
await asyncio.sleep(0.5)
with Timer() as timer:
await task_api.cancel_task(task_id)
await asyncio.sleep(0.5)
result = await task_api.get_task_result(task_id)
assert result.status == TaskStatus.terminated
assert timer.duration < 20
Expand Down
2 changes: 1 addition & 1 deletion mars/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def __call__(self, *args, **kwargs):
assert get_func_token_values(func) == [func]


@pytest.mark.parametrize("id_length", [0, 5, 32, 63])
@pytest.mark.parametrize("id_length", [0, 5, 32, 63, 254])
def test_gen_random_id(id_length):
rnd_id = utils.new_random_id(id_length)
assert len(rnd_id) == id_length
Expand Down

0 comments on commit 6c5199b

Please sign in to comment.