From 70cb07b85fc8379043f9cd8c6c9e9c312ca37e16 Mon Sep 17 00:00:00 2001 From: "gaodawei.gdw" Date: Wed, 21 Sep 2022 16:43:17 +0800 Subject: [PATCH 1/7] add MSGBUFFER in enums.py; modify function `check_and_move_on` to support new messages --- federatedscope/core/auxiliaries/enums.py | 6 ++ federatedscope/core/workers/server.py | 108 ++++++++++++----------- 2 files changed, 61 insertions(+), 53 deletions(-) diff --git a/federatedscope/core/auxiliaries/enums.py b/federatedscope/core/auxiliaries/enums.py index ffaf0808a..7f6d1e805 100644 --- a/federatedscope/core/auxiliaries/enums.py +++ b/federatedscope/core/auxiliaries/enums.py @@ -11,6 +11,12 @@ class MODE: FINETUNE = 'finetune' +class MSGBUFFER: + TRAIN = 'train' + EVAL = 'eval' + CONSULT = 'consult' + + class TRIGGER: ON_FIT_START = 'on_fit_start' ON_EPOCH_START = 'on_epoch_start' diff --git a/federatedscope/core/workers/server.py b/federatedscope/core/workers/server.py index 38ceb6435..696c73f02 100644 --- a/federatedscope/core/workers/server.py +++ b/federatedscope/core/workers/server.py @@ -16,6 +16,7 @@ from federatedscope.core.auxiliaries.utils import merge_dict, Timeout, \ merge_param_dict from federatedscope.core.auxiliaries.trainer_builder import get_trainer +from federatedscope.core.auxiliaries.enums import MSGBUFFER from federatedscope.core.secret_sharing import AdditiveSecretSharing logger = logging.getLogger(__name__) @@ -251,9 +252,10 @@ def run(self): logger.info('Time out at the training round #{}'.format( self.state)) move_on_flag_eval = self.check_and_move_on( - min_received_num=min_received_num, - check_eval_result=True) + buffer_type=MSGBUFFER.EVAL, + min_received_num=min_received_num) move_on_flag = self.check_and_move_on( + buffer_type=MSGBUFFER.TRAIN, min_received_num=min_received_num) if not move_on_flag and not move_on_flag_eval: num_failure += 1 @@ -271,8 +273,8 @@ def run(self): f'Round #{self.state}) for {num_failure} time ' f'-------------') # TODO: Clean the msg_buffer - if self.state in self.msg_buffer['train']: - self.msg_buffer['train'][self.state].clear() + if self.state in self.msg_buffer[MSGBUFFER.TRAIN]: + self.msg_buffer[MSGBUFFER.TRAIN][self.state].clear() self.broadcast_model_para( msg_type='model_para', @@ -284,7 +286,7 @@ def run(self): self.terminate(msg_type='finish') def check_and_move_on(self, - check_eval_result=False, + buffer_type, min_received_num=None): """ To check the message_buffer. When enough messages are receiving, @@ -292,8 +294,8 @@ def check_and_move_on(self, the next training round) would be triggered. Arguments: - check_eval_result (bool): If True, check the message buffer for - evaluation; and check the message buffer for training otherwise. + buffer_type (str): The type of checked message buffer, chosen from MSGBUFFER.TRAIN, MSGBUFFER.EVAL and MSGBUFFER.CONSULT + min_received_num (int): The minimum number of received messages """ if min_received_num is None: if self._cfg.asyn.use: @@ -302,7 +304,7 @@ def check_and_move_on(self, min_received_num = self._cfg.federate.sample_client_num assert min_received_num <= self.sample_client_num - if check_eval_result and self._cfg.federate.mode.lower( + if buffer_type == MSGBUFFER.EVAL and self._cfg.federate.mode.lower( ) == "standalone": # in evaluation stage and standalone simulation mode, we assume # strong synchronization that receives responses from all clients @@ -310,8 +312,8 @@ def check_and_move_on(self, move_on_flag = True # To record whether moving to a new training # round or finishing the evaluation - if self.check_buffer(self.state, min_received_num, check_eval_result): - if not check_eval_result: + if self.check_buffer(self.state, min_received_num, buffer_type): + if buffer_type == MSGBUFFER.TRAIN: # Receiving enough feedback in the training process aggregated_num = self._perform_federated_aggregation() @@ -329,8 +331,8 @@ def check_and_move_on(self, f'----------- Starting a new training round (Round ' f'#{self.state}) -------------') # Clean the msg_buffer - self.msg_buffer['train'][self.state - 1].clear() - self.msg_buffer['train'][self.state] = dict() + self.msg_buffer[MSGBUFFER.TRAIN][self.state - 1].clear() + self.msg_buffer[MSGBUFFER.TRAIN][self.state] = dict() self.staled_msg_buffer.clear() # Start a new training round self._start_new_training_round(aggregated_num) @@ -393,8 +395,8 @@ def check_and_save(self): # Clean the clients evaluation msg buffer if not self._cfg.federate.make_global_eval: - round = max(self.msg_buffer['eval'].keys()) - self.msg_buffer['eval'][round].clear() + round = max(self.msg_buffer[MSGBUFFER.EVAL].keys()) + self.msg_buffer[MSGBUFFER.EVAL][round].clear() if self.state == self.total_round_num: # break out the loop for distributed mode @@ -404,7 +406,7 @@ def _perform_federated_aggregation(self): """ Perform federated aggregation and update the global model """ - train_msg_buffer = self.msg_buffer['train'][self.state] + train_msg_buffer = self.msg_buffer[MSGBUFFER.TRAIN][self.state] for model_idx in range(self.model_num): model = self.models[model_idx] aggregator = self.aggregators[model_idx] @@ -523,8 +525,8 @@ def save_client_eval_results(self): :return: """ - round = max(self.msg_buffer['eval'].keys()) - eval_msg_buffer = self.msg_buffer['eval'][round] + round = max(self.msg_buffer[MSGBUFFER.EVAL].keys()) + eval_msg_buffer = self.msg_buffer[MSGBUFFER.EVAL][round] with open(os.path.join(self._cfg.outdir, "eval_results.log"), "a") as outfile: @@ -545,8 +547,8 @@ def merge_eval_results_from_all_clients(self): :returns: the formatted merged results """ - round = max(self.msg_buffer['eval'].keys()) - eval_msg_buffer = self.msg_buffer['eval'][round] + round = max(self.msg_buffer[MSGBUFFER.EVAL].keys()) + eval_msg_buffer = self.msg_buffer[MSGBUFFER.EVAL][round] eval_res_participated_clients = [] eval_res_unseen_clients = [] for client_id in eval_msg_buffer: @@ -688,33 +690,21 @@ def broadcast_client_address(self): def check_buffer(self, cur_round, min_received_num, - check_eval_result=False): - """ - To check the message buffer + buffer_type): + """Check if the message buffer receives enough messages Arguments: - cur_round (int): The current round number - min_received_num (int): The minimal number of the receiving messages - check_eval_result (bool): To check training results for evaluation - results - :returns: Whether enough messages have been received or not - :rtype: bool - """ - - if check_eval_result: - if 'eval' not in self.msg_buffer.keys() or len( - self.msg_buffer['eval'].keys()) == 0: - return False - - buffer = self.msg_buffer['eval'] - cur_round = max(buffer.keys()) - cur_buffer = buffer[cur_round] - return len(cur_buffer) >= min_received_num - else: - if cur_round not in self.msg_buffer['train']: - cur_buffer = dict() - else: - cur_buffer = self.msg_buffer['train'][cur_round] + cur_round (int): The current round number + min_received_num (int): The minimal number of the receiving messages + buffer_type (str): Which field to check, chosen from MSGBUFFER.TRAIN, MSGBUFFER.EVAL and MSGBUFFER.CONSULT + + Return: + Whether enough messages have been received or not + """ + buffer = self.msg_buffer.get(buffer_type, dict()) + + if buffer_type == MSGBUFFER.TRAIN: + cur_buffer = buffer.get(cur_round, dict()) if self._cfg.asyn.use and self._cfg.asyn.aggregator == 'time_up': if self.cur_timestamp >= self.deadline_for_cur_round and len( cur_buffer) + len(self.staled_msg_buffer) == 0: @@ -738,6 +728,18 @@ def check_buffer(self, return len(cur_buffer)+len(self.staled_msg_buffer) >= \ min_received_num + elif buffer_type == MSGBUFFER.EVAL: + # Evaluation won't block the training process + cur_buffer = buffer.get(max(buffer.keys()), dict()) + return len(cur_buffer) >= min_received_num + + elif buffer_type == MSGBUFFER.CONSULT: + cur_buffer = buffer.get(cur_round, dict()) + return len(cur_buffer) >= min_received_num + + else: + raise NotImplementedError(f'Type of message buffer {buffer_type} is not implemented.') + def check_client_join_in(self): """ To check whether all the clients have joined in the FL course. @@ -802,7 +804,7 @@ def trigger_for_time_up(self, check_timestamp=None): return False self.cur_timestamp = self.deadline_for_cur_round - self.check_and_move_on() + self.check_and_move_on(buffer_type=MSGBUFFER.TRAIN) return True def terminate(self, msg_type='finish'): @@ -891,10 +893,10 @@ def callback_funcs_model_para(self, message: Message): self.cur_timestamp = timestamp if round == self.state: - if round not in self.msg_buffer['train']: - self.msg_buffer['train'][round] = dict() + if round not in self.msg_buffer[MSGBUFFER.TRAIN]: + self.msg_buffer[MSGBUFFER.TRAIN][round] = dict() # Save the messages in this round - self.msg_buffer['train'][round][sender] = content + self.msg_buffer[MSGBUFFER.TRAIN][round][sender] = content elif round >= self.state - self.staleness_toleration: # Save the staled messages self.staled_msg_buffer.append((round, sender, content)) @@ -906,7 +908,7 @@ def callback_funcs_model_para(self, message: Message): if self._cfg.federate.online_aggr: self.aggregator.inc(content) - move_on_flag = self.check_and_move_on() + move_on_flag = self.check_and_move_on(buffer_type=MSGBUFFER.TRAIN) if self._cfg.asyn.use and self._cfg.asyn.broadcast_manner == \ 'after_receiving': self.broadcast_model_para(msg_type='model_para', @@ -975,9 +977,9 @@ def callback_funcs_for_metrics(self, message: Message): sender = message.sender content = message.content - if round not in self.msg_buffer['eval'].keys(): - self.msg_buffer['eval'][round] = dict() + if round not in self.msg_buffer[MSGBUFFER.EVAL].keys(): + self.msg_buffer[MSGBUFFER.EVAL][round] = dict() - self.msg_buffer['eval'][round][sender] = content + self.msg_buffer[MSGBUFFER.EVAL][round][sender] = content - return self.check_and_move_on(check_eval_result=True) + return self.check_and_move_on(buffer_type=MSGBUFFER.EVAL) From 30c7854498337cc9b5304d4f5f6dad2cadcd5ad3 Mon Sep 17 00:00:00 2001 From: "gaodawei.gdw" Date: Wed, 21 Sep 2022 17:55:54 +0800 Subject: [PATCH 2/7] bug fix --- .../worker_as_attacker/server_attacker.py | 5 +- federatedscope/core/auxiliaries/enums.py | 2 +- federatedscope/core/workers/server.py | 62 +++++++++---------- federatedscope/gfl/fedsageplus/worker.py | 2 +- 4 files changed, 36 insertions(+), 35 deletions(-) diff --git a/federatedscope/attack/worker_as_attacker/server_attacker.py b/federatedscope/attack/worker_as_attacker/server_attacker.py index 226568242..577384026 100644 --- a/federatedscope/attack/worker_as_attacker/server_attacker.py +++ b/federatedscope/attack/worker_as_attacker/server_attacker.py @@ -1,3 +1,4 @@ +from federatedscope.core.auxiliaries.enums import STAGE from federatedscope.core.workers import Server from federatedscope.core.message import Message @@ -284,7 +285,7 @@ def callback_funcs_model_para(self, message: Message): name='image_state_{}_client_{}.png'.format( message.state, message.sender)) - self.check_and_move_on() + self.check_and_move_on(stage=STAGE.TRAIN) class PassivePIAServer(Server): @@ -359,7 +360,7 @@ def callback_funcs_model_para(self, message: Message): # TODO: put this line to `check_and_move_on` # currently, no way to know the latest `sender` self.aggregator.inc(content) - self.check_and_move_on() + self.check_and_move_on(stage=STAGE.TRAIN) if self.state == self.total_round_num: self.pia_attacker.train_property_classifier() diff --git a/federatedscope/core/auxiliaries/enums.py b/federatedscope/core/auxiliaries/enums.py index 7f6d1e805..53d9d1ec9 100644 --- a/federatedscope/core/auxiliaries/enums.py +++ b/federatedscope/core/auxiliaries/enums.py @@ -11,7 +11,7 @@ class MODE: FINETUNE = 'finetune' -class MSGBUFFER: +class STAGE: TRAIN = 'train' EVAL = 'eval' CONSULT = 'consult' diff --git a/federatedscope/core/workers/server.py b/federatedscope/core/workers/server.py index 696c73f02..11c0f439b 100644 --- a/federatedscope/core/workers/server.py +++ b/federatedscope/core/workers/server.py @@ -16,7 +16,7 @@ from federatedscope.core.auxiliaries.utils import merge_dict, Timeout, \ merge_param_dict from federatedscope.core.auxiliaries.trainer_builder import get_trainer -from federatedscope.core.auxiliaries.enums import MSGBUFFER +from federatedscope.core.auxiliaries.enums import STAGE from federatedscope.core.secret_sharing import AdditiveSecretSharing logger = logging.getLogger(__name__) @@ -252,10 +252,10 @@ def run(self): logger.info('Time out at the training round #{}'.format( self.state)) move_on_flag_eval = self.check_and_move_on( - buffer_type=MSGBUFFER.EVAL, + stage=STAGE.EVAL, min_received_num=min_received_num) move_on_flag = self.check_and_move_on( - buffer_type=MSGBUFFER.TRAIN, + stage=STAGE.TRAIN, min_received_num=min_received_num) if not move_on_flag and not move_on_flag_eval: num_failure += 1 @@ -273,8 +273,8 @@ def run(self): f'Round #{self.state}) for {num_failure} time ' f'-------------') # TODO: Clean the msg_buffer - if self.state in self.msg_buffer[MSGBUFFER.TRAIN]: - self.msg_buffer[MSGBUFFER.TRAIN][self.state].clear() + if self.state in self.msg_buffer[STAGE.TRAIN]: + self.msg_buffer[STAGE.TRAIN][self.state].clear() self.broadcast_model_para( msg_type='model_para', @@ -286,7 +286,7 @@ def run(self): self.terminate(msg_type='finish') def check_and_move_on(self, - buffer_type, + stage, min_received_num=None): """ To check the message_buffer. When enough messages are receiving, @@ -294,7 +294,7 @@ def check_and_move_on(self, the next training round) would be triggered. Arguments: - buffer_type (str): The type of checked message buffer, chosen from MSGBUFFER.TRAIN, MSGBUFFER.EVAL and MSGBUFFER.CONSULT + stage (str): The type of checked message buffer, chosen from MSGBUFFER.TRAIN, MSGBUFFER.EVAL and MSGBUFFER.CONSULT min_received_num (int): The minimum number of received messages """ if min_received_num is None: @@ -304,7 +304,7 @@ def check_and_move_on(self, min_received_num = self._cfg.federate.sample_client_num assert min_received_num <= self.sample_client_num - if buffer_type == MSGBUFFER.EVAL and self._cfg.federate.mode.lower( + if stage == STAGE.EVAL and self._cfg.federate.mode.lower( ) == "standalone": # in evaluation stage and standalone simulation mode, we assume # strong synchronization that receives responses from all clients @@ -312,8 +312,8 @@ def check_and_move_on(self, move_on_flag = True # To record whether moving to a new training # round or finishing the evaluation - if self.check_buffer(self.state, min_received_num, buffer_type): - if buffer_type == MSGBUFFER.TRAIN: + if self.check_buffer(self.state, min_received_num, stage): + if stage == STAGE.TRAIN: # Receiving enough feedback in the training process aggregated_num = self._perform_federated_aggregation() @@ -331,8 +331,8 @@ def check_and_move_on(self, f'----------- Starting a new training round (Round ' f'#{self.state}) -------------') # Clean the msg_buffer - self.msg_buffer[MSGBUFFER.TRAIN][self.state - 1].clear() - self.msg_buffer[MSGBUFFER.TRAIN][self.state] = dict() + self.msg_buffer[STAGE.TRAIN][self.state - 1].clear() + self.msg_buffer[STAGE.TRAIN][self.state] = dict() self.staled_msg_buffer.clear() # Start a new training round self._start_new_training_round(aggregated_num) @@ -395,8 +395,8 @@ def check_and_save(self): # Clean the clients evaluation msg buffer if not self._cfg.federate.make_global_eval: - round = max(self.msg_buffer[MSGBUFFER.EVAL].keys()) - self.msg_buffer[MSGBUFFER.EVAL][round].clear() + round = max(self.msg_buffer[STAGE.EVAL].keys()) + self.msg_buffer[STAGE.EVAL][round].clear() if self.state == self.total_round_num: # break out the loop for distributed mode @@ -406,7 +406,7 @@ def _perform_federated_aggregation(self): """ Perform federated aggregation and update the global model """ - train_msg_buffer = self.msg_buffer[MSGBUFFER.TRAIN][self.state] + train_msg_buffer = self.msg_buffer[STAGE.TRAIN][self.state] for model_idx in range(self.model_num): model = self.models[model_idx] aggregator = self.aggregators[model_idx] @@ -525,8 +525,8 @@ def save_client_eval_results(self): :return: """ - round = max(self.msg_buffer[MSGBUFFER.EVAL].keys()) - eval_msg_buffer = self.msg_buffer[MSGBUFFER.EVAL][round] + round = max(self.msg_buffer[STAGE.EVAL].keys()) + eval_msg_buffer = self.msg_buffer[STAGE.EVAL][round] with open(os.path.join(self._cfg.outdir, "eval_results.log"), "a") as outfile: @@ -547,8 +547,8 @@ def merge_eval_results_from_all_clients(self): :returns: the formatted merged results """ - round = max(self.msg_buffer[MSGBUFFER.EVAL].keys()) - eval_msg_buffer = self.msg_buffer[MSGBUFFER.EVAL][round] + round = max(self.msg_buffer[STAGE.EVAL].keys()) + eval_msg_buffer = self.msg_buffer[STAGE.EVAL][round] eval_res_participated_clients = [] eval_res_unseen_clients = [] for client_id in eval_msg_buffer: @@ -703,7 +703,7 @@ def check_buffer(self, """ buffer = self.msg_buffer.get(buffer_type, dict()) - if buffer_type == MSGBUFFER.TRAIN: + if buffer_type == STAGE.TRAIN: cur_buffer = buffer.get(cur_round, dict()) if self._cfg.asyn.use and self._cfg.asyn.aggregator == 'time_up': if self.cur_timestamp >= self.deadline_for_cur_round and len( @@ -728,12 +728,12 @@ def check_buffer(self, return len(cur_buffer)+len(self.staled_msg_buffer) >= \ min_received_num - elif buffer_type == MSGBUFFER.EVAL: + elif buffer_type == STAGE.EVAL: # Evaluation won't block the training process cur_buffer = buffer.get(max(buffer.keys()), dict()) return len(cur_buffer) >= min_received_num - elif buffer_type == MSGBUFFER.CONSULT: + elif buffer_type == STAGE.CONSULT: cur_buffer = buffer.get(cur_round, dict()) return len(cur_buffer) >= min_received_num @@ -804,7 +804,7 @@ def trigger_for_time_up(self, check_timestamp=None): return False self.cur_timestamp = self.deadline_for_cur_round - self.check_and_move_on(buffer_type=MSGBUFFER.TRAIN) + self.check_and_move_on(stage=STAGE.TRAIN) return True def terminate(self, msg_type='finish'): @@ -893,10 +893,10 @@ def callback_funcs_model_para(self, message: Message): self.cur_timestamp = timestamp if round == self.state: - if round not in self.msg_buffer[MSGBUFFER.TRAIN]: - self.msg_buffer[MSGBUFFER.TRAIN][round] = dict() + if round not in self.msg_buffer[STAGE.TRAIN]: + self.msg_buffer[STAGE.TRAIN][round] = dict() # Save the messages in this round - self.msg_buffer[MSGBUFFER.TRAIN][round][sender] = content + self.msg_buffer[STAGE.TRAIN][round][sender] = content elif round >= self.state - self.staleness_toleration: # Save the staled messages self.staled_msg_buffer.append((round, sender, content)) @@ -908,7 +908,7 @@ def callback_funcs_model_para(self, message: Message): if self._cfg.federate.online_aggr: self.aggregator.inc(content) - move_on_flag = self.check_and_move_on(buffer_type=MSGBUFFER.TRAIN) + move_on_flag = self.check_and_move_on(stage=STAGE.TRAIN) if self._cfg.asyn.use and self._cfg.asyn.broadcast_manner == \ 'after_receiving': self.broadcast_model_para(msg_type='model_para', @@ -977,9 +977,9 @@ def callback_funcs_for_metrics(self, message: Message): sender = message.sender content = message.content - if round not in self.msg_buffer[MSGBUFFER.EVAL].keys(): - self.msg_buffer[MSGBUFFER.EVAL][round] = dict() + if round not in self.msg_buffer[STAGE.EVAL].keys(): + self.msg_buffer[STAGE.EVAL][round] = dict() - self.msg_buffer[MSGBUFFER.EVAL][round][sender] = content + self.msg_buffer[STAGE.EVAL][round][sender] = content - return self.check_and_move_on(buffer_type=MSGBUFFER.EVAL) + return self.check_and_move_on(stage=STAGE.EVAL) diff --git a/federatedscope/gfl/fedsageplus/worker.py b/federatedscope/gfl/fedsageplus/worker.py index f1812598d..93628de29 100644 --- a/federatedscope/gfl/fedsageplus/worker.py +++ b/federatedscope/gfl/fedsageplus/worker.py @@ -121,7 +121,7 @@ def callback_funcs_gradient(self, message: Message): gen_grad[key].cpu()) self.check_and_move_on() - def check_and_move_on(self, check_eval_result=False): + def check_and_move_on(self, check_eval_result=False, **kwargs): client_IDs = [i for i in range(1, self.client_num + 1)] if check_eval_result: From c7d183233c05ac23902d1e94156186ad5fa07a01 Mon Sep 17 00:00:00 2001 From: "gaodawei.gdw" Date: Wed, 21 Sep 2022 17:59:20 +0800 Subject: [PATCH 3/7] use enums rather than string --- .../worker_as_attacker/server_attacker.py | 18 +++++++-------- federatedscope/autotune/fedex/server.py | 11 +++++---- federatedscope/core/workers/client.py | 9 ++++---- federatedscope/gfl/fedsageplus/worker.py | 23 ++++++++++--------- federatedscope/gfl/gcflplus/worker.py | 9 ++++---- 5 files changed, 37 insertions(+), 33 deletions(-) diff --git a/federatedscope/attack/worker_as_attacker/server_attacker.py b/federatedscope/attack/worker_as_attacker/server_attacker.py index 577384026..651555740 100644 --- a/federatedscope/attack/worker_as_attacker/server_attacker.py +++ b/federatedscope/attack/worker_as_attacker/server_attacker.py @@ -245,15 +245,15 @@ def _reconstruct(self, model_para, batch_size, state, sender): def run_reconstruct(self, state_list=None, sender_list=None): if state_list is None: - state_list = self.msg_buffer['train'].keys() + state_list = self.msg_buffer[STAGE.TRAIN].keys() # After FL running, using gradient based reconstruction method to # recover client's private training data for state in state_list: if sender_list is None: - sender_list = self.msg_buffer['train'][state].keys() + sender_list = self.msg_buffer[STAGE.TRAIN][state].keys() for sender in sender_list: - content = self.msg_buffer['train'][state][sender] + content = self.msg_buffer[STAGE.TRAIN][state][sender] self._reconstruct(model_para=content[1], batch_size=content[0], state=state, @@ -265,9 +265,9 @@ def callback_funcs_model_para(self, message: Message): round, sender, content = message.state, message.sender, message.content self.sampler.change_state(sender, 'idle') - if round not in self.msg_buffer['train']: - self.msg_buffer['train'][round] = dict() - self.msg_buffer['train'][round][sender] = content + if round not in self.msg_buffer[STAGE.TRAIN]: + self.msg_buffer[STAGE.TRAIN][round] = dict() + self.msg_buffer[STAGE.TRAIN][round][sender] = content # run reconstruction before the clear of self.msg_buffer @@ -343,9 +343,9 @@ def callback_funcs_model_para(self, message: Message): round, sender, content = message.state, message.sender, message.content self.sampler.change_state(sender, 'idle') - if round not in self.msg_buffer['train']: - self.msg_buffer['train'][round] = dict() - self.msg_buffer['train'][round][sender] = content + if round not in self.msg_buffer[STAGE.TRAIN]: + self.msg_buffer[STAGE.TRAIN][round] = dict() + self.msg_buffer[STAGE.TRAIN][round][sender] = content # collect the updates self.pia_attacker.collect_updates( diff --git a/federatedscope/autotune/fedex/server.py b/federatedscope/autotune/fedex/server.py index fac5e6488..17b998c30 100644 --- a/federatedscope/autotune/fedex/server.py +++ b/federatedscope/autotune/fedex/server.py @@ -8,6 +8,7 @@ from numpy.linalg import norm from scipy.special import logsumexp +from federatedscope.core.auxiliaries.enums import STAGE from federatedscope.core.message import Message from federatedscope.core.workers import Server from federatedscope.core.auxiliaries.utils import merge_dict @@ -204,10 +205,10 @@ def callback_funcs_model_para(self, message: Message): round, sender, content = message.state, message.sender, message.content self.sampler.change_state(sender, 'idle') # For a new round - if round not in self.msg_buffer['train'].keys(): - self.msg_buffer['train'][round] = dict() + if round not in self.msg_buffer[STAGE.TRAIN].keys(): + self.msg_buffer[STAGE.TRAIN][round] = dict() - self.msg_buffer['train'][round][sender] = content + self.msg_buffer[STAGE.TRAIN][round][sender] = content if self._cfg.federate.online_aggr: self.aggregator.inc(tuple(content[0:2])) @@ -305,7 +306,7 @@ def check_and_move_on(self, if not check_eval_result: # in the training process mab_feedbacks = list() # Get all the message - train_msg_buffer = self.msg_buffer['train'][self.state] + train_msg_buffer = self.msg_buffer[STAGE.TRAIN][self.state] for model_idx in range(self.model_num): model = self.models[model_idx] aggregator = self.aggregators[model_idx] @@ -363,7 +364,7 @@ def check_and_move_on(self, f'----------- Starting a new training round (Round ' f'#{self.state}) -------------') # Clean the msg_buffer - self.msg_buffer['train'][self.state - 1].clear() + self.msg_buffer[STAGE.TRAIN][self.state - 1].clear() self.broadcast_model_para( msg_type='model_para', diff --git a/federatedscope/core/workers/client.py b/federatedscope/core/workers/client.py index 2dbc4bdd7..9f47aa60c 100644 --- a/federatedscope/core/workers/client.py +++ b/federatedscope/core/workers/client.py @@ -3,6 +3,7 @@ import sys import pickle +from federatedscope.core.auxiliaries.enums import STAGE from federatedscope.core.message import Message from federatedscope.core.communication import StandaloneCommManager, \ gRPCCommManager @@ -225,12 +226,12 @@ def callback_funcs_for_model_para(self, message: Message): # A fragment of the shared secret state, content, timestamp = message.state, message.content, \ message.timestamp - self.msg_buffer['train'][state].append(content) + self.msg_buffer[STAGE.TRAIN][state].append(content) - if len(self.msg_buffer['train'] + if len(self.msg_buffer[STAGE.TRAIN] [state]) == self._cfg.federate.client_num: # Check whether the received fragments are enough - model_list = self.msg_buffer['train'][state] + model_list = self.msg_buffer[STAGE.TRAIN][state] sample_size, first_aggregate_model_para = model_list[0] single_model_case = True if isinstance(first_aggregate_model_para, list): @@ -355,7 +356,7 @@ def callback_funcs_for_model_para(self, message: Message): single_model_case else \ [model_para_list[frame_idx] for model_para_list in model_para_list_all] - self.msg_buffer['train'][self.state] = [(sample_size, + self.msg_buffer[STAGE.TRAIN][self.state] = [(sample_size, content_frame)] else: if self._cfg.asyn.use: diff --git a/federatedscope/gfl/fedsageplus/worker.py b/federatedscope/gfl/fedsageplus/worker.py index 93628de29..442e31249 100644 --- a/federatedscope/gfl/fedsageplus/worker.py +++ b/federatedscope/gfl/fedsageplus/worker.py @@ -4,6 +4,7 @@ from torch_geometric.loader import NeighborSampler +from federatedscope.core.auxiliaries.enums import STAGE from federatedscope.core.message import Message from federatedscope.core.workers.server import Server from federatedscope.core.workers.client import Client @@ -106,18 +107,18 @@ def callback_funcs_gradient(self, message: Message): round, _, content = message.state, message.sender, message.content gen_grad, ID = content # For a new round - if round not in self.msg_buffer['train'].keys(): - self.msg_buffer['train'][round] = dict() + if round not in self.msg_buffer[STAGE.TRAIN].keys(): + self.msg_buffer[STAGE.TRAIN][round] = dict() self.grad_cnt += 1 # Sum up all grad from other client - if ID not in self.msg_buffer['train'][round]: - self.msg_buffer['train'][round][ID] = dict() + if ID not in self.msg_buffer[STAGE.TRAIN][round]: + self.msg_buffer[STAGE.TRAIN][round][ID] = dict() for key in gen_grad.keys(): - self.msg_buffer['train'][round][ID][key] = torch.FloatTensor( + self.msg_buffer[STAGE.TRAIN][round][ID][key] = torch.FloatTensor( gen_grad[key].cpu()) else: for key in gen_grad.keys(): - self.msg_buffer['train'][round][ID][key] += torch.FloatTensor( + self.msg_buffer[STAGE.TRAIN][round][ID][key] += torch.FloatTensor( gen_grad[key].cpu()) self.check_and_move_on() @@ -137,8 +138,8 @@ def check_and_move_on(self, check_eval_result=False, **kwargs): ) and self.state < self._cfg.fedsageplus.fedgen_epoch and self.state\ % 2 == 0: # FedGen: we should wait for all messages - for sender in self.msg_buffer['train'][self.state]: - content = self.msg_buffer['train'][self.state][sender] + for sender in self.msg_buffer[STAGE.TRAIN][self.state]: + content = self.msg_buffer[STAGE.TRAIN][self.state][sender] gen_para, embedding, label = content receiver_IDs = client_IDs[:sender - 1] + client_IDs[sender:] self.comm_manager.send( @@ -157,8 +158,8 @@ def check_and_move_on(self, check_eval_result=False, **kwargs): ) and self.state < self._cfg.fedsageplus.fedgen_epoch and self.state\ % 2 == 1 and self.grad_cnt == self.client_num * ( self.client_num - 1): - for ID in self.msg_buffer['train'][self.state]: - grad = self.msg_buffer['train'][self.state][ID] + for ID in self.msg_buffer[STAGE.TRAIN][self.state]: + grad = self.msg_buffer[STAGE.TRAIN][self.state][ID] self.comm_manager.send( Message(msg_type='gradient', sender=self.ID, @@ -186,7 +187,7 @@ def check_and_move_on(self, check_eval_result=False, **kwargs): if not check_eval_result: # in the training process # Get all the message - train_msg_buffer = self.msg_buffer['train'][self.state] + train_msg_buffer = self.msg_buffer[STAGE.TRAIN][self.state] msg_list = list() for client_id in train_msg_buffer: msg_list.append(train_msg_buffer[client_id]) diff --git a/federatedscope/gfl/gcflplus/worker.py b/federatedscope/gfl/gcflplus/worker.py index 26a1d62df..5ecd4623a 100644 --- a/federatedscope/gfl/gcflplus/worker.py +++ b/federatedscope/gfl/gcflplus/worker.py @@ -3,6 +3,7 @@ import copy import numpy as np +from federatedscope.core.auxiliaries.enums import STAGE from federatedscope.core.message import Message from federatedscope.core.workers.server import Server from federatedscope.core.workers.client import Client @@ -44,7 +45,7 @@ def compute_update_norm(self, cluster): max_norm = -np.inf cluster_dWs = [] for key in cluster: - content = self.msg_buffer['train'][self.state][key] + content = self.msg_buffer[STAGE.TRAIN][self.state][key] _, model_para, client_dw, _ = content dW = {} for k in model_para.keys(): @@ -71,7 +72,7 @@ def check_and_move_on(self, check_eval_result=False): if not check_eval_result: # in the training process # Get all the message - train_msg_buffer = self.msg_buffer['train'][self.state] + train_msg_buffer = self.msg_buffer[STAGE.TRAIN][self.state] for model_idx in range(self.model_num): model = self.models[model_idx] aggregator = self.aggregators[model_idx] @@ -135,7 +136,7 @@ def check_and_move_on(self, check_eval_result=False): for cluster in self.cluster_indices: msg_list = list() for key in cluster: - content = self.msg_buffer['train'][self.state - + content = self.msg_buffer[STAGE.TRAIN][self.state - 1][key] train_data_size, model_para, client_dw, \ convGradsNorm = content @@ -161,7 +162,7 @@ def check_and_move_on(self, check_eval_result=False): f'----------- Starting a new traininground(Round ' f'#{self.state}) -------------') # Clean the msg_buffer - self.msg_buffer['train'][self.state - 1].clear() + self.msg_buffer[STAGE.TRAIN][self.state - 1].clear() else: # Final Evaluate From 2295c70182299b43f87a98f5a503fc4b8fab3fe5 Mon Sep 17 00:00:00 2001 From: "gaodawei.gdw" Date: Wed, 21 Sep 2022 19:05:12 +0800 Subject: [PATCH 4/7] [WIP] add client_manager --- federatedscope/core/auxiliaries/enums.py | 26 ++- .../core/auxiliaries/sampler_builder.py | 9 +- federatedscope/core/sampler.py | 86 +++----- federatedscope/core/workers/client_manager.py | 193 ++++++++++++++++++ federatedscope/core/workers/server.py | 38 +--- 5 files changed, 256 insertions(+), 96 deletions(-) create mode 100644 federatedscope/core/workers/client_manager.py diff --git a/federatedscope/core/auxiliaries/enums.py b/federatedscope/core/auxiliaries/enums.py index 53d9d1ec9..d67104933 100644 --- a/federatedscope/core/auxiliaries/enums.py +++ b/federatedscope/core/auxiliaries/enums.py @@ -1,4 +1,14 @@ -class MODE: +class BasicEnum(object): + @classmethod + def assert_value(cls, value): + """ + Check if the **value** is legal for the given class (If the value equals one of the class attributes) + """ + if not value in [v for k, v in cls.__dict__.items() if not k.startswith('__')]: + raise ValueError(f"Value {value} is not in {cls.__name__}.") + + +class MODE(BasicEnum): """ Note: @@ -11,13 +21,13 @@ class MODE: FINETUNE = 'finetune' -class STAGE: +class STAGE(BasicEnum): TRAIN = 'train' EVAL = 'eval' CONSULT = 'consult' -class TRIGGER: +class TRIGGER(BasicEnum): ON_FIT_START = 'on_fit_start' ON_EPOCH_START = 'on_epoch_start' ON_BATCH_START = 'on_batch_start' @@ -36,8 +46,16 @@ def contains(cls, item): ] -class LIFECYCLE: +class LIFECYCLE(BasicEnum): ROUTINE = 'routine' EPOCH = 'epoch' BATCH = 'batch' NONE = None + + +class CLIENT_STATE(BasicEnum): + OFFLINE = -1 # not join + CONSULTING = 2 # join in and is consulting + IDLE = 1 # join in but not working, available for training + WORKING = 0 # join in and is working + SIDELINE = 0 # join in but won't participate in federated training diff --git a/federatedscope/core/auxiliaries/sampler_builder.py b/federatedscope/core/auxiliaries/sampler_builder.py index 0b7d2ff44..f5492fad3 100644 --- a/federatedscope/core/auxiliaries/sampler_builder.py +++ b/federatedscope/core/auxiliaries/sampler_builder.py @@ -5,16 +5,15 @@ logger = logging.getLogger(__name__) -def get_sampler(sample_strategy='uniform', +def get_sampler(sample_strategy, client_num=None, client_info=None, bins=10): if sample_strategy == 'uniform': - return UniformSampler(client_num=client_num) + return UniformSampler() elif sample_strategy == 'group': return GroupSampler(client_num=client_num, - client_info=client_info, + client_info=client_info['client_resource'], bins=bins) else: - raise ValueError( - f"The sample strategy {sample_strategy} has not been provided.") + return None \ No newline at end of file diff --git a/federatedscope/core/sampler.py b/federatedscope/core/sampler.py index 6438cf822..cc3c12af3 100644 --- a/federatedscope/core/sampler.py +++ b/federatedscope/core/sampler.py @@ -5,56 +5,31 @@ class Sampler(ABC): """ - The strategies of sampling clients for each training round - - Arguments: - client_state: a dict to manager the state of clients (idle or busy) + The strategy of sampling """ - def __init__(self, client_num): - self.client_state = np.asarray([1] * (client_num + 1)) - # Set the state of server (index=0) to 'working' - self.client_state[0] = 0 + def __init__(self): + pass @abstractmethod - def sample(self, size): + def sample(self, *args, **kwargs): raise NotImplementedError - def change_state(self, indices, state): - """ - To modify the state of clients (idle or working) - """ - if isinstance(indices, list) or isinstance(indices, np.ndarray): - all_idx = indices - else: - all_idx = [indices] - for idx in all_idx: - if state in ['idle', 'seen']: - self.client_state[idx] = 1 - elif state in ['working', 'unseen']: - self.client_state[idx] = 0 - else: - raise ValueError( - f"The state of client should be one of " - f"['idle', 'working', 'unseen], but got {state}") - class UniformSampler(Sampler): """ - To uniformly sample the clients from all the idle clients + A stateless sampler that samples items uniformly """ - def __init__(self, client_num): - super(UniformSampler, self).__init__(client_num) + def __init__(self): + super(UniformSampler, self).__init__() - def sample(self, size): + def sample(self, client_idle, size, *args, **kwargs): """ To sample clients """ - idle_clients = np.nonzero(self.client_state)[0] - sampled_clients = np.random.choice(idle_clients, - size=size, - replace=False).tolist() - self.change_state(sampled_clients, 'working') - return sampled_clients + sampled_items = np.random.choice(a, + size=size, + replace=False).tolist() + return sampled_items class GroupSampler(Sampler): @@ -62,23 +37,12 @@ class GroupSampler(Sampler): To grouply sample the clients based on their responsiveness (or other client information of the clients) """ - def __init__(self, client_num, client_info, bins=10): - super(GroupSampler, self).__init__(client_num) + def __init__(self, client_info, bins=10): + super(GroupSampler, self).__init__() self.bins = bins - self.update_client_info(client_info) + self.client_info = client_info self.candidate_iterator = self.partition() - def update_client_info(self, client_info): - """ - To update the client information - """ - self.client_info = np.asarray( - [1.0] + [x for x in client_info - ]) # client_info[0] is preversed for the server - assert len(self.client_info) == len( - self.client_state - ), "The first dimension of client_info is mismatched with client_num" - def partition(self): """ To partition the clients into groups according to the client @@ -87,7 +51,9 @@ def partition(self): Arguments: :returns: a iteration of candidates """ - sorted_index = np.argsort(self.client_info) + # sort client_info by xx + sorted_index = sorted(self.client_info.keys(), key=lambda x: self.client_info[x]) + # bin的长度 num_per_bins = np.int(len(sorted_index) / self.bins) # grouped clients @@ -105,11 +71,14 @@ def permutation(self): return iter(candidates) - def sample(self, size, shuffle=False): + def sample(self, clients_idle, size, perturb=False): """ To sample clients """ - if shuffle: + if self.candidate_iterator is None: + self.partition() + + if perturb: self.candidate_iterator = self.permutation() sampled_clients = list() @@ -117,15 +86,14 @@ def sample(self, size, shuffle=False): # To find an idle client while True: try: - item = next(self.candidate_iterator) + client_id = next(self.candidate_iterator) except StopIteration: self.candidate_iterator = self.permutation() - item = next(self.candidate_iterator) + client_id = next(self.candidate_iterator) - if self.client_state[item] == 1: + if client_id in clients_idle: break - sampled_clients.append(item) - self.change_state(item, 'working') + sampled_clients.append(client_id) return sampled_clients diff --git a/federatedscope/core/workers/client_manager.py b/federatedscope/core/workers/client_manager.py new file mode 100644 index 000000000..23f8c15dc --- /dev/null +++ b/federatedscope/core/workers/client_manager.py @@ -0,0 +1,193 @@ +import collections + +import numpy as np + +from federatedscope.core.auxiliaries.sampler_builder import get_sampler +from federatedscope.core.auxiliaries.enums import CLIENT_STATE + +import logging + +logger = logging.getLogger(__name__) + + +class ClientManager(object): + def __init__(self, + num_client, + sample_strategy, + id_client_unseen=[]): + self._num_client_join = 0 + self._num_client_total = num_client + + # the unseen clients indicate the ones that do not contribute to FL + # process by training on their local data and uploading their local + # model update. The splitting is useful to check participation + # generalization gap in + # [ICLR'22, What Do We Mean by Generalization in Federated Learning?] + self._num_client_unseen = len(id_client_unseen) + self._id_client_unseen = id_client_unseen + + # TODO: achieve by a two-leveled index rather than saving twice + # Used to maintain the information collected from the clients + self._info_by_client = collections.defaultdict(dict) + self._info_by_key = collections.defaultdict(dict) + + # Record the state of the clients (client_id counts from 1) + self._state_client = {client_id: CLIENT_STATE.OFFLINE for client_id in range(1, self._num_client_total+1)} + + self.sampler = None + self.sample_strategy = sample_strategy + + def __assert_client(self, client_id): + """ + Check if the client_id is legal. + """ + if client_id >= 1 and client_id <= self._num_client_join + 1: + pass + else: + raise IndexError(f"Client ID {client_id} doesn't exist.") + + @property + def _num_join_client(self): + return self._num_join_client + + def join_in(self, client_id): + """ + Register the client as online + """ + self.__assert_client(client_id) + + # Step into consulting (exchange information between server and client) + self._state_client[client_id] = CLIENT_STATE.CONSULTING + self._num_join_client += 1 + + def set_offline(self, client_id): + """ + Set the state of the client as CLIENT_STATE.OFFLINE + """ + self.__assert_client(client_id) + + self.change_state(client_id, CLIENT_STATE.OFFLINE) + self._num_join_client -= 1 + + def finish_consult(self, client_id): + """ + Set the state of the client that finishes consulting as CLIENT_STATE.IDLE + """ + self.change_state(client_id, CLIENT_STATE.IDLE) + + def check_client_join_in(self): + """ + Check if enough clients has joined in. + """ + return self._num_join_client == self._num_client + + def check_client_info(self): + """ + Check if enough information is collected from the clients + """ + return len(self._info_by_client) == self._num_client + + def check_client_consult(self): + """ + Check if all clients finish requiring information from the server (The state is CLIENT_STATE.IDLE) + """ + return len(self.get_idle_client()) == self._num_client + + def update_client_info(self, client_id, info: dict): + """ + Update client information in the manager. + """ + if client_id in self._info_by_client: + logger.info(f"Information of Client #{client_id} is updated by {info}.") + self._info_by_client.update(info) + + for k, v in info.items(): + self._info_by_key[k][client_id] = v + + def del_client_info(self, client_id): + """ + Delete the client information. + """ + if client_id in self._info_by_client: + del self._info_by_client[client_id] + for k in self._info_by_client: + if client_id in self._info_by_client[k]: + del self._info_by_client[k][client_id] + + def get_info_by_client(self, client_id): + """ + Get the client information by client_id + """ + return self._info_by_client.get(client_id, None) + + def get_info_by_key(self, key): + """ + Get the client information by key + """ + return self._info_by_client.get(key, None) + + def change_state(self, indices, state): + """ + To modify the state of clients (idle or working) + """ + CLIENT_STATE.assert_value(state) + + if isinstance(indices, list) or isinstance(indices, np.ndarray): + client_idx = indices + else: + client_idx = [indices] + + for client_id in client_idx: + self._state_client[client_id] = state + + def get_client_by_state(self, state): + CLIENT_STATE.assert_value(state) + + return [client_id for client_id, client_state in self._state_client if client_state == state] + + def get_consult_client(self): + """ + Return all the clients with state CLIENT_STATE.CONSULTING + """ + self.get_client_by_state(CLIENT_STATE.CONSULTING) + + def get_idle_client(self): + """ + Return all the clients with state CLIENT_STATE.IDLE + """ + return self.get_client_by_state(CLIENT_STATE.IDLE) + + def init_sampler(self): + """ + Considering the sampling strategy may need client information, we should + initialize it after finishing client information collection, or + re-initialize it after updating client information. + """ + # To sample clients during training + self.sampler = get_sampler( + sample_strategy=self.sample_strategy, + client_num=self._num_client_total, + client_info=self._info_by_key + ) + + def sample(self, size, perturb=False): + """Sample idle clients with the specific sampler + + Args: + size: the number of sample size + perturb: if perturb the clients before sampling + + Returns: + index of the sampled clients + """ + if self.sampler is None: + # When we call the function sample, we default that all client information have been collected. + self.init_sampler() + + # Obtain the idle clients + clients_idle = self.get_idle_client() + # Sampling + clients_sampled = self.sampler.sample(clients_idle, size, perturb) + # Change state for the sampled clients + self.change_state(clients_sampled, CLIENT_STATE.WORKING) + return clients_sampled diff --git a/federatedscope/core/workers/server.py b/federatedscope/core/workers/server.py index 11c0f439b..f83d1c9c3 100644 --- a/federatedscope/core/workers/server.py +++ b/federatedscope/core/workers/server.py @@ -11,6 +11,7 @@ from federatedscope.core.communication import StandaloneCommManager, \ gRPCCommManager from federatedscope.core.workers import Worker +from federatedscope.core.workers.client_manager import ClientManager from federatedscope.core.auxiliaries.aggregator_builder import get_aggregator from federatedscope.core.auxiliaries.sampler_builder import get_sampler from federatedscope.core.auxiliaries.utils import merge_dict, Timeout, \ @@ -49,7 +50,7 @@ def __init__(self, total_round_num=10, device='cpu', strategy=None, - unseen_clients_id=None, + id_clients_unseen=[], **kwargs): super(Server, self).__init__(ID, state, config, model, strategy) @@ -124,18 +125,16 @@ def __init__(self, ]) # Initialize the number of joined-in clients + self.client_manager = ClientManager( + client_num, + sample_strategy=self._cfg.federate.sampler, + id_client_unseen=id_clients_unseen) + self._client_num = client_num self._total_round_num = total_round_num self.sample_client_num = int(self._cfg.federate.sample_client_num) self.join_in_client_num = 0 self.join_in_info = dict() - # the unseen clients indicate the ones that do not contribute to FL - # process by training on their local data and uploading their local - # model update. The splitting is useful to check participation - # generalization gap in - # [ICLR'22, What Do We Mean by Generalization in Federated Learning?] - self.unseen_clients_id = [] if unseen_clients_id is None \ - else unseen_clients_id # Server state self.is_finish = False @@ -759,26 +758,9 @@ def trigger_for_start(self): if self._cfg.federate.use_ss: self.broadcast_client_address() - # get sampler - if 'client_resource' in self._cfg.federate.join_in_info: - client_resource = [ - self.join_in_info[client_index]['client_resource'] - for client_index in np.arange(1, self.client_num + 1) - ] - else: - model_size = sys.getsizeof(pickle.dumps( - self.model)) / 1024.0 * 8. - client_resource = [ - model_size / float(x['communication']) + - float(x['computation']) / 1000. - for x in self.client_resource_info - ] if self.client_resource_info is not None else None - - if self.sampler is None: - self.sampler = get_sampler( - sample_strategy=self._cfg.federate.sampler, - client_num=self.client_num, - client_info=client_resource) + # Prepare for training + # init sampler within the client manager after finishing exchanging information + self.client_manager.init_sampler() # change the deadline if the asyn.aggregator is `time up` if self._cfg.asyn.use and self._cfg.asyn.aggregator == 'time_up': From bc7a8d99f56a1b9a4442ad7e2335b656759f320e Mon Sep 17 00:00:00 2001 From: "gaodawei.gdw" Date: Wed, 21 Sep 2022 21:31:34 +0800 Subject: [PATCH 5/7] [WIP] debug client_manager for the server --- federatedscope/core/auxiliaries/enums.py | 2 +- federatedscope/core/sampler.py | 2 +- federatedscope/core/workers/client_manager.py | 39 +++++++--- federatedscope/core/workers/server.py | 71 ++++++++----------- 4 files changed, 59 insertions(+), 55 deletions(-) diff --git a/federatedscope/core/auxiliaries/enums.py b/federatedscope/core/auxiliaries/enums.py index d67104933..024004189 100644 --- a/federatedscope/core/auxiliaries/enums.py +++ b/federatedscope/core/auxiliaries/enums.py @@ -54,7 +54,7 @@ class LIFECYCLE(BasicEnum): class CLIENT_STATE(BasicEnum): - OFFLINE = -1 # not join + OFFLINE = -1 # not join in CONSULTING = 2 # join in and is consulting IDLE = 1 # join in but not working, available for training WORKING = 0 # join in and is working diff --git a/federatedscope/core/sampler.py b/federatedscope/core/sampler.py index cc3c12af3..7cce752c9 100644 --- a/federatedscope/core/sampler.py +++ b/federatedscope/core/sampler.py @@ -26,7 +26,7 @@ def sample(self, client_idle, size, *args, **kwargs): """ To sample clients """ - sampled_items = np.random.choice(a, + sampled_items = np.random.choice(client_idle, size=size, replace=False).tolist() return sampled_items diff --git a/federatedscope/core/workers/client_manager.py b/federatedscope/core/workers/client_manager.py index 23f8c15dc..d640092dd 100644 --- a/federatedscope/core/workers/client_manager.py +++ b/federatedscope/core/workers/client_manager.py @@ -47,18 +47,28 @@ def __assert_client(self, client_id): raise IndexError(f"Client ID {client_id} doesn't exist.") @property - def _num_join_client(self): - return self._num_join_client + def join_in_client_num(self): + return self._num_client_join + + def register_client(self): + self._num_client_join += 1 + return self._num_client_join def join_in(self, client_id): """ - Register the client as online + Register the client, and assign client_id if it doesn't have """ - self.__assert_client(client_id) + # Count the number of client + self._num_client_join += 1 + + if client_id == -1: + # Doesn't have client_id + client_id = self._num_client_join # Step into consulting (exchange information between server and client) - self._state_client[client_id] = CLIENT_STATE.CONSULTING - self._num_join_client += 1 + self._state_client[client_id] = CLIENT_STATE.IDLE + # Return the client_id + return client_id def set_offline(self, client_id): """ @@ -67,7 +77,14 @@ def set_offline(self, client_id): self.__assert_client(client_id) self.change_state(client_id, CLIENT_STATE.OFFLINE) - self._num_join_client -= 1 + self._num_client_join -= 1 + + def block_unseen_client(self): + """ + Set the state of the client as CLIENT_STATE.SIDELINE + """ + if self._num_client_unseen > 0: + self.change_state(self._num_client_unseen, CLIENT_STATE.SIDELINE) def finish_consult(self, client_id): """ @@ -79,19 +96,19 @@ def check_client_join_in(self): """ Check if enough clients has joined in. """ - return self._num_join_client == self._num_client + return self._num_client_join == self._num_client_total def check_client_info(self): """ Check if enough information is collected from the clients """ - return len(self._info_by_client) == self._num_client + return len(self._info_by_client) == self._num_client_total def check_client_consult(self): """ Check if all clients finish requiring information from the server (The state is CLIENT_STATE.IDLE) """ - return len(self.get_idle_client()) == self._num_client + return len(self.get_idle_client()) == self._num_client_total def update_client_info(self, client_id, info: dict): """ @@ -143,7 +160,7 @@ def change_state(self, indices, state): def get_client_by_state(self, state): CLIENT_STATE.assert_value(state) - return [client_id for client_id, client_state in self._state_client if client_state == state] + return [client_id for client_id, client_state in self._state_client.items() if client_state == state] def get_consult_client(self): """ diff --git a/federatedscope/core/workers/server.py b/federatedscope/core/workers/server.py index f83d1c9c3..a9e9ed451 100644 --- a/federatedscope/core/workers/server.py +++ b/federatedscope/core/workers/server.py @@ -17,7 +17,7 @@ from federatedscope.core.auxiliaries.utils import merge_dict, Timeout, \ merge_param_dict from federatedscope.core.auxiliaries.trainer_builder import get_trainer -from federatedscope.core.auxiliaries.enums import STAGE +from federatedscope.core.auxiliaries.enums import STAGE, CLIENT_STATE from federatedscope.core.secret_sharing import AdditiveSecretSharing logger = logging.getLogger(__name__) @@ -133,23 +133,11 @@ def __init__(self, self._client_num = client_num self._total_round_num = total_round_num self.sample_client_num = int(self._cfg.federate.sample_client_num) - self.join_in_client_num = 0 self.join_in_info = dict() # Server state self.is_finish = False - # Sampler - if self._cfg.federate.sampler in ['uniform']: - self.sampler = get_sampler( - sample_strategy=self._cfg.federate.sampler, - client_num=self.client_num, - client_info=None) - else: - # Some type of sampler would be instantiated in trigger_for_start, - # since they need more information - self.sampler = None - # Current Timestamp self.cur_timestamp = 0 self.deadline_for_cur_round = 1 @@ -230,7 +218,7 @@ def run(self): """ # Begin: Broadcast model parameters and start to FL train - while self.join_in_client_num < self.client_num: + while self.client_manager.check_client_join_in(): msg = self.comm_manager.receive() self.msg_handlers[msg.msg_type](msg) @@ -612,35 +600,27 @@ def merge_eval_results_from_all_clients(self): def broadcast_model_para(self, msg_type='model_para', - sample_client_num=-1, - filter_unseen_clients=True): + sample_client_num=None): """ To broadcast the message to all clients or sampled clients Arguments: msg_type: 'model_para' or other user defined msg_type sample_client_num: the number of sampled clients in the broadcast - behavior. And sample_client_num = -1 denotes to broadcast to + behavior. And sample_client_num = None denotes to broadcast to all the clients. - filter_unseen_clients: whether filter out the unseen clients that - do not contribute to FL process by training on their local - data and uploading their local model update. The splitting is - useful to check participation generalization gap in [ICLR'22, - What Do We Mean by Generalization in Federated Learning?] - You may want to set it to be False when in evaluation stage - """ - if filter_unseen_clients: - # to filter out the unseen clients when sampling - self.sampler.change_state(self.unseen_clients_id, 'unseen') - - if sample_client_num > 0: - receiver = self.sampler.sample(size=sample_client_num) + """ + # Get the receivers + if sample_client_num: + receiver = self.client_manager.sample(size=sample_client_num) else: - # broadcast to all clients receiver = list(self.comm_manager.neighbors.keys()) - if msg_type == 'model_para': - self.sampler.change_state(receiver, 'working') + if msg_type == 'model_para': + # Training + self.client_manager.change_state(receiver, CLIENT_STATE.WORKING) + + # Inject noise if self._noise_injector is not None and msg_type == 'model_para': # Inject noise only when broadcast parameters for model_idx_i in range(len(self.models)): @@ -650,6 +630,7 @@ def broadcast_model_para(self, self._noise_injector(self._cfg, num_sample_clients, self.models[model_idx_i]) + # Prepare model parameters skip_broadcast = self._cfg.federate.method in ["local", "global"] if self.model_num > 1: model_para = [{} if skip_broadcast else model.state_dict() @@ -657,6 +638,7 @@ def broadcast_model_para(self, else: model_para = {} if skip_broadcast else self.model.state_dict() + # Send model parameters self.comm_manager.send( Message(msg_type=msg_type, sender=self.ID, @@ -668,10 +650,6 @@ def broadcast_model_para(self, for idx in range(self.model_num): self.aggregators[idx].reset() - if filter_unseen_clients: - # restore the state of the unseen clients within sampler - self.sampler.change_state(self.unseen_clients_id, 'seen') - def broadcast_client_address(self): """ To broadcast the communication addresses of clients (used for @@ -745,9 +723,9 @@ def check_client_join_in(self): """ if len(self._cfg.federate.join_in_info) != 0: - return len(self.join_in_info) == self.client_num + return self.client_manager.check_client_info() else: - return self.join_in_client_num == self.client_num + return self.client_manager.check_client_join_in() def trigger_for_start(self): """ @@ -761,6 +739,8 @@ def trigger_for_start(self): # Prepare for training # init sampler within the client manager after finishing exchanging information self.client_manager.init_sampler() + # Block unseen clients from training + self.client_manager.block_unseen_client() # change the deadline if the asyn.aggregator is `time up` if self._cfg.asyn.use and self._cfg.asyn.aggregator == 'time_up': @@ -868,7 +848,9 @@ def callback_funcs_model_para(self, message: Message): sender = message.sender timestamp = message.timestamp content = message.content - self.sampler.change_state(sender, 'idle') + + # After training, change the client status into idle + self.client_manager.change_state(sender, CLIENT_STATE.IDLE) # update the currency timestamp according to the received message assert timestamp >= self.cur_timestamp # for test @@ -916,11 +898,16 @@ def callback_funcs_for_join_in(self, message: Message): assert key in info self.join_in_info[sender] = info logger.info('Server: Client #{:d} has joined in !'.format(sender)) + + # Set the client status as idle + self.client_manager.change_state(sender, CLIENT_STATE.IDLE) + else: - self.join_in_client_num += 1 sender, address = message.sender, message.content + # Register client in client_manager + register_id = self.client_manager.register_client() if int(sender) == -1: # assign number to client - sender = self.join_in_client_num + sender = register_id self.comm_manager.add_neighbors(neighbor_id=sender, address=address) self.comm_manager.send( From 3e3f72dcaeda038c463d087b7ee0ade7381ac3bf Mon Sep 17 00:00:00 2001 From: "gaodawei.gdw" Date: Thu, 22 Sep 2022 13:54:14 +0800 Subject: [PATCH 6/7] record client_id in client_manager.py --- federatedscope/core/workers/client_manager.py | 70 ++++++------------- federatedscope/core/workers/server.py | 22 +++--- 2 files changed, 31 insertions(+), 61 deletions(-) diff --git a/federatedscope/core/workers/client_manager.py b/federatedscope/core/workers/client_manager.py index d640092dd..869b1c383 100644 --- a/federatedscope/core/workers/client_manager.py +++ b/federatedscope/core/workers/client_manager.py @@ -18,6 +18,8 @@ def __init__(self, self._num_client_join = 0 self._num_client_total = num_client + self._ids_client = [] + # the unseen clients indicate the ones that do not contribute to FL # process by training on their local data and uploading their local # model update. The splitting is useful to check participation @@ -32,7 +34,7 @@ def __init__(self, self._info_by_key = collections.defaultdict(dict) # Record the state of the clients (client_id counts from 1) - self._state_client = {client_id: CLIENT_STATE.OFFLINE for client_id in range(1, self._num_client_total+1)} + self._state_client = dict() self.sampler = None self.sample_strategy = sample_strategy @@ -41,43 +43,34 @@ def __assert_client(self, client_id): """ Check if the client_id is legal. """ - if client_id >= 1 and client_id <= self._num_client_join + 1: + if client_id in self._ids_client: pass else: raise IndexError(f"Client ID {client_id} doesn't exist.") + def get_all_client(self): + return self._ids_client + @property def join_in_client_num(self): return self._num_client_join - def register_client(self): - self._num_client_join += 1 - return self._num_client_join - - def join_in(self, client_id): + def register_client(self, sender): """ - Register the client, and assign client_id if it doesn't have + Register new client into client manager """ - # Count the number of client self._num_client_join += 1 + # Assign new id if necessary + if sender == -1: + register_id = self._num_client_join + else: + register_id = sender - if client_id == -1: - # Doesn't have client_id - client_id = self._num_client_join - - # Step into consulting (exchange information between server and client) - self._state_client[client_id] = CLIENT_STATE.IDLE - # Return the client_id - return client_id - - def set_offline(self, client_id): - """ - Set the state of the client as CLIENT_STATE.OFFLINE - """ - self.__assert_client(client_id) + # Record the client + self._ids_client.append(register_id) + self._state_client[register_id] = CLIENT_STATE.IDLE - self.change_state(client_id, CLIENT_STATE.OFFLINE) - self._num_client_join -= 1 + return register_id def block_unseen_client(self): """ @@ -86,12 +79,6 @@ def block_unseen_client(self): if self._num_client_unseen > 0: self.change_state(self._num_client_unseen, CLIENT_STATE.SIDELINE) - def finish_consult(self, client_id): - """ - Set the state of the client that finishes consulting as CLIENT_STATE.IDLE - """ - self.change_state(client_id, CLIENT_STATE.IDLE) - def check_client_join_in(self): """ Check if enough clients has joined in. @@ -104,12 +91,6 @@ def check_client_info(self): """ return len(self._info_by_client) == self._num_client_total - def check_client_consult(self): - """ - Check if all clients finish requiring information from the server (The state is CLIENT_STATE.IDLE) - """ - return len(self.get_idle_client()) == self._num_client_total - def update_client_info(self, client_id, info: dict): """ Update client information in the manager. @@ -162,12 +143,6 @@ def get_client_by_state(self, state): return [client_id for client_id, client_state in self._state_client.items() if client_state == state] - def get_consult_client(self): - """ - Return all the clients with state CLIENT_STATE.CONSULTING - """ - self.get_client_by_state(CLIENT_STATE.CONSULTING) - def get_idle_client(self): """ Return all the clients with state CLIENT_STATE.IDLE @@ -187,13 +162,13 @@ def init_sampler(self): client_info=self._info_by_key ) - def sample(self, size, perturb=False): + def sample(self, size, perturb=False, change_state=True): """Sample idle clients with the specific sampler Args: size: the number of sample size perturb: if perturb the clients before sampling - + change_state: if change the state into CLIENT_STATE.WORKING Returns: index of the sampled clients """ @@ -205,6 +180,7 @@ def sample(self, size, perturb=False): clients_idle = self.get_idle_client() # Sampling clients_sampled = self.sampler.sample(clients_idle, size, perturb) - # Change state for the sampled clients - self.change_state(clients_sampled, CLIENT_STATE.WORKING) + if change_state: + # Change state for the sampled clients + self.change_state(clients_sampled, CLIENT_STATE.WORKING) return clients_sampled diff --git a/federatedscope/core/workers/server.py b/federatedscope/core/workers/server.py index a9e9ed451..dfba9af4f 100644 --- a/federatedscope/core/workers/server.py +++ b/federatedscope/core/workers/server.py @@ -541,7 +541,7 @@ def merge_eval_results_from_all_clients(self): for client_id in eval_msg_buffer: if eval_msg_buffer[client_id] is None: continue - if client_id in self.unseen_clients_id: + if client_id in self.client_manager._id_client_unseen: eval_res_unseen_clients.append(eval_msg_buffer[client_id]) else: eval_res_participated_clients.append( @@ -616,10 +616,6 @@ def broadcast_model_para(self, else: receiver = list(self.comm_manager.neighbors.keys()) - if msg_type == 'model_para': - # Training - self.client_manager.change_state(receiver, CLIENT_STATE.WORKING) - # Inject noise if self._noise_injector is not None and msg_type == 'model_para': # Inject noise only when broadcast parameters @@ -740,6 +736,7 @@ def trigger_for_start(self): # init sampler within the client manager after finishing exchanging information self.client_manager.init_sampler() # Block unseen clients from training + self.client_manager.block_unseen_client() # change the deadline if the asyn.aggregator is `time up` @@ -826,8 +823,7 @@ def eval(self): self.check_and_save() else: # Preform evaluation in clients - self.broadcast_model_para(msg_type='evaluate', - filter_unseen_clients=False) + self.broadcast_model_para(msg_type='evaluate') def callback_funcs_model_para(self, message: Message): """ @@ -905,11 +901,12 @@ def callback_funcs_for_join_in(self, message: Message): else: sender, address = message.sender, message.content # Register client in client_manager - register_id = self.client_manager.register_client() - if int(sender) == -1: # assign number to client + register_id = self.client_manager.register_client(sender) + + self.comm_manager.add_neighbors(neighbor_id=sender, address=address) + + if register_id != sender: # assign number to client sender = register_id - self.comm_manager.add_neighbors(neighbor_id=sender, - address=address) self.comm_manager.send( Message(msg_type='assign_client_id', sender=self.ID, @@ -917,9 +914,6 @@ def callback_funcs_for_join_in(self, message: Message): state=self.state, timestamp=self.cur_timestamp, content=str(sender))) - else: - self.comm_manager.add_neighbors(neighbor_id=sender, - address=address) if len(self._cfg.federate.join_in_info) != 0: self.comm_manager.send( From cfa6f2713ed564511b0af8c15c2d77856664ee05 Mon Sep 17 00:00:00 2001 From: DavdGao Date: Thu, 22 Sep 2022 17:11:39 +0800 Subject: [PATCH 7/7] bug fix --- .../worker_as_attacker/server_attacker.py | 28 +++++----------- federatedscope/autotune/fedex/server.py | 33 +++++++++---------- .../core/auxiliaries/sampler_builder.py | 3 +- federatedscope/core/workers/server.py | 11 +++++-- federatedscope/gfl/fedsageplus/worker.py | 21 +++++++----- 5 files changed, 46 insertions(+), 50 deletions(-) diff --git a/federatedscope/attack/worker_as_attacker/server_attacker.py b/federatedscope/attack/worker_as_attacker/server_attacker.py index 651555740..0b7ab2dfc 100644 --- a/federatedscope/attack/worker_as_attacker/server_attacker.py +++ b/federatedscope/attack/worker_as_attacker/server_attacker.py @@ -1,4 +1,4 @@ -from federatedscope.core.auxiliaries.enums import STAGE +from federatedscope.core.auxiliaries.enums import STAGE, CLIENT_STATE from federatedscope.core.workers import Server from federatedscope.core.message import Message @@ -47,8 +47,7 @@ def __init__(self, def broadcast_model_para(self, msg_type='model_para', - sample_client_num=-1, - filter_unseen_clients=True): + sample_client_num=-1): """ To broadcast the message to all clients or sampled clients @@ -57,18 +56,8 @@ def broadcast_model_para(self, sample_client_num: the number of sampled clients in the broadcast behavior. And sample_client_num = -1 denotes to broadcast to all the clients. - filter_unseen_clients: whether filter out the unseen clients that - do not contribute to FL process by training on their local - data and uploading their local model update. The splitting is - useful to check participation generalization gap in [ICLR'22, - What Do We Mean by Generalization in Federated Learning?] - You may want to set it to be False when in evaluation stage """ - if filter_unseen_clients: - # to filter out the unseen clients when sampling - self.sampler.change_state(self.unseen_clients_id, 'unseen') - if sample_client_num > 0: # only activated at training process attacker_id = self._cfg.attack.attacker_id setting = self._cfg.attack.setting @@ -159,10 +148,6 @@ def broadcast_model_para(self, for idx in range(self.model_num): self.aggregators[idx].reset() - if filter_unseen_clients: - # restore the state of the unseen clients within sampler - self.sampler.change_state(self.unseen_clients_id, 'seen') - class PassiveServer(Server): ''' @@ -264,7 +249,9 @@ def callback_funcs_model_para(self, message: Message): return 'finish' round, sender, content = message.state, message.sender, message.content - self.sampler.change_state(sender, 'idle') + # After training, change the client status into idle + self.client_manager.change_state(sender, CLIENT_STATE.IDLE) + if round not in self.msg_buffer[STAGE.TRAIN]: self.msg_buffer[STAGE.TRAIN][round] = dict() self.msg_buffer[STAGE.TRAIN][round][sender] = content @@ -342,7 +329,10 @@ def callback_funcs_model_para(self, message: Message): return 'finish' round, sender, content = message.state, message.sender, message.content - self.sampler.change_state(sender, 'idle') + + # After training, change the client status into idle + self.client_manager.change_state(sender, CLIENT_STATE.IDLE) + if round not in self.msg_buffer[STAGE.TRAIN]: self.msg_buffer[STAGE.TRAIN][round] = dict() self.msg_buffer[STAGE.TRAIN][round][sender] = content diff --git a/federatedscope/autotune/fedex/server.py b/federatedscope/autotune/fedex/server.py index 17b998c30..4fe515dcb 100644 --- a/federatedscope/autotune/fedex/server.py +++ b/federatedscope/autotune/fedex/server.py @@ -8,7 +8,7 @@ from numpy.linalg import norm from scipy.special import logsumexp -from federatedscope.core.auxiliaries.enums import STAGE +from federatedscope.core.auxiliaries.enums import STAGE, CLIENT_STATE from federatedscope.core.message import Message from federatedscope.core.workers import Server from federatedscope.core.auxiliaries.utils import merge_dict @@ -147,23 +147,18 @@ def sample(self): def broadcast_model_para(self, msg_type='model_para', - sample_client_num=-1, - filter_unseen_clients=True): + sample_client_num=-1): """ To broadcast the message to all clients or sampled clients """ - if filter_unseen_clients: - # to filter out the unseen clients when sampling - self.sampler.change_state(self.unseen_clients_id, 'unseen') if sample_client_num > 0: - receiver = self.sampler.sample(size=sample_client_num) + receiver = self.client_manager.sample(size=sample_client_num) else: # broadcast to all clients receiver = list(self.comm_manager.neighbors.keys()) - if msg_type == 'model_para': - self.sampler.change_state(receiver, 'working') + # Inject noise if self._noise_injector is not None and msg_type == 'model_para': # Inject noise only when broadcast parameters for model_idx_i in range(len(self.models)): @@ -197,13 +192,11 @@ def broadcast_model_para(self, for idx in range(self.model_num): self.aggregators[idx].reset() - if filter_unseen_clients: - # restore the state of the unseen clients within sampler - self.sampler.change_state(self.unseen_clients_id, 'seen') def callback_funcs_model_para(self, message: Message): round, sender, content = message.state, message.sender, message.content - self.sampler.change_state(sender, 'idle') + # After training, change the client status into idle + self.client_manager.change_state(sender, CLIENT_STATE.IDLE) # For a new round if round not in self.msg_buffer[STAGE.TRAIN].keys(): self.msg_buffer[STAGE.TRAIN][round] = dict() @@ -285,7 +278,7 @@ def update_policy(self, feedbacks): self._trace['mle'][-1])) def check_and_move_on(self, - check_eval_result=False, + stage, min_received_num=None): """ To check the message_buffer, when enough messages are receiving, @@ -296,14 +289,14 @@ def check_and_move_on(self, min_received_num = self._cfg.federate.sample_client_num assert min_received_num <= self.sample_client_num - if check_eval_result: + if stage == STAGE.EVAL: min_received_num = len(list(self.comm_manager.neighbors.keys())) move_on_flag = True # To record whether moving to a new training # round or finishing the evaluation - if self.check_buffer(self.state, min_received_num, check_eval_result): + if self.check_buffer(self.state, min_received_num, stage): - if not check_eval_result: # in the training process + if stage == STAGE.TRAIN: # in the training process mab_feedbacks = list() # Get all the message train_msg_buffer = self.msg_buffer[STAGE.TRAIN][self.state] @@ -375,12 +368,16 @@ def check_and_move_on(self, 'evaluation.') self.eval() - else: # in the evaluation process + elif stage == STAGE.EVAL: # in the evaluation process # Get all the message & aggregate formatted_eval_res = self.merge_eval_results_from_all_clients() self.history_results = merge_dict(self.history_results, formatted_eval_res) self.check_and_save() + + else: + pass + else: move_on_flag = False diff --git a/federatedscope/core/auxiliaries/sampler_builder.py b/federatedscope/core/auxiliaries/sampler_builder.py index f5492fad3..dc46a1581 100644 --- a/federatedscope/core/auxiliaries/sampler_builder.py +++ b/federatedscope/core/auxiliaries/sampler_builder.py @@ -12,8 +12,7 @@ def get_sampler(sample_strategy, if sample_strategy == 'uniform': return UniformSampler() elif sample_strategy == 'group': - return GroupSampler(client_num=client_num, - client_info=client_info['client_resource'], + return GroupSampler(client_info=client_info['client_resource'], bins=bins) else: return None \ No newline at end of file diff --git a/federatedscope/core/workers/server.py b/federatedscope/core/workers/server.py index dfba9af4f..cf09102be 100644 --- a/federatedscope/core/workers/server.py +++ b/federatedscope/core/workers/server.py @@ -329,10 +329,13 @@ def check_and_move_on(self, 'evaluation.') self.eval() - else: + elif stage == STAGE.EVAL: # Receiving enough feedback in the evaluation process self._merge_and_format_eval_results() + else: + pass + else: move_on_flag = False @@ -903,10 +906,10 @@ def callback_funcs_for_join_in(self, message: Message): # Register client in client_manager register_id = self.client_manager.register_client(sender) - self.comm_manager.add_neighbors(neighbor_id=sender, address=address) - + # TODO: maybe we shouldn't support user-defined ID if register_id != sender: # assign number to client sender = register_id + self.comm_manager.add_neighbors(neighbor_id=sender, address=address) self.comm_manager.send( Message(msg_type='assign_client_id', sender=self.ID, @@ -914,6 +917,8 @@ def callback_funcs_for_join_in(self, message: Message): state=self.state, timestamp=self.cur_timestamp, content=str(sender))) + else: + self.comm_manager.add_neighbors(neighbor_id=sender, address=address) if len(self._cfg.federate.join_in_info) != 0: self.comm_manager.send( diff --git a/federatedscope/gfl/fedsageplus/worker.py b/federatedscope/gfl/fedsageplus/worker.py index 442e31249..dd96fcb76 100644 --- a/federatedscope/gfl/fedsageplus/worker.py +++ b/federatedscope/gfl/fedsageplus/worker.py @@ -4,7 +4,7 @@ from torch_geometric.loader import NeighborSampler -from federatedscope.core.auxiliaries.enums import STAGE +from federatedscope.core.auxiliaries.enums import STAGE, CLIENT_STATE from federatedscope.core.message import Message from federatedscope.core.workers.server import Server from federatedscope.core.workers.client import Client @@ -68,22 +68,27 @@ def callback_funcs_for_join_in(self, message: Message): assert key in info self.join_in_info[sender] = info logger.info('Server: Client #{:d} has joined in !'.format(sender)) + + # Set the client status as idle + self.client_manager.change_state(sender, CLIENT_STATE.IDLE) else: - self.join_in_client_num += 1 sender, address = message.sender, message.content - if int(sender) == -1: # assign number to client - sender = self.join_in_client_num - self.comm_manager.add_neighbors(neighbor_id=sender, - address=address) + + # Register client in client_manager + register_id = self.client_manager.register_client(sender) + + if register_id != sender: # assign number to client + sender = register_id + self.comm_manager.add_neighbors(neighbor_id=sender, address=address) self.comm_manager.send( Message(msg_type='assign_client_id', sender=self.ID, receiver=[sender], state=self.state, + timestamp=self.cur_timestamp, content=str(sender))) else: - self.comm_manager.add_neighbors(neighbor_id=sender, - address=address) + self.comm_manager.add_neighbors(neighbor_id=sender, address=address) if len(self._cfg.federate.join_in_info) != 0: self.comm_manager.send(