Skip to content

Commit

Permalink
refactor: remove warmup (#6114)
Browse files Browse the repository at this point in the history
Signed-off-by: Joan Fontanals Martinez <[email protected]>
Co-authored-by: Joan Fontanals Martinez <[email protected]>
  • Loading branch information
JoanFM and JoanFM authored Nov 21, 2023
1 parent bbf21de commit 6cd3311
Show file tree
Hide file tree
Showing 9 changed files with 2 additions and 288 deletions.
65 changes: 0 additions & 65 deletions jina/serve/networking/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,71 +615,6 @@ async def task_coroutine():

return task_coroutine()

async def warmup(
self,
deployment: str,
stop_event: 'threading.Event',
):
"""Executes JinaInfoRPC against the provided deployment. A single task is created for each replica connection.
:param deployment: deployment name and the replicas that needs to be warmed up.
:param stop_event: signal to indicate if an early termination of the task is required for graceful teardown.
"""
self._logger.debug(f'starting warmup task for deployment {deployment}')

async def task_wrapper(target_warmup_responses, stub):
try:
call_result = stub.send_info_rpc(timeout=0.5)
await call_result
target_warmup_responses[stub.address] = True
except asyncio.CancelledError:
self._logger.debug(f'warmup task got cancelled')
target_warmup_responses[stub.address] = False
raise
except Exception:
target_warmup_responses[stub.address] = False


try:
start_time = time.time()
timeout = start_time + 60 * 5 # 5 minutes from now
warmed_up_targets = set()
replicas = self._get_all_replicas(deployment)

while not stop_event.is_set():
replica_warmup_responses = {}
tasks = []
try:
for replica in replicas:
for stub in replica.warmup_stubs:
if stub.address not in warmed_up_targets:
tasks.append(
asyncio.create_task(
task_wrapper(replica_warmup_responses, stub)
)
)

await asyncio.gather(*tasks, return_exceptions=True)
for target, response in replica_warmup_responses.items():
if response:
warmed_up_targets.add(target)

now = time.time()
if now > timeout or all(list(replica_warmup_responses.values())):
self._logger.debug(
f'completed warmup task in {now - start_time}s.'
)
return
await asyncio.sleep(0.2)
except asyncio.CancelledError:
self._logger.debug(f'warmup task got cancelled')
if tasks:
for task in tasks:
task.cancel()
raise
except Exception as ex:
self._logger.error(f'error with warmup up task: {ex}')
return

def _get_all_replicas(self, deployment):
replica_set = set()
replica_set.update(self._connections.get_replicas_all_shards(deployment))
Expand Down
16 changes: 0 additions & 16 deletions jina/serve/networking/replica_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ def __init__(
self.tracing_client_interceptors = tracing_client_interceptor
self._deployment_name = deployment_name
self.channel_options = channel_options
# a set containing all the ConnectionStubs that will be created using add_connection
# this set is not updated in reset_connection and remove_connection
self._warmup_stubs = set()

async def reset_connection(self, address: str, deployment_name: str):
"""
Expand Down Expand Up @@ -90,10 +87,7 @@ def add_connection(self, address: str, deployment_name: str):
stubs, channel = self._create_connection(address, deployment_name)
self._address_to_channel[resolved_address] = channel
self._connections.append(stubs)
# create a new set of stubs and channels for warmup to avoid
# loosing channel during remove_connection or reset_connection
stubs, _ = self._create_connection(address, deployment_name)
self._warmup_stubs.add(stubs)

async def remove_connection(self, address: str):
"""
Expand Down Expand Up @@ -213,13 +207,3 @@ async def close(self):
self._address_to_connection_idx.clear()
self._connections.clear()
self._rr_counter = 0
for stub in self._warmup_stubs:
await stub.channel.close(0.5)
self._warmup_stubs.clear()

@property
def warmup_stubs(self):
"""Return set of warmup stubs
:returns: Set of stubs. The set doesn't remove any items once added.
"""
return self._warmup_stubs
26 changes: 0 additions & 26 deletions jina/serve/runtimes/gateway/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __init__(
meter=None,
aio_tracing_client_interceptors=None,
tracing_client_interceptor=None,
works_as_load_balancer: bool = False,
**kwargs,
):
import json
Expand Down Expand Up @@ -102,37 +101,12 @@ def __init__(
if isinstance(addresses, Dict):
servers.extend(addresses.get(ProtocolType.HTTP.to_string(), []))
self.load_balancer_servers = itertools.cycle(servers)
self.warmup_stop_event = threading.Event()
self.warmup_task = None
if not works_as_load_balancer:
try:
self.warmup_task = asyncio.create_task(
self.streamer.warmup(self.warmup_stop_event)
)
except RuntimeError:
# when Gateway is started locally, it may not have loop
pass

def cancel_warmup_task(self):
"""Cancel warmup task if exists and is not completed. Cancellation is required if the Flow is being terminated before the
task is successful or hasn't reached the max timeout.
"""
if self.warmup_task:
try:
if not self.warmup_task.done():
self.logger.debug(f'Cancelling warmup task.')
self.warmup_stop_event.set() # this event is useless if simply cancel
self.warmup_task.cancel()
except Exception as ex:
self.logger.debug(f'exception during warmup task cancellation: {ex}')
pass

async def close(self):
"""
Gratefully closes the object making sure all the floating requests are taken care and the connections are closed gracefully
"""
self.logger.debug(f'Closing Request Handler')
self.cancel_warmup_task()
await self.streamer.close()
self.logger.debug(f'Request Handler closed')

Expand Down
32 changes: 0 additions & 32 deletions jina/serve/runtimes/gateway/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,38 +427,6 @@ def get_streamer():
def _set_env_streamer_args(**kwargs):
os.environ['JINA_STREAMER_ARGS'] = json.dumps(kwargs)

async def warmup(self, stop_event: threading.Event):
"""Executes warmup task on each deployment. This forces the gateway to establish connection and open a
gRPC channel to each executor so that the first request doesn't need to experience the penalty of
eastablishing a brand new gRPC channel.
:param stop_event: signal to indicate if an early termination of the task is required for graceful teardown.
"""
self.logger.debug(f'Running GatewayRuntime warmup')
deployments = {key for key in self._executor_addresses.keys()}

try:
deployment_warmup_tasks = []
try:
for deployment in deployments:
deployment_warmup_tasks.append(
asyncio.create_task(
self._connection_pool.warmup(
deployment=deployment, stop_event=stop_event
)
)
)

await asyncio.gather(*deployment_warmup_tasks, return_exceptions=True)
except asyncio.CancelledError:
self.logger.debug(f'Warmup task got cancelled')
if deployment_warmup_tasks:
for task in deployment_warmup_tasks:
task.cancel()
raise
except Exception as ex:
self.logger.error(f'error with GatewayRuntime warmup up task: {ex}')
return


class _ExecutorStreamer:
def __init__(self, connection_pool: GrpcConnectionPool, executor_name: str) -> None:
Expand Down
42 changes: 0 additions & 42 deletions jina/serve/runtimes/head/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,6 @@ def __init__(
self._executor_endpoint_mapping = None
self._gathering_endpoints = False
self.runtime_name = runtime_name
self.warmup_stop_event = threading.Event()
self.warmup_task = asyncio.create_task(
self.warmup(
connection_pool=self.connection_pool,
stop_event=self.warmup_stop_event,
deployment=self._deployment_name,
)
)
self._pydantic_models_by_endpoint = None
self.endpoints_discovery_stop_event = threading.Event()
self.endpoints_discovery_task = None
Expand Down Expand Up @@ -383,39 +375,6 @@ async def task():

return task()

async def warmup(
self,
connection_pool: GrpcConnectionPool,
stop_event: 'threading.Event',
deployment: str,
):
"""Executes warmup task against the deployments from the connection pool.
:param connection_pool: GrpcConnectionPool that implements the warmup to the connected deployments.
:param stop_event: signal to indicate if an early termination of the task is required for graceful teardown.
:param deployment: deployment name that need to be warmed up.
"""
self.logger.debug(f'Running HeadRuntime warmup')

try:
await connection_pool.warmup(deployment=deployment, stop_event=stop_event)
except Exception as ex:
self.logger.error(f'error with HeadRuntime warmup up task: {ex}')
return

def cancel_warmup_task(self):
"""Cancel warmup task if exists and is not completed. Cancellation is required if the Flow is being terminated before the
task is successful or hasn't reached the max timeout.
"""
if self.warmup_task:
try:
if not self.warmup_task.done():
self.logger.debug(f'Cancelling warmup task.')
self.warmup_stop_event.set() # this event is useless if simply cancel
self.warmup_task.cancel()
except Exception as ex:
self.logger.debug(f'exception during warmup task cancellation: {ex}')
pass

def cancel_endpoint_discovery_from_workers_task(self):
"""Cancel endpoint_discovery_from_worker task if exists and is not completed. Cancellation is required if the Flow is being terminated before the
task is successful or hasn't reached the max timeout.
Expand All @@ -433,7 +392,6 @@ def cancel_endpoint_discovery_from_workers_task(self):
async def close(self):
"""Close the data request handler, by closing the executor and the batch queues."""
self.logger.debug(f'Closing Request Handler')
self.cancel_warmup_task()
self.cancel_endpoint_discovery_from_workers_task()
await self.connection_pool.close()
self.logger.debug(f'Request Handler closed')
Expand Down
1 change: 0 additions & 1 deletion jina/serve/runtimes/servers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def _get_request_handler(self):
aio_tracing_client_interceptors=self.aio_tracing_client_interceptors(),
tracing_client_interceptor=self.tracing_client_interceptor(),
deployment_name=self.name.split('/')[0],
works_as_load_balancer=self.works_as_load_balancer,
)

def _add_gateway_args(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ def test_multiprotocol_gateway_instrumentation(
(server_spans, client_spans, executor_spans) = partition_spans_by_kind(
gateway_traces
)
assert len(client_spans) == 11
assert len(server_spans) == 12
assert len(client_spans) == 9
assert len(server_spans) == 10


def test_executor_instrumentation(jaeger_port, otlp_collector, otlp_receiver_port):
Expand Down
100 changes: 0 additions & 100 deletions tests/integration/runtimes/test_warmup.py

This file was deleted.

4 changes: 0 additions & 4 deletions tests/unit/serve/networking/test_replica_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def test_add_connection(replica_list):
replica_list.add_connection('executor0', 'executor-0')
assert replica_list.has_connections()
assert replica_list.has_connection('executor0')
assert len(replica_list.warmup_stubs)
assert not replica_list.has_connection('random-address')
assert len(replica_list.get_all_connections()) == 1

Expand All @@ -34,8 +33,6 @@ async def test_remove_connection(replica_list):
await replica_list.remove_connection('executor0')
assert not replica_list.has_connections()
assert not replica_list.has_connection('executor0')
# warmup stubs are not updated in the remove_connection method
assert len(replica_list.warmup_stubs)
# unknown/unmanaged connections
removed_connection_invalid = await replica_list.remove_connection('random-address')
assert removed_connection_invalid is None
Expand Down Expand Up @@ -64,7 +61,6 @@ async def test_close(replica_list):
assert replica_list.has_connection('executor1')
await replica_list.close()
assert not replica_list.has_connections()
assert not len(replica_list.warmup_stubs)


async def _print_channel_attributes(connection_stub: _ConnectionStubs):
Expand Down

0 comments on commit 6cd3311

Please sign in to comment.