Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add utilities in monitor.py to monitor message-related informations #705

Open
wants to merge 2 commits into
base: llm
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions federatedscope/core/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from collections import deque

from federatedscope.core.monitors.monitor import Monitor
from federatedscope.core.proto import gRPC_comm_manager_pb2, \
gRPC_comm_manager_pb2_grpc
from federatedscope.core.gRPC_server import gRPCComServeFunc
Expand Down Expand Up @@ -44,6 +45,7 @@ def get_neighbors(self, neighbor_id=None):
# Get all neighbors
return self.neighbors

@Monitor.efficiency_comp_message_send_time
def send(self, message):
# All the workers share one comm_queue
self.comm_queue.append(message)
Expand Down Expand Up @@ -105,7 +107,12 @@ class gRPCCommManager(object):
The implementation of gRPCCommManager is referred to the tutorial on
https://grpc.io/docs/languages/python/
"""
def __init__(self, host='0.0.0.0', port='50050', client_num=2, cfg=None):
def __init__(self,
host='0.0.0.0',
port='50050',
client_num=2,
cfg=None,
monitor=None):
self.host = host
self.port = port
options = [
Expand All @@ -128,7 +135,8 @@ def __init__(self, host='0.0.0.0', port='50050', client_num=2, cfg=None):
port=port,
options=options)
self.neighbors = dict()
self.monitor = None # used to track the communication related metrics
self.monitor = monitor
# used to track the communication related metrics

def serve(self, max_workers, host, port, options):
"""
Expand Down Expand Up @@ -169,6 +177,7 @@ def get_neighbors(self, neighbor_id=None):
# Get all neighbors
return self.neighbors

@Monitor.efficiency_comp_message_send_time
def _send(self, receiver_address, message):
def _create_stub(receiver_address):
"""
Expand Down
8 changes: 8 additions & 0 deletions federatedscope/core/configs/cfg_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ def extend_evaluation_cfg(cfg):
cfg.wandb.online_track = True
cfg.wandb.client_train_info = False

# ---------------------------------------------------------------------- #
# efficiency related options # This works only for FS-LLM temporarily.
# ---------------------------------------------------------------------- #

cfg.eval.efficiency = CN()
cfg.eval.efficiency.use = False
cfg.eval.efficiency.freq = 1

# --------------- register corresponding check function ----------
cfg.register_cfg_check_fun(assert_evaluation_cfg)

Expand Down
55 changes: 55 additions & 0 deletions federatedscope/core/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

from federatedscope.core.auxiliaries.utils import b64serializer
from federatedscope.core.proto import gRPC_comm_manager_pb2
from federatedscope.core.monitors.monitor import Monitor
from federatedscope.core.compression import (symmetric_uniform_quantization,
symmetric_uniform_dequantization)


class Message(object):
Expand Down Expand Up @@ -263,3 +266,55 @@ def count_bytes(self):
list) else 1
upload_bytes = download_bytes * upload_cnt
return download_bytes, upload_bytes

@Monitor.efficiency_comp_message_compression_time
def quantization(content,
role=None,
model_num=None,
msg_type=None,
flag=None,
method=None,
nbits=None,
monitor=None):
if role == 'server':
if (msg_type == 'model_para' and flag and method == 'uniform'):
if model_num > 1:
content = [
symmetric_uniform_quantization(x, nbits)
for x in content
]
else:
content = symmetric_uniform_quantization(content, nbits)
elif role == 'client':
if method == 'uniform':
if isinstance(content, list):
content = [
symmetric_uniform_quantization(x, nbits)
for x in content
]
else:
content = symmetric_uniform_quantization(content, nbits)
return content

@Monitor.efficiency_comp_message_compression_time
def dequantization(content, role=None, method=None, monitor=None):
if role == 'server':
if method == 'uniform':
if isinstance(content[1], list): # multiple model
sample_size = content[0]
quant_model = [
symmetric_uniform_dequantization(x) for x in content[1]
]
else:
sample_size = content[0]
quant_model = symmetric_uniform_dequantization(content[1])
content = (sample_size, quant_model)
elif role == 'client':
if method == 'uniform':
if isinstance(content, list): # multiple model
content = [
symmetric_uniform_dequantization(x) for x in content
]
else:
content = symmetric_uniform_dequantization(content)
return content
129 changes: 129 additions & 0 deletions federatedscope/core/monitors/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import gzip
import shutil
import datetime
import sys
import time
import psutil
from collections import defaultdict
from importlib import import_module

Expand Down Expand Up @@ -110,6 +113,132 @@ def __init__(self, cfg, monitored_object=None):
"cfg.wandb.use=True but not install the wandb package")
exit()

self.efficiency_memory = 0
self.efficiency_gpu = 0

self.efficiency_round_start_time = 0
self.efficiency_round_training_time = 0
self.efficiency_total_training_time = 0

self.efficiency_message_compression_time = 0
self.efficiency_message_send_time = 0
self.efficiency_total_message_compression_time = 0
self.efficiency_total_message_send_time = 0

def efficiency_compare(func):
"""
Decorate functions in trainer to get memory, gpu consumption
"""
def wrapper(*args, **kwargs):
if args[-1].monitor.cfg.eval.efficiency.use:
func(*args, **kwargs)
efficiency_memory = round(
psutil.Process(os.getpid()).memory_info().rss / 1024 /
1024 / 1024, 2)
efficiency_gpu = torch.cuda.max_memory_allocated(
0) / 1024 / 1024 / 1024

args[-1].monitor.efficiency_memory = max(
args[-1].monitor.efficiency_memory, efficiency_memory)
args[-1].monitor.efficiency_gpu = max(
args[-1].monitor.efficiency_gpu, efficiency_gpu)
else:
return func(*args, **kwargs)

return wrapper

def efficiency_training_start_time(func):
"""
Decorate the start function in trainer
to get the starting time for training
"""
def wrapper(*args, **kwargs):
if args[-1].monitor.cfg.eval.efficiency.use:
args[-1].monitor.efficiency_round_start_time = time.time()
return func(*args, **kwargs)

return wrapper

def efficiency_training_end_time(func):
"""
Decorate the end function in trainer to get the end time for training,
and get the total training time
"""
def wrapper(*args, **kwargs):
if args[-1].monitor.cfg.eval.efficiency.use:
res = func(*args, **kwargs)
args[-1].monitor.efficiency_round_training_time = time.time(
) - args[-1].monitor.efficiency_round_start_time
args[-1].monitor.efficiency_total_training_time += args[
-1].monitor.efficiency_round_training_time
return res
else:
return func(*args, **kwargs)

return wrapper

def efficiency_comp_message_compression_time(func):
"""
Decorate the message-compression functions in message.py
to get the time for message compression
"""
def wrapper(*args, **kwargs):
if kwargs['monitor'].cfg.eval.efficiency.use:
start = time.time()
res = func(*args, **kwargs)
compression_time = time.time() - start
if kwargs['monitor']:
kwargs[
'monitor'].efficiency_total_message_compression_time \
+= compression_time
return res
else:
return func(*args, **kwargs)

return wrapper

def efficiency_comp_message_send_time(func):
"""
Decorate message-sending functions in communication.py
to get the time for message sending
Note: in standalone mode,
we simulate the behavior for sening messages, i.e.,
we assume the bandwidth of the network to be 100Mib/s, thus
the sending time equals to sys.getsizeof(message) / 1024 / 100 S.
"""
def wrapper(*args, **kwargs):
if args[0].monitor.cfg.eval.efficiency.use:
if args[0].monitor.cfg.federate.mode == 'standalone':
args[0].monitor.efficiency_total_message_send_time +=\
sys.getsizeof(args[1]) / 1024 / 100
func(*args, **kwargs)
elif args[0].monitor.cfg.federate.mode == 'distributed':
start_time = time.time()
func(*args, **kwargs)
efficiency_message_send_time = time.time() - start_time
args[0].monitor.efficiency_total_message_send_time +=\
efficiency_message_send_time
else:
return func(*args, **kwargs)

return wrapper

def format_efficiency_result(self, rnd, role=-1):
res_dict = dict()
res_dict['Role'] = role
res_dict['Round'] = rnd
res_dict['Round_training_time'] = str(
self.efficiency_round_training_time) + ' S'
res_dict['Total_training_time'] = str(
self.efficiency_total_training_time) + ' S'
res_dict['Total_message_compression_time'] = str(
self.efficiency_total_message_compression_time) + ' S'
res_dict['Total_message_send_time'] = str(
self.efficiency_total_message_send_time) + ' S'
res_dict['Max memory usage'] = str(self.efficiency_memory) + ' GB'
res_dict['Max GPU usage'] = str(self.efficiency_gpu) + ' GB'
return res_dict

def eval(self, ctx):
"""
Evaluates the given context with ``metric_calculator``.
Expand Down
66 changes: 44 additions & 22 deletions federatedscope/core/workers/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from federatedscope.core.auxiliaries.utils import merge_dict_of_results, \
calculate_time_cost, add_prefix_to_path, get_ds_rank
from federatedscope.core.workers.base_client import BaseClient
from federatedscope.core.monitors.monitor import Monitor

logger = logging.getLogger(__name__)
if get_ds_rank() == 0:
Expand Down Expand Up @@ -175,7 +176,8 @@ def __init__(self,
host=host,
port=port,
client_num=self._cfg.federate.client_num,
cfg=self._cfg.distribute)
cfg=self._cfg.distribute,
monitor=self._monitor)
logger.info('Client: Listen to {}:{}...'.format(host, port))
self.comm_manager.add_neighbors(neighbor_id=server_id,
address={
Expand Down Expand Up @@ -303,16 +305,22 @@ def callback_funcs_for_model_para(self, message: Message):
timestamp = message.timestamp
content = message.content

content = Message.dequantization(
content=content,
role='client',
method=self._cfg.quantization.method,
monitor=self._monitor)

# dequantization
if self._cfg.quantization.method == 'uniform':
from federatedscope.core.compression import \
symmetric_uniform_dequantization
if isinstance(content, list): # multiple model
content = [
symmetric_uniform_dequantization(x) for x in content
]
else:
content = symmetric_uniform_dequantization(content)
# if self._cfg.quantization.method == 'uniform':
# from federatedscope.core.compression import \
# symmetric_uniform_dequantization
# if isinstance(content, list): # multiple model
# content = [
# symmetric_uniform_dequantization(x) for x in content
# ]
# else:
# content = symmetric_uniform_dequantization(content)

# When clients share the local model, we must set strict=True to
# ensure all the model params (which might be updated by other
Expand Down Expand Up @@ -417,19 +425,25 @@ def callback_funcs_for_model_para(self, message: Message):
else:
shared_model_para = model_para_all

shared_model_para = Message.quantization(
content=shared_model_para,
role='client',
method=self._cfg.quantization.method,
nbits=self._cfg.quantization.nbits,
monitor=self._monitor)
# quantization
if self._cfg.quantization.method == 'uniform':
from federatedscope.core.compression import \
symmetric_uniform_quantization
nbits = self._cfg.quantization.nbits
if isinstance(shared_model_para, list):
shared_model_para = [
symmetric_uniform_quantization(x, nbits)
for x in shared_model_para
]
else:
shared_model_para = symmetric_uniform_quantization(
shared_model_para, nbits)
# if self._cfg.quantization.method == 'uniform':
# from federatedscope.core.compression import \
# symmetric_uniform_quantization
# nbits = self._cfg.quantization.nbits
# if isinstance(shared_model_para, list):
# shared_model_para = [
# symmetric_uniform_quantization(x, nbits)
# for x in shared_model_para
# ]
# else:
# shared_model_para = symmetric_uniform_quantization(
# shared_model_para, nbits)

self.comm_manager.send(
Message(msg_type='model_para',
Expand All @@ -440,6 +454,14 @@ def callback_funcs_for_model_para(self, message: Message):
init_timestamp=timestamp,
instance_number=sample_size),
content=(sample_size, shared_model_para)))
if ((self._cfg.eval.efficiency.use
and self._cfg.eval.efficiency.freq > 0
and self.state % self._cfg.eval.efficiency.freq == 0) or
(self._cfg.eval.efficiency.use
and self.state == self._cfg.federate.total_round_num)):
logger.info(
self._monitor.format_efficiency_result(
rnd=self.state, role='Client #{}'.format(self.ID)))

def callback_funcs_for_assign_id(self, message: Message):
"""
Expand Down
Loading