diff --git a/federatedscope/attack/worker_as_attacker/server_attacker.py b/federatedscope/attack/worker_as_attacker/server_attacker.py index 226568242..651555740 100644 --- a/federatedscope/attack/worker_as_attacker/server_attacker.py +++ b/federatedscope/attack/worker_as_attacker/server_attacker.py @@ -1,3 +1,4 @@ +from federatedscope.core.auxiliaries.enums import STAGE from federatedscope.core.workers import Server from federatedscope.core.message import Message @@ -244,15 +245,15 @@ def _reconstruct(self, model_para, batch_size, state, sender): def run_reconstruct(self, state_list=None, sender_list=None): if state_list is None: - state_list = self.msg_buffer['train'].keys() + state_list = self.msg_buffer[STAGE.TRAIN].keys() # After FL running, using gradient based reconstruction method to # recover client's private training data for state in state_list: if sender_list is None: - sender_list = self.msg_buffer['train'][state].keys() + sender_list = self.msg_buffer[STAGE.TRAIN][state].keys() for sender in sender_list: - content = self.msg_buffer['train'][state][sender] + content = self.msg_buffer[STAGE.TRAIN][state][sender] self._reconstruct(model_para=content[1], batch_size=content[0], state=state, @@ -264,9 +265,9 @@ def callback_funcs_model_para(self, message: Message): round, sender, content = message.state, message.sender, message.content self.sampler.change_state(sender, 'idle') - if round not in self.msg_buffer['train']: - self.msg_buffer['train'][round] = dict() - self.msg_buffer['train'][round][sender] = content + if round not in self.msg_buffer[STAGE.TRAIN]: + self.msg_buffer[STAGE.TRAIN][round] = dict() + self.msg_buffer[STAGE.TRAIN][round][sender] = content # run reconstruction before the clear of self.msg_buffer @@ -284,7 +285,7 @@ def callback_funcs_model_para(self, message: Message): name='image_state_{}_client_{}.png'.format( message.state, message.sender)) - self.check_and_move_on() + self.check_and_move_on(stage=STAGE.TRAIN) class PassivePIAServer(Server): @@ -342,9 +343,9 @@ def callback_funcs_model_para(self, message: Message): round, sender, content = message.state, message.sender, message.content self.sampler.change_state(sender, 'idle') - if round not in self.msg_buffer['train']: - self.msg_buffer['train'][round] = dict() - self.msg_buffer['train'][round][sender] = content + if round not in self.msg_buffer[STAGE.TRAIN]: + self.msg_buffer[STAGE.TRAIN][round] = dict() + self.msg_buffer[STAGE.TRAIN][round][sender] = content # collect the updates self.pia_attacker.collect_updates( @@ -359,7 +360,7 @@ def callback_funcs_model_para(self, message: Message): # TODO: put this line to `check_and_move_on` # currently, no way to know the latest `sender` self.aggregator.inc(content) - self.check_and_move_on() + self.check_and_move_on(stage=STAGE.TRAIN) if self.state == self.total_round_num: self.pia_attacker.train_property_classifier() diff --git a/federatedscope/autotune/fedex/server.py b/federatedscope/autotune/fedex/server.py index fac5e6488..2946ce0de 100644 --- a/federatedscope/autotune/fedex/server.py +++ b/federatedscope/autotune/fedex/server.py @@ -8,6 +8,7 @@ from numpy.linalg import norm from scipy.special import logsumexp +from federatedscope.core.auxiliaries.enums import STAGE from federatedscope.core.message import Message from federatedscope.core.workers import Server from federatedscope.core.auxiliaries.utils import merge_dict @@ -204,15 +205,15 @@ def callback_funcs_model_para(self, message: Message): round, sender, content = message.state, message.sender, message.content self.sampler.change_state(sender, 'idle') # For a new round - if round not in self.msg_buffer['train'].keys(): - self.msg_buffer['train'][round] = dict() + if round not in self.msg_buffer[STAGE.TRAIN].keys(): + self.msg_buffer[STAGE.TRAIN][round] = dict() - self.msg_buffer['train'][round][sender] = content + self.msg_buffer[STAGE.TRAIN][round][sender] = content if self._cfg.federate.online_aggr: self.aggregator.inc(tuple(content[0:2])) - return self.check_and_move_on() + return self.check_and_move_on(stage=STAGE.TRAIN) def update_policy(self, feedbacks): """Update the policy. This implementation is borrowed from the @@ -284,7 +285,7 @@ def update_policy(self, feedbacks): self._trace['mle'][-1])) def check_and_move_on(self, - check_eval_result=False, + stage, min_received_num=None): """ To check the message_buffer, when enough messages are receiving, @@ -295,17 +296,17 @@ def check_and_move_on(self, min_received_num = self._cfg.federate.sample_client_num assert min_received_num <= self.sample_client_num - if check_eval_result: + if stage == STAGE.EVAL: min_received_num = len(list(self.comm_manager.neighbors.keys())) move_on_flag = True # To record whether moving to a new training # round or finishing the evaluation - if self.check_buffer(self.state, min_received_num, check_eval_result): + if self.check_buffer(self.state, min_received_num, stage): - if not check_eval_result: # in the training process + if stage == STAGE.TRAIN: # in the training process mab_feedbacks = list() # Get all the message - train_msg_buffer = self.msg_buffer['train'][self.state] + train_msg_buffer = self.msg_buffer[STAGE.TRAIN][self.state] for model_idx in range(self.model_num): model = self.models[model_idx] aggregator = self.aggregators[model_idx] @@ -363,7 +364,7 @@ def check_and_move_on(self, f'----------- Starting a new training round (Round ' f'#{self.state}) -------------') # Clean the msg_buffer - self.msg_buffer['train'][self.state - 1].clear() + self.msg_buffer[STAGE.TRAIN][self.state - 1].clear() self.broadcast_model_para( msg_type='model_para', diff --git a/federatedscope/core/auxiliaries/enums.py b/federatedscope/core/auxiliaries/enums.py index ffaf0808a..53d9d1ec9 100644 --- a/federatedscope/core/auxiliaries/enums.py +++ b/federatedscope/core/auxiliaries/enums.py @@ -11,6 +11,12 @@ class MODE: FINETUNE = 'finetune' +class STAGE: + TRAIN = 'train' + EVAL = 'eval' + CONSULT = 'consult' + + class TRIGGER: ON_FIT_START = 'on_fit_start' ON_EPOCH_START = 'on_epoch_start' diff --git a/federatedscope/core/workers/client.py b/federatedscope/core/workers/client.py index 2dbc4bdd7..9f47aa60c 100644 --- a/federatedscope/core/workers/client.py +++ b/federatedscope/core/workers/client.py @@ -3,6 +3,7 @@ import sys import pickle +from federatedscope.core.auxiliaries.enums import STAGE from federatedscope.core.message import Message from federatedscope.core.communication import StandaloneCommManager, \ gRPCCommManager @@ -225,12 +226,12 @@ def callback_funcs_for_model_para(self, message: Message): # A fragment of the shared secret state, content, timestamp = message.state, message.content, \ message.timestamp - self.msg_buffer['train'][state].append(content) + self.msg_buffer[STAGE.TRAIN][state].append(content) - if len(self.msg_buffer['train'] + if len(self.msg_buffer[STAGE.TRAIN] [state]) == self._cfg.federate.client_num: # Check whether the received fragments are enough - model_list = self.msg_buffer['train'][state] + model_list = self.msg_buffer[STAGE.TRAIN][state] sample_size, first_aggregate_model_para = model_list[0] single_model_case = True if isinstance(first_aggregate_model_para, list): @@ -355,7 +356,7 @@ def callback_funcs_for_model_para(self, message: Message): single_model_case else \ [model_para_list[frame_idx] for model_para_list in model_para_list_all] - self.msg_buffer['train'][self.state] = [(sample_size, + self.msg_buffer[STAGE.TRAIN][self.state] = [(sample_size, content_frame)] else: if self._cfg.asyn.use: diff --git a/federatedscope/core/workers/server.py b/federatedscope/core/workers/server.py index 38ceb6435..226224de2 100644 --- a/federatedscope/core/workers/server.py +++ b/federatedscope/core/workers/server.py @@ -16,6 +16,7 @@ from federatedscope.core.auxiliaries.utils import merge_dict, Timeout, \ merge_param_dict from federatedscope.core.auxiliaries.trainer_builder import get_trainer +from federatedscope.core.auxiliaries.enums import STAGE from federatedscope.core.secret_sharing import AdditiveSecretSharing logger = logging.getLogger(__name__) @@ -251,9 +252,10 @@ def run(self): logger.info('Time out at the training round #{}'.format( self.state)) move_on_flag_eval = self.check_and_move_on( - min_received_num=min_received_num, - check_eval_result=True) + stage=STAGE.EVAL, + min_received_num=min_received_num) move_on_flag = self.check_and_move_on( + stage=STAGE.TRAIN, min_received_num=min_received_num) if not move_on_flag and not move_on_flag_eval: num_failure += 1 @@ -271,8 +273,8 @@ def run(self): f'Round #{self.state}) for {num_failure} time ' f'-------------') # TODO: Clean the msg_buffer - if self.state in self.msg_buffer['train']: - self.msg_buffer['train'][self.state].clear() + if self.state in self.msg_buffer[STAGE.TRAIN]: + self.msg_buffer[STAGE.TRAIN][self.state].clear() self.broadcast_model_para( msg_type='model_para', @@ -284,7 +286,7 @@ def run(self): self.terminate(msg_type='finish') def check_and_move_on(self, - check_eval_result=False, + stage, min_received_num=None): """ To check the message_buffer. When enough messages are receiving, @@ -292,8 +294,8 @@ def check_and_move_on(self, the next training round) would be triggered. Arguments: - check_eval_result (bool): If True, check the message buffer for - evaluation; and check the message buffer for training otherwise. + stage (str): The type of checked message buffer, chosen from MSGBUFFER.TRAIN, MSGBUFFER.EVAL and MSGBUFFER.CONSULT + min_received_num (int): The minimum number of received messages """ if min_received_num is None: if self._cfg.asyn.use: @@ -302,7 +304,7 @@ def check_and_move_on(self, min_received_num = self._cfg.federate.sample_client_num assert min_received_num <= self.sample_client_num - if check_eval_result and self._cfg.federate.mode.lower( + if stage == STAGE.EVAL and self._cfg.federate.mode.lower( ) == "standalone": # in evaluation stage and standalone simulation mode, we assume # strong synchronization that receives responses from all clients @@ -310,8 +312,8 @@ def check_and_move_on(self, move_on_flag = True # To record whether moving to a new training # round or finishing the evaluation - if self.check_buffer(self.state, min_received_num, check_eval_result): - if not check_eval_result: + if self.check_buffer(self.state, min_received_num, stage): + if stage == STAGE.TRAIN: # Receiving enough feedback in the training process aggregated_num = self._perform_federated_aggregation() @@ -329,8 +331,8 @@ def check_and_move_on(self, f'----------- Starting a new training round (Round ' f'#{self.state}) -------------') # Clean the msg_buffer - self.msg_buffer['train'][self.state - 1].clear() - self.msg_buffer['train'][self.state] = dict() + self.msg_buffer[STAGE.TRAIN][self.state - 1].clear() + self.msg_buffer[STAGE.TRAIN][self.state] = dict() self.staled_msg_buffer.clear() # Start a new training round self._start_new_training_round(aggregated_num) @@ -340,9 +342,12 @@ def check_and_move_on(self, 'evaluation.') self.eval() - else: + elif stage == STAGE.EVAL: # Receiving enough feedback in the evaluation process self._merge_and_format_eval_results() + else: + # TODO: consultation + pass else: move_on_flag = False @@ -393,8 +398,8 @@ def check_and_save(self): # Clean the clients evaluation msg buffer if not self._cfg.federate.make_global_eval: - round = max(self.msg_buffer['eval'].keys()) - self.msg_buffer['eval'][round].clear() + round = max(self.msg_buffer[STAGE.EVAL].keys()) + self.msg_buffer[STAGE.EVAL][round].clear() if self.state == self.total_round_num: # break out the loop for distributed mode @@ -404,7 +409,7 @@ def _perform_federated_aggregation(self): """ Perform federated aggregation and update the global model """ - train_msg_buffer = self.msg_buffer['train'][self.state] + train_msg_buffer = self.msg_buffer[STAGE.TRAIN][self.state] for model_idx in range(self.model_num): model = self.models[model_idx] aggregator = self.aggregators[model_idx] @@ -523,8 +528,8 @@ def save_client_eval_results(self): :return: """ - round = max(self.msg_buffer['eval'].keys()) - eval_msg_buffer = self.msg_buffer['eval'][round] + round = max(self.msg_buffer[STAGE.EVAL].keys()) + eval_msg_buffer = self.msg_buffer[STAGE.EVAL][round] with open(os.path.join(self._cfg.outdir, "eval_results.log"), "a") as outfile: @@ -545,8 +550,8 @@ def merge_eval_results_from_all_clients(self): :returns: the formatted merged results """ - round = max(self.msg_buffer['eval'].keys()) - eval_msg_buffer = self.msg_buffer['eval'][round] + round = max(self.msg_buffer[STAGE.EVAL].keys()) + eval_msg_buffer = self.msg_buffer[STAGE.EVAL][round] eval_res_participated_clients = [] eval_res_unseen_clients = [] for client_id in eval_msg_buffer: @@ -688,33 +693,21 @@ def broadcast_client_address(self): def check_buffer(self, cur_round, min_received_num, - check_eval_result=False): - """ - To check the message buffer + buffer_type): + """Check if the message buffer receives enough messages Arguments: - cur_round (int): The current round number - min_received_num (int): The minimal number of the receiving messages - check_eval_result (bool): To check training results for evaluation - results - :returns: Whether enough messages have been received or not - :rtype: bool - """ - - if check_eval_result: - if 'eval' not in self.msg_buffer.keys() or len( - self.msg_buffer['eval'].keys()) == 0: - return False - - buffer = self.msg_buffer['eval'] - cur_round = max(buffer.keys()) - cur_buffer = buffer[cur_round] - return len(cur_buffer) >= min_received_num - else: - if cur_round not in self.msg_buffer['train']: - cur_buffer = dict() - else: - cur_buffer = self.msg_buffer['train'][cur_round] + cur_round (int): The current round number + min_received_num (int): The minimal number of the receiving messages + buffer_type (str): Which field to check, chosen from MSGBUFFER.TRAIN, MSGBUFFER.EVAL and MSGBUFFER.CONSULT + + Return: + Whether enough messages have been received or not + """ + buffer = self.msg_buffer.get(buffer_type, dict()) + + if buffer_type == STAGE.TRAIN: + cur_buffer = buffer.get(cur_round, dict()) if self._cfg.asyn.use and self._cfg.asyn.aggregator == 'time_up': if self.cur_timestamp >= self.deadline_for_cur_round and len( cur_buffer) + len(self.staled_msg_buffer) == 0: @@ -738,6 +731,18 @@ def check_buffer(self, return len(cur_buffer)+len(self.staled_msg_buffer) >= \ min_received_num + elif buffer_type == STAGE.EVAL: + # Evaluation won't block the training process + cur_buffer = buffer.get(max(buffer.keys()), dict()) + return len(cur_buffer) >= min_received_num + + elif buffer_type == STAGE.CONSULT: + cur_buffer = buffer.get(cur_round, dict()) + return len(cur_buffer) >= min_received_num + + else: + raise NotImplementedError(f'Type of message buffer {buffer_type} is not implemented.') + def check_client_join_in(self): """ To check whether all the clients have joined in the FL course. @@ -802,7 +807,7 @@ def trigger_for_time_up(self, check_timestamp=None): return False self.cur_timestamp = self.deadline_for_cur_round - self.check_and_move_on() + self.check_and_move_on(stage=STAGE.TRAIN) return True def terminate(self, msg_type='finish'): @@ -891,10 +896,10 @@ def callback_funcs_model_para(self, message: Message): self.cur_timestamp = timestamp if round == self.state: - if round not in self.msg_buffer['train']: - self.msg_buffer['train'][round] = dict() + if round not in self.msg_buffer[STAGE.TRAIN]: + self.msg_buffer[STAGE.TRAIN][round] = dict() # Save the messages in this round - self.msg_buffer['train'][round][sender] = content + self.msg_buffer[STAGE.TRAIN][round][sender] = content elif round >= self.state - self.staleness_toleration: # Save the staled messages self.staled_msg_buffer.append((round, sender, content)) @@ -906,7 +911,7 @@ def callback_funcs_model_para(self, message: Message): if self._cfg.federate.online_aggr: self.aggregator.inc(content) - move_on_flag = self.check_and_move_on() + move_on_flag = self.check_and_move_on(stage=STAGE.TRAIN) if self._cfg.asyn.use and self._cfg.asyn.broadcast_manner == \ 'after_receiving': self.broadcast_model_para(msg_type='model_para', @@ -975,9 +980,9 @@ def callback_funcs_for_metrics(self, message: Message): sender = message.sender content = message.content - if round not in self.msg_buffer['eval'].keys(): - self.msg_buffer['eval'][round] = dict() + if round not in self.msg_buffer[STAGE.EVAL].keys(): + self.msg_buffer[STAGE.EVAL][round] = dict() - self.msg_buffer['eval'][round][sender] = content + self.msg_buffer[STAGE.EVAL][round][sender] = content - return self.check_and_move_on(check_eval_result=True) + return self.check_and_move_on(stage=STAGE.EVAL) diff --git a/federatedscope/gfl/fedsageplus/worker.py b/federatedscope/gfl/fedsageplus/worker.py index f1812598d..db758aeeb 100644 --- a/federatedscope/gfl/fedsageplus/worker.py +++ b/federatedscope/gfl/fedsageplus/worker.py @@ -4,6 +4,7 @@ from torch_geometric.loader import NeighborSampler +from federatedscope.core.auxiliaries.enums import STAGE from federatedscope.core.message import Message from federatedscope.core.workers.server import Server from federatedscope.core.workers.client import Client @@ -106,39 +107,39 @@ def callback_funcs_gradient(self, message: Message): round, _, content = message.state, message.sender, message.content gen_grad, ID = content # For a new round - if round not in self.msg_buffer['train'].keys(): - self.msg_buffer['train'][round] = dict() + if round not in self.msg_buffer[STAGE.TRAIN].keys(): + self.msg_buffer[STAGE.TRAIN][round] = dict() self.grad_cnt += 1 # Sum up all grad from other client - if ID not in self.msg_buffer['train'][round]: - self.msg_buffer['train'][round][ID] = dict() + if ID not in self.msg_buffer[STAGE.TRAIN][round]: + self.msg_buffer[STAGE.TRAIN][round][ID] = dict() for key in gen_grad.keys(): - self.msg_buffer['train'][round][ID][key] = torch.FloatTensor( + self.msg_buffer[STAGE.TRAIN][round][ID][key] = torch.FloatTensor( gen_grad[key].cpu()) else: for key in gen_grad.keys(): - self.msg_buffer['train'][round][ID][key] += torch.FloatTensor( + self.msg_buffer[STAGE.TRAIN][round][ID][key] += torch.FloatTensor( gen_grad[key].cpu()) - self.check_and_move_on() + self.check_and_move_on(stage=STAGE.TRAIN) - def check_and_move_on(self, check_eval_result=False): + def check_and_move_on(self, stage, **kwargs): client_IDs = [i for i in range(1, self.client_num + 1)] - if check_eval_result: + if stage == STAGE.EVAL or stage == STAGE.CONSULT: # all clients are participating in evaluation minimal_number = self.client_num - else: + elif stage == STAGE.TRAIN: # sampled clients are participating in training minimal_number = self.sample_client_num # Transmit model and embedding to get gradient back if self.check_buffer( - self.state, self.client_num + self.state, self.client_num, STAGE.TRAIN ) and self.state < self._cfg.fedsageplus.fedgen_epoch and self.state\ % 2 == 0: # FedGen: we should wait for all messages - for sender in self.msg_buffer['train'][self.state]: - content = self.msg_buffer['train'][self.state][sender] + for sender in self.msg_buffer[STAGE.TRAIN][self.state]: + content = self.msg_buffer[STAGE.TRAIN][self.state][sender] gen_para, embedding, label = content receiver_IDs = client_IDs[:sender - 1] + client_IDs[sender:] self.comm_manager.send( @@ -153,12 +154,12 @@ def check_and_move_on(self, check_eval_result=False): # Sum up gradient client-wisely and send back if self.check_buffer( - self.state, self.client_num + self.state, self.client_num, STAGE.TRAIN ) and self.state < self._cfg.fedsageplus.fedgen_epoch and self.state\ % 2 == 1 and self.grad_cnt == self.client_num * ( self.client_num - 1): - for ID in self.msg_buffer['train'][self.state]: - grad = self.msg_buffer['train'][self.state][ID] + for ID in self.msg_buffer[STAGE.TRAIN][self.state]: + grad = self.msg_buffer[STAGE.TRAIN][self.state][ID] self.comm_manager.send( Message(msg_type='gradient', sender=self.ID, @@ -170,7 +171,7 @@ def check_and_move_on(self, check_eval_result=False): self.state += 1 if self.check_buffer( - self.state, self.client_num + self.state, self.client_num, STAGE.TRAIN ) and self.state == self._cfg.fedsageplus.fedgen_epoch: self.state += 1 # Setup Clf_trainer for each client @@ -181,12 +182,12 @@ def check_and_move_on(self, check_eval_result=False): state=self.state)) if self.check_buffer( - self.state, minimal_number, check_eval_result + self.state, minimal_number, stage ) and self.state >= self._cfg.fedsageplus.fedgen_epoch: - if not check_eval_result: # in the training process + if stage == STAGE.TRAIN: # in the training process # Get all the message - train_msg_buffer = self.msg_buffer['train'][self.state] + train_msg_buffer = self.msg_buffer[STAGE.TRAIN][self.state] msg_list = list() for client_id in train_msg_buffer: msg_list.append(train_msg_buffer[client_id]) @@ -231,12 +232,15 @@ def check_and_move_on(self, check_eval_result=False): 'evaluation.') self.eval() - else: # in the evaluation process + elif stage == STAGE.EVAL: # in the evaluation process # Get all the message & aggregate formatted_eval_res = self.merge_eval_results_from_all_clients() self.history_results = merge_dict(self.history_results, formatted_eval_res) self.check_and_save() + else: + # TODO: consult stage + pass class FedSagePlusClient(Client): diff --git a/federatedscope/gfl/gcflplus/worker.py b/federatedscope/gfl/gcflplus/worker.py index 26a1d62df..5ecd4623a 100644 --- a/federatedscope/gfl/gcflplus/worker.py +++ b/federatedscope/gfl/gcflplus/worker.py @@ -3,6 +3,7 @@ import copy import numpy as np +from federatedscope.core.auxiliaries.enums import STAGE from federatedscope.core.message import Message from federatedscope.core.workers.server import Server from federatedscope.core.workers.client import Client @@ -44,7 +45,7 @@ def compute_update_norm(self, cluster): max_norm = -np.inf cluster_dWs = [] for key in cluster: - content = self.msg_buffer['train'][self.state][key] + content = self.msg_buffer[STAGE.TRAIN][self.state][key] _, model_para, client_dw, _ = content dW = {} for k in model_para.keys(): @@ -71,7 +72,7 @@ def check_and_move_on(self, check_eval_result=False): if not check_eval_result: # in the training process # Get all the message - train_msg_buffer = self.msg_buffer['train'][self.state] + train_msg_buffer = self.msg_buffer[STAGE.TRAIN][self.state] for model_idx in range(self.model_num): model = self.models[model_idx] aggregator = self.aggregators[model_idx] @@ -135,7 +136,7 @@ def check_and_move_on(self, check_eval_result=False): for cluster in self.cluster_indices: msg_list = list() for key in cluster: - content = self.msg_buffer['train'][self.state - + content = self.msg_buffer[STAGE.TRAIN][self.state - 1][key] train_data_size, model_para, client_dw, \ convGradsNorm = content @@ -161,7 +162,7 @@ def check_and_move_on(self, check_eval_result=False): f'----------- Starting a new traininground(Round ' f'#{self.state}) -------------') # Clean the msg_buffer - self.msg_buffer['train'][self.state - 1].clear() + self.msg_buffer[STAGE.TRAIN][self.state - 1].clear() else: # Final Evaluate