diff --git a/federatedscope/attack/worker_as_attacker/server_attacker.py b/federatedscope/attack/worker_as_attacker/server_attacker.py index 226568242..0b7ab2dfc 100644 --- a/federatedscope/attack/worker_as_attacker/server_attacker.py +++ b/federatedscope/attack/worker_as_attacker/server_attacker.py @@ -1,3 +1,4 @@ +from federatedscope.core.auxiliaries.enums import STAGE, CLIENT_STATE from federatedscope.core.workers import Server from federatedscope.core.message import Message @@ -46,8 +47,7 @@ def __init__(self, def broadcast_model_para(self, msg_type='model_para', - sample_client_num=-1, - filter_unseen_clients=True): + sample_client_num=-1): """ To broadcast the message to all clients or sampled clients @@ -56,18 +56,8 @@ def broadcast_model_para(self, sample_client_num: the number of sampled clients in the broadcast behavior. And sample_client_num = -1 denotes to broadcast to all the clients. - filter_unseen_clients: whether filter out the unseen clients that - do not contribute to FL process by training on their local - data and uploading their local model update. The splitting is - useful to check participation generalization gap in [ICLR'22, - What Do We Mean by Generalization in Federated Learning?] - You may want to set it to be False when in evaluation stage """ - if filter_unseen_clients: - # to filter out the unseen clients when sampling - self.sampler.change_state(self.unseen_clients_id, 'unseen') - if sample_client_num > 0: # only activated at training process attacker_id = self._cfg.attack.attacker_id setting = self._cfg.attack.setting @@ -158,10 +148,6 @@ def broadcast_model_para(self, for idx in range(self.model_num): self.aggregators[idx].reset() - if filter_unseen_clients: - # restore the state of the unseen clients within sampler - self.sampler.change_state(self.unseen_clients_id, 'seen') - class PassiveServer(Server): ''' @@ -244,15 +230,15 @@ def _reconstruct(self, model_para, batch_size, state, sender): def run_reconstruct(self, state_list=None, sender_list=None): if state_list is None: - state_list = self.msg_buffer['train'].keys() + state_list = self.msg_buffer[STAGE.TRAIN].keys() # After FL running, using gradient based reconstruction method to # recover client's private training data for state in state_list: if sender_list is None: - sender_list = self.msg_buffer['train'][state].keys() + sender_list = self.msg_buffer[STAGE.TRAIN][state].keys() for sender in sender_list: - content = self.msg_buffer['train'][state][sender] + content = self.msg_buffer[STAGE.TRAIN][state][sender] self._reconstruct(model_para=content[1], batch_size=content[0], state=state, @@ -263,10 +249,12 @@ def callback_funcs_model_para(self, message: Message): return 'finish' round, sender, content = message.state, message.sender, message.content - self.sampler.change_state(sender, 'idle') - if round not in self.msg_buffer['train']: - self.msg_buffer['train'][round] = dict() - self.msg_buffer['train'][round][sender] = content + # After training, change the client status into idle + self.client_manager.change_state(sender, CLIENT_STATE.IDLE) + + if round not in self.msg_buffer[STAGE.TRAIN]: + self.msg_buffer[STAGE.TRAIN][round] = dict() + self.msg_buffer[STAGE.TRAIN][round][sender] = content # run reconstruction before the clear of self.msg_buffer @@ -284,7 +272,7 @@ def callback_funcs_model_para(self, message: Message): name='image_state_{}_client_{}.png'.format( message.state, message.sender)) - self.check_and_move_on() + self.check_and_move_on(stage=STAGE.TRAIN) class PassivePIAServer(Server): @@ -341,10 +329,13 @@ def callback_funcs_model_para(self, message: Message): return 'finish' round, sender, content = message.state, message.sender, message.content - self.sampler.change_state(sender, 'idle') - if round not in self.msg_buffer['train']: - self.msg_buffer['train'][round] = dict() - self.msg_buffer['train'][round][sender] = content + + # After training, change the client status into idle + self.client_manager.change_state(sender, CLIENT_STATE.IDLE) + + if round not in self.msg_buffer[STAGE.TRAIN]: + self.msg_buffer[STAGE.TRAIN][round] = dict() + self.msg_buffer[STAGE.TRAIN][round][sender] = content # collect the updates self.pia_attacker.collect_updates( @@ -359,7 +350,7 @@ def callback_funcs_model_para(self, message: Message): # TODO: put this line to `check_and_move_on` # currently, no way to know the latest `sender` self.aggregator.inc(content) - self.check_and_move_on() + self.check_and_move_on(stage=STAGE.TRAIN) if self.state == self.total_round_num: self.pia_attacker.train_property_classifier() diff --git a/federatedscope/autotune/fedex/server.py b/federatedscope/autotune/fedex/server.py index fac5e6488..4fe515dcb 100644 --- a/federatedscope/autotune/fedex/server.py +++ b/federatedscope/autotune/fedex/server.py @@ -8,6 +8,7 @@ from numpy.linalg import norm from scipy.special import logsumexp +from federatedscope.core.auxiliaries.enums import STAGE, CLIENT_STATE from federatedscope.core.message import Message from federatedscope.core.workers import Server from federatedscope.core.auxiliaries.utils import merge_dict @@ -146,23 +147,18 @@ def sample(self): def broadcast_model_para(self, msg_type='model_para', - sample_client_num=-1, - filter_unseen_clients=True): + sample_client_num=-1): """ To broadcast the message to all clients or sampled clients """ - if filter_unseen_clients: - # to filter out the unseen clients when sampling - self.sampler.change_state(self.unseen_clients_id, 'unseen') if sample_client_num > 0: - receiver = self.sampler.sample(size=sample_client_num) + receiver = self.client_manager.sample(size=sample_client_num) else: # broadcast to all clients receiver = list(self.comm_manager.neighbors.keys()) - if msg_type == 'model_para': - self.sampler.change_state(receiver, 'working') + # Inject noise if self._noise_injector is not None and msg_type == 'model_para': # Inject noise only when broadcast parameters for model_idx_i in range(len(self.models)): @@ -196,18 +192,16 @@ def broadcast_model_para(self, for idx in range(self.model_num): self.aggregators[idx].reset() - if filter_unseen_clients: - # restore the state of the unseen clients within sampler - self.sampler.change_state(self.unseen_clients_id, 'seen') def callback_funcs_model_para(self, message: Message): round, sender, content = message.state, message.sender, message.content - self.sampler.change_state(sender, 'idle') + # After training, change the client status into idle + self.client_manager.change_state(sender, CLIENT_STATE.IDLE) # For a new round - if round not in self.msg_buffer['train'].keys(): - self.msg_buffer['train'][round] = dict() + if round not in self.msg_buffer[STAGE.TRAIN].keys(): + self.msg_buffer[STAGE.TRAIN][round] = dict() - self.msg_buffer['train'][round][sender] = content + self.msg_buffer[STAGE.TRAIN][round][sender] = content if self._cfg.federate.online_aggr: self.aggregator.inc(tuple(content[0:2])) @@ -284,7 +278,7 @@ def update_policy(self, feedbacks): self._trace['mle'][-1])) def check_and_move_on(self, - check_eval_result=False, + stage, min_received_num=None): """ To check the message_buffer, when enough messages are receiving, @@ -295,17 +289,17 @@ def check_and_move_on(self, min_received_num = self._cfg.federate.sample_client_num assert min_received_num <= self.sample_client_num - if check_eval_result: + if stage == STAGE.EVAL: min_received_num = len(list(self.comm_manager.neighbors.keys())) move_on_flag = True # To record whether moving to a new training # round or finishing the evaluation - if self.check_buffer(self.state, min_received_num, check_eval_result): + if self.check_buffer(self.state, min_received_num, stage): - if not check_eval_result: # in the training process + if stage == STAGE.TRAIN: # in the training process mab_feedbacks = list() # Get all the message - train_msg_buffer = self.msg_buffer['train'][self.state] + train_msg_buffer = self.msg_buffer[STAGE.TRAIN][self.state] for model_idx in range(self.model_num): model = self.models[model_idx] aggregator = self.aggregators[model_idx] @@ -363,7 +357,7 @@ def check_and_move_on(self, f'----------- Starting a new training round (Round ' f'#{self.state}) -------------') # Clean the msg_buffer - self.msg_buffer['train'][self.state - 1].clear() + self.msg_buffer[STAGE.TRAIN][self.state - 1].clear() self.broadcast_model_para( msg_type='model_para', @@ -374,12 +368,16 @@ def check_and_move_on(self, 'evaluation.') self.eval() - else: # in the evaluation process + elif stage == STAGE.EVAL: # in the evaluation process # Get all the message & aggregate formatted_eval_res = self.merge_eval_results_from_all_clients() self.history_results = merge_dict(self.history_results, formatted_eval_res) self.check_and_save() + + else: + pass + else: move_on_flag = False diff --git a/federatedscope/core/auxiliaries/enums.py b/federatedscope/core/auxiliaries/enums.py index ffaf0808a..024004189 100644 --- a/federatedscope/core/auxiliaries/enums.py +++ b/federatedscope/core/auxiliaries/enums.py @@ -1,4 +1,14 @@ -class MODE: +class BasicEnum(object): + @classmethod + def assert_value(cls, value): + """ + Check if the **value** is legal for the given class (If the value equals one of the class attributes) + """ + if not value in [v for k, v in cls.__dict__.items() if not k.startswith('__')]: + raise ValueError(f"Value {value} is not in {cls.__name__}.") + + +class MODE(BasicEnum): """ Note: @@ -11,7 +21,13 @@ class MODE: FINETUNE = 'finetune' -class TRIGGER: +class STAGE(BasicEnum): + TRAIN = 'train' + EVAL = 'eval' + CONSULT = 'consult' + + +class TRIGGER(BasicEnum): ON_FIT_START = 'on_fit_start' ON_EPOCH_START = 'on_epoch_start' ON_BATCH_START = 'on_batch_start' @@ -30,8 +46,16 @@ def contains(cls, item): ] -class LIFECYCLE: +class LIFECYCLE(BasicEnum): ROUTINE = 'routine' EPOCH = 'epoch' BATCH = 'batch' NONE = None + + +class CLIENT_STATE(BasicEnum): + OFFLINE = -1 # not join in + CONSULTING = 2 # join in and is consulting + IDLE = 1 # join in but not working, available for training + WORKING = 0 # join in and is working + SIDELINE = 0 # join in but won't participate in federated training diff --git a/federatedscope/core/auxiliaries/sampler_builder.py b/federatedscope/core/auxiliaries/sampler_builder.py index 0b7d2ff44..dc46a1581 100644 --- a/federatedscope/core/auxiliaries/sampler_builder.py +++ b/federatedscope/core/auxiliaries/sampler_builder.py @@ -5,16 +5,14 @@ logger = logging.getLogger(__name__) -def get_sampler(sample_strategy='uniform', +def get_sampler(sample_strategy, client_num=None, client_info=None, bins=10): if sample_strategy == 'uniform': - return UniformSampler(client_num=client_num) + return UniformSampler() elif sample_strategy == 'group': - return GroupSampler(client_num=client_num, - client_info=client_info, + return GroupSampler(client_info=client_info['client_resource'], bins=bins) else: - raise ValueError( - f"The sample strategy {sample_strategy} has not been provided.") + return None \ No newline at end of file diff --git a/federatedscope/core/sampler.py b/federatedscope/core/sampler.py index 6438cf822..7cce752c9 100644 --- a/federatedscope/core/sampler.py +++ b/federatedscope/core/sampler.py @@ -5,56 +5,31 @@ class Sampler(ABC): """ - The strategies of sampling clients for each training round - - Arguments: - client_state: a dict to manager the state of clients (idle or busy) + The strategy of sampling """ - def __init__(self, client_num): - self.client_state = np.asarray([1] * (client_num + 1)) - # Set the state of server (index=0) to 'working' - self.client_state[0] = 0 + def __init__(self): + pass @abstractmethod - def sample(self, size): + def sample(self, *args, **kwargs): raise NotImplementedError - def change_state(self, indices, state): - """ - To modify the state of clients (idle or working) - """ - if isinstance(indices, list) or isinstance(indices, np.ndarray): - all_idx = indices - else: - all_idx = [indices] - for idx in all_idx: - if state in ['idle', 'seen']: - self.client_state[idx] = 1 - elif state in ['working', 'unseen']: - self.client_state[idx] = 0 - else: - raise ValueError( - f"The state of client should be one of " - f"['idle', 'working', 'unseen], but got {state}") - class UniformSampler(Sampler): """ - To uniformly sample the clients from all the idle clients + A stateless sampler that samples items uniformly """ - def __init__(self, client_num): - super(UniformSampler, self).__init__(client_num) + def __init__(self): + super(UniformSampler, self).__init__() - def sample(self, size): + def sample(self, client_idle, size, *args, **kwargs): """ To sample clients """ - idle_clients = np.nonzero(self.client_state)[0] - sampled_clients = np.random.choice(idle_clients, - size=size, - replace=False).tolist() - self.change_state(sampled_clients, 'working') - return sampled_clients + sampled_items = np.random.choice(client_idle, + size=size, + replace=False).tolist() + return sampled_items class GroupSampler(Sampler): @@ -62,23 +37,12 @@ class GroupSampler(Sampler): To grouply sample the clients based on their responsiveness (or other client information of the clients) """ - def __init__(self, client_num, client_info, bins=10): - super(GroupSampler, self).__init__(client_num) + def __init__(self, client_info, bins=10): + super(GroupSampler, self).__init__() self.bins = bins - self.update_client_info(client_info) + self.client_info = client_info self.candidate_iterator = self.partition() - def update_client_info(self, client_info): - """ - To update the client information - """ - self.client_info = np.asarray( - [1.0] + [x for x in client_info - ]) # client_info[0] is preversed for the server - assert len(self.client_info) == len( - self.client_state - ), "The first dimension of client_info is mismatched with client_num" - def partition(self): """ To partition the clients into groups according to the client @@ -87,7 +51,9 @@ def partition(self): Arguments: :returns: a iteration of candidates """ - sorted_index = np.argsort(self.client_info) + # sort client_info by xx + sorted_index = sorted(self.client_info.keys(), key=lambda x: self.client_info[x]) + # bin的长度 num_per_bins = np.int(len(sorted_index) / self.bins) # grouped clients @@ -105,11 +71,14 @@ def permutation(self): return iter(candidates) - def sample(self, size, shuffle=False): + def sample(self, clients_idle, size, perturb=False): """ To sample clients """ - if shuffle: + if self.candidate_iterator is None: + self.partition() + + if perturb: self.candidate_iterator = self.permutation() sampled_clients = list() @@ -117,15 +86,14 @@ def sample(self, size, shuffle=False): # To find an idle client while True: try: - item = next(self.candidate_iterator) + client_id = next(self.candidate_iterator) except StopIteration: self.candidate_iterator = self.permutation() - item = next(self.candidate_iterator) + client_id = next(self.candidate_iterator) - if self.client_state[item] == 1: + if client_id in clients_idle: break - sampled_clients.append(item) - self.change_state(item, 'working') + sampled_clients.append(client_id) return sampled_clients diff --git a/federatedscope/core/workers/client.py b/federatedscope/core/workers/client.py index 2dbc4bdd7..9f47aa60c 100644 --- a/federatedscope/core/workers/client.py +++ b/federatedscope/core/workers/client.py @@ -3,6 +3,7 @@ import sys import pickle +from federatedscope.core.auxiliaries.enums import STAGE from federatedscope.core.message import Message from federatedscope.core.communication import StandaloneCommManager, \ gRPCCommManager @@ -225,12 +226,12 @@ def callback_funcs_for_model_para(self, message: Message): # A fragment of the shared secret state, content, timestamp = message.state, message.content, \ message.timestamp - self.msg_buffer['train'][state].append(content) + self.msg_buffer[STAGE.TRAIN][state].append(content) - if len(self.msg_buffer['train'] + if len(self.msg_buffer[STAGE.TRAIN] [state]) == self._cfg.federate.client_num: # Check whether the received fragments are enough - model_list = self.msg_buffer['train'][state] + model_list = self.msg_buffer[STAGE.TRAIN][state] sample_size, first_aggregate_model_para = model_list[0] single_model_case = True if isinstance(first_aggregate_model_para, list): @@ -355,7 +356,7 @@ def callback_funcs_for_model_para(self, message: Message): single_model_case else \ [model_para_list[frame_idx] for model_para_list in model_para_list_all] - self.msg_buffer['train'][self.state] = [(sample_size, + self.msg_buffer[STAGE.TRAIN][self.state] = [(sample_size, content_frame)] else: if self._cfg.asyn.use: diff --git a/federatedscope/core/workers/client_manager.py b/federatedscope/core/workers/client_manager.py new file mode 100644 index 000000000..869b1c383 --- /dev/null +++ b/federatedscope/core/workers/client_manager.py @@ -0,0 +1,186 @@ +import collections + +import numpy as np + +from federatedscope.core.auxiliaries.sampler_builder import get_sampler +from federatedscope.core.auxiliaries.enums import CLIENT_STATE + +import logging + +logger = logging.getLogger(__name__) + + +class ClientManager(object): + def __init__(self, + num_client, + sample_strategy, + id_client_unseen=[]): + self._num_client_join = 0 + self._num_client_total = num_client + + self._ids_client = [] + + # the unseen clients indicate the ones that do not contribute to FL + # process by training on their local data and uploading their local + # model update. The splitting is useful to check participation + # generalization gap in + # [ICLR'22, What Do We Mean by Generalization in Federated Learning?] + self._num_client_unseen = len(id_client_unseen) + self._id_client_unseen = id_client_unseen + + # TODO: achieve by a two-leveled index rather than saving twice + # Used to maintain the information collected from the clients + self._info_by_client = collections.defaultdict(dict) + self._info_by_key = collections.defaultdict(dict) + + # Record the state of the clients (client_id counts from 1) + self._state_client = dict() + + self.sampler = None + self.sample_strategy = sample_strategy + + def __assert_client(self, client_id): + """ + Check if the client_id is legal. + """ + if client_id in self._ids_client: + pass + else: + raise IndexError(f"Client ID {client_id} doesn't exist.") + + def get_all_client(self): + return self._ids_client + + @property + def join_in_client_num(self): + return self._num_client_join + + def register_client(self, sender): + """ + Register new client into client manager + """ + self._num_client_join += 1 + # Assign new id if necessary + if sender == -1: + register_id = self._num_client_join + else: + register_id = sender + + # Record the client + self._ids_client.append(register_id) + self._state_client[register_id] = CLIENT_STATE.IDLE + + return register_id + + def block_unseen_client(self): + """ + Set the state of the client as CLIENT_STATE.SIDELINE + """ + if self._num_client_unseen > 0: + self.change_state(self._num_client_unseen, CLIENT_STATE.SIDELINE) + + def check_client_join_in(self): + """ + Check if enough clients has joined in. + """ + return self._num_client_join == self._num_client_total + + def check_client_info(self): + """ + Check if enough information is collected from the clients + """ + return len(self._info_by_client) == self._num_client_total + + def update_client_info(self, client_id, info: dict): + """ + Update client information in the manager. + """ + if client_id in self._info_by_client: + logger.info(f"Information of Client #{client_id} is updated by {info}.") + self._info_by_client.update(info) + + for k, v in info.items(): + self._info_by_key[k][client_id] = v + + def del_client_info(self, client_id): + """ + Delete the client information. + """ + if client_id in self._info_by_client: + del self._info_by_client[client_id] + for k in self._info_by_client: + if client_id in self._info_by_client[k]: + del self._info_by_client[k][client_id] + + def get_info_by_client(self, client_id): + """ + Get the client information by client_id + """ + return self._info_by_client.get(client_id, None) + + def get_info_by_key(self, key): + """ + Get the client information by key + """ + return self._info_by_client.get(key, None) + + def change_state(self, indices, state): + """ + To modify the state of clients (idle or working) + """ + CLIENT_STATE.assert_value(state) + + if isinstance(indices, list) or isinstance(indices, np.ndarray): + client_idx = indices + else: + client_idx = [indices] + + for client_id in client_idx: + self._state_client[client_id] = state + + def get_client_by_state(self, state): + CLIENT_STATE.assert_value(state) + + return [client_id for client_id, client_state in self._state_client.items() if client_state == state] + + def get_idle_client(self): + """ + Return all the clients with state CLIENT_STATE.IDLE + """ + return self.get_client_by_state(CLIENT_STATE.IDLE) + + def init_sampler(self): + """ + Considering the sampling strategy may need client information, we should + initialize it after finishing client information collection, or + re-initialize it after updating client information. + """ + # To sample clients during training + self.sampler = get_sampler( + sample_strategy=self.sample_strategy, + client_num=self._num_client_total, + client_info=self._info_by_key + ) + + def sample(self, size, perturb=False, change_state=True): + """Sample idle clients with the specific sampler + + Args: + size: the number of sample size + perturb: if perturb the clients before sampling + change_state: if change the state into CLIENT_STATE.WORKING + Returns: + index of the sampled clients + """ + if self.sampler is None: + # When we call the function sample, we default that all client information have been collected. + self.init_sampler() + + # Obtain the idle clients + clients_idle = self.get_idle_client() + # Sampling + clients_sampled = self.sampler.sample(clients_idle, size, perturb) + if change_state: + # Change state for the sampled clients + self.change_state(clients_sampled, CLIENT_STATE.WORKING) + return clients_sampled diff --git a/federatedscope/core/workers/server.py b/federatedscope/core/workers/server.py index 38ceb6435..cf09102be 100644 --- a/federatedscope/core/workers/server.py +++ b/federatedscope/core/workers/server.py @@ -11,11 +11,13 @@ from federatedscope.core.communication import StandaloneCommManager, \ gRPCCommManager from federatedscope.core.workers import Worker +from federatedscope.core.workers.client_manager import ClientManager from federatedscope.core.auxiliaries.aggregator_builder import get_aggregator from federatedscope.core.auxiliaries.sampler_builder import get_sampler from federatedscope.core.auxiliaries.utils import merge_dict, Timeout, \ merge_param_dict from federatedscope.core.auxiliaries.trainer_builder import get_trainer +from federatedscope.core.auxiliaries.enums import STAGE, CLIENT_STATE from federatedscope.core.secret_sharing import AdditiveSecretSharing logger = logging.getLogger(__name__) @@ -48,7 +50,7 @@ def __init__(self, total_round_num=10, device='cpu', strategy=None, - unseen_clients_id=None, + id_clients_unseen=[], **kwargs): super(Server, self).__init__(ID, state, config, model, strategy) @@ -123,33 +125,19 @@ def __init__(self, ]) # Initialize the number of joined-in clients + self.client_manager = ClientManager( + client_num, + sample_strategy=self._cfg.federate.sampler, + id_client_unseen=id_clients_unseen) + self._client_num = client_num self._total_round_num = total_round_num self.sample_client_num = int(self._cfg.federate.sample_client_num) - self.join_in_client_num = 0 self.join_in_info = dict() - # the unseen clients indicate the ones that do not contribute to FL - # process by training on their local data and uploading their local - # model update. The splitting is useful to check participation - # generalization gap in - # [ICLR'22, What Do We Mean by Generalization in Federated Learning?] - self.unseen_clients_id = [] if unseen_clients_id is None \ - else unseen_clients_id # Server state self.is_finish = False - # Sampler - if self._cfg.federate.sampler in ['uniform']: - self.sampler = get_sampler( - sample_strategy=self._cfg.federate.sampler, - client_num=self.client_num, - client_info=None) - else: - # Some type of sampler would be instantiated in trigger_for_start, - # since they need more information - self.sampler = None - # Current Timestamp self.cur_timestamp = 0 self.deadline_for_cur_round = 1 @@ -230,7 +218,7 @@ def run(self): """ # Begin: Broadcast model parameters and start to FL train - while self.join_in_client_num < self.client_num: + while self.client_manager.check_client_join_in(): msg = self.comm_manager.receive() self.msg_handlers[msg.msg_type](msg) @@ -251,9 +239,10 @@ def run(self): logger.info('Time out at the training round #{}'.format( self.state)) move_on_flag_eval = self.check_and_move_on( - min_received_num=min_received_num, - check_eval_result=True) + stage=STAGE.EVAL, + min_received_num=min_received_num) move_on_flag = self.check_and_move_on( + stage=STAGE.TRAIN, min_received_num=min_received_num) if not move_on_flag and not move_on_flag_eval: num_failure += 1 @@ -271,8 +260,8 @@ def run(self): f'Round #{self.state}) for {num_failure} time ' f'-------------') # TODO: Clean the msg_buffer - if self.state in self.msg_buffer['train']: - self.msg_buffer['train'][self.state].clear() + if self.state in self.msg_buffer[STAGE.TRAIN]: + self.msg_buffer[STAGE.TRAIN][self.state].clear() self.broadcast_model_para( msg_type='model_para', @@ -284,7 +273,7 @@ def run(self): self.terminate(msg_type='finish') def check_and_move_on(self, - check_eval_result=False, + stage, min_received_num=None): """ To check the message_buffer. When enough messages are receiving, @@ -292,8 +281,8 @@ def check_and_move_on(self, the next training round) would be triggered. Arguments: - check_eval_result (bool): If True, check the message buffer for - evaluation; and check the message buffer for training otherwise. + stage (str): The type of checked message buffer, chosen from MSGBUFFER.TRAIN, MSGBUFFER.EVAL and MSGBUFFER.CONSULT + min_received_num (int): The minimum number of received messages """ if min_received_num is None: if self._cfg.asyn.use: @@ -302,7 +291,7 @@ def check_and_move_on(self, min_received_num = self._cfg.federate.sample_client_num assert min_received_num <= self.sample_client_num - if check_eval_result and self._cfg.federate.mode.lower( + if stage == STAGE.EVAL and self._cfg.federate.mode.lower( ) == "standalone": # in evaluation stage and standalone simulation mode, we assume # strong synchronization that receives responses from all clients @@ -310,8 +299,8 @@ def check_and_move_on(self, move_on_flag = True # To record whether moving to a new training # round or finishing the evaluation - if self.check_buffer(self.state, min_received_num, check_eval_result): - if not check_eval_result: + if self.check_buffer(self.state, min_received_num, stage): + if stage == STAGE.TRAIN: # Receiving enough feedback in the training process aggregated_num = self._perform_federated_aggregation() @@ -329,8 +318,8 @@ def check_and_move_on(self, f'----------- Starting a new training round (Round ' f'#{self.state}) -------------') # Clean the msg_buffer - self.msg_buffer['train'][self.state - 1].clear() - self.msg_buffer['train'][self.state] = dict() + self.msg_buffer[STAGE.TRAIN][self.state - 1].clear() + self.msg_buffer[STAGE.TRAIN][self.state] = dict() self.staled_msg_buffer.clear() # Start a new training round self._start_new_training_round(aggregated_num) @@ -340,10 +329,13 @@ def check_and_move_on(self, 'evaluation.') self.eval() - else: + elif stage == STAGE.EVAL: # Receiving enough feedback in the evaluation process self._merge_and_format_eval_results() + else: + pass + else: move_on_flag = False @@ -393,8 +385,8 @@ def check_and_save(self): # Clean the clients evaluation msg buffer if not self._cfg.federate.make_global_eval: - round = max(self.msg_buffer['eval'].keys()) - self.msg_buffer['eval'][round].clear() + round = max(self.msg_buffer[STAGE.EVAL].keys()) + self.msg_buffer[STAGE.EVAL][round].clear() if self.state == self.total_round_num: # break out the loop for distributed mode @@ -404,7 +396,7 @@ def _perform_federated_aggregation(self): """ Perform federated aggregation and update the global model """ - train_msg_buffer = self.msg_buffer['train'][self.state] + train_msg_buffer = self.msg_buffer[STAGE.TRAIN][self.state] for model_idx in range(self.model_num): model = self.models[model_idx] aggregator = self.aggregators[model_idx] @@ -523,8 +515,8 @@ def save_client_eval_results(self): :return: """ - round = max(self.msg_buffer['eval'].keys()) - eval_msg_buffer = self.msg_buffer['eval'][round] + round = max(self.msg_buffer[STAGE.EVAL].keys()) + eval_msg_buffer = self.msg_buffer[STAGE.EVAL][round] with open(os.path.join(self._cfg.outdir, "eval_results.log"), "a") as outfile: @@ -545,14 +537,14 @@ def merge_eval_results_from_all_clients(self): :returns: the formatted merged results """ - round = max(self.msg_buffer['eval'].keys()) - eval_msg_buffer = self.msg_buffer['eval'][round] + round = max(self.msg_buffer[STAGE.EVAL].keys()) + eval_msg_buffer = self.msg_buffer[STAGE.EVAL][round] eval_res_participated_clients = [] eval_res_unseen_clients = [] for client_id in eval_msg_buffer: if eval_msg_buffer[client_id] is None: continue - if client_id in self.unseen_clients_id: + if client_id in self.client_manager._id_client_unseen: eval_res_unseen_clients.append(eval_msg_buffer[client_id]) else: eval_res_participated_clients.append( @@ -611,35 +603,23 @@ def merge_eval_results_from_all_clients(self): def broadcast_model_para(self, msg_type='model_para', - sample_client_num=-1, - filter_unseen_clients=True): + sample_client_num=None): """ To broadcast the message to all clients or sampled clients Arguments: msg_type: 'model_para' or other user defined msg_type sample_client_num: the number of sampled clients in the broadcast - behavior. And sample_client_num = -1 denotes to broadcast to + behavior. And sample_client_num = None denotes to broadcast to all the clients. - filter_unseen_clients: whether filter out the unseen clients that - do not contribute to FL process by training on their local - data and uploading their local model update. The splitting is - useful to check participation generalization gap in [ICLR'22, - What Do We Mean by Generalization in Federated Learning?] - You may want to set it to be False when in evaluation stage - """ - if filter_unseen_clients: - # to filter out the unseen clients when sampling - self.sampler.change_state(self.unseen_clients_id, 'unseen') - - if sample_client_num > 0: - receiver = self.sampler.sample(size=sample_client_num) + """ + # Get the receivers + if sample_client_num: + receiver = self.client_manager.sample(size=sample_client_num) else: - # broadcast to all clients receiver = list(self.comm_manager.neighbors.keys()) - if msg_type == 'model_para': - self.sampler.change_state(receiver, 'working') + # Inject noise if self._noise_injector is not None and msg_type == 'model_para': # Inject noise only when broadcast parameters for model_idx_i in range(len(self.models)): @@ -649,6 +629,7 @@ def broadcast_model_para(self, self._noise_injector(self._cfg, num_sample_clients, self.models[model_idx_i]) + # Prepare model parameters skip_broadcast = self._cfg.federate.method in ["local", "global"] if self.model_num > 1: model_para = [{} if skip_broadcast else model.state_dict() @@ -656,6 +637,7 @@ def broadcast_model_para(self, else: model_para = {} if skip_broadcast else self.model.state_dict() + # Send model parameters self.comm_manager.send( Message(msg_type=msg_type, sender=self.ID, @@ -667,10 +649,6 @@ def broadcast_model_para(self, for idx in range(self.model_num): self.aggregators[idx].reset() - if filter_unseen_clients: - # restore the state of the unseen clients within sampler - self.sampler.change_state(self.unseen_clients_id, 'seen') - def broadcast_client_address(self): """ To broadcast the communication addresses of clients (used for @@ -688,33 +666,21 @@ def broadcast_client_address(self): def check_buffer(self, cur_round, min_received_num, - check_eval_result=False): - """ - To check the message buffer + buffer_type): + """Check if the message buffer receives enough messages Arguments: - cur_round (int): The current round number - min_received_num (int): The minimal number of the receiving messages - check_eval_result (bool): To check training results for evaluation - results - :returns: Whether enough messages have been received or not - :rtype: bool - """ - - if check_eval_result: - if 'eval' not in self.msg_buffer.keys() or len( - self.msg_buffer['eval'].keys()) == 0: - return False - - buffer = self.msg_buffer['eval'] - cur_round = max(buffer.keys()) - cur_buffer = buffer[cur_round] - return len(cur_buffer) >= min_received_num - else: - if cur_round not in self.msg_buffer['train']: - cur_buffer = dict() - else: - cur_buffer = self.msg_buffer['train'][cur_round] + cur_round (int): The current round number + min_received_num (int): The minimal number of the receiving messages + buffer_type (str): Which field to check, chosen from MSGBUFFER.TRAIN, MSGBUFFER.EVAL and MSGBUFFER.CONSULT + + Return: + Whether enough messages have been received or not + """ + buffer = self.msg_buffer.get(buffer_type, dict()) + + if buffer_type == STAGE.TRAIN: + cur_buffer = buffer.get(cur_round, dict()) if self._cfg.asyn.use and self._cfg.asyn.aggregator == 'time_up': if self.cur_timestamp >= self.deadline_for_cur_round and len( cur_buffer) + len(self.staled_msg_buffer) == 0: @@ -738,15 +704,27 @@ def check_buffer(self, return len(cur_buffer)+len(self.staled_msg_buffer) >= \ min_received_num + elif buffer_type == STAGE.EVAL: + # Evaluation won't block the training process + cur_buffer = buffer.get(max(buffer.keys()), dict()) + return len(cur_buffer) >= min_received_num + + elif buffer_type == STAGE.CONSULT: + cur_buffer = buffer.get(cur_round, dict()) + return len(cur_buffer) >= min_received_num + + else: + raise NotImplementedError(f'Type of message buffer {buffer_type} is not implemented.') + def check_client_join_in(self): """ To check whether all the clients have joined in the FL course. """ if len(self._cfg.federate.join_in_info) != 0: - return len(self.join_in_info) == self.client_num + return self.client_manager.check_client_info() else: - return self.join_in_client_num == self.client_num + return self.client_manager.check_client_join_in() def trigger_for_start(self): """ @@ -757,26 +735,12 @@ def trigger_for_start(self): if self._cfg.federate.use_ss: self.broadcast_client_address() - # get sampler - if 'client_resource' in self._cfg.federate.join_in_info: - client_resource = [ - self.join_in_info[client_index]['client_resource'] - for client_index in np.arange(1, self.client_num + 1) - ] - else: - model_size = sys.getsizeof(pickle.dumps( - self.model)) / 1024.0 * 8. - client_resource = [ - model_size / float(x['communication']) + - float(x['computation']) / 1000. - for x in self.client_resource_info - ] if self.client_resource_info is not None else None - - if self.sampler is None: - self.sampler = get_sampler( - sample_strategy=self._cfg.federate.sampler, - client_num=self.client_num, - client_info=client_resource) + # Prepare for training + # init sampler within the client manager after finishing exchanging information + self.client_manager.init_sampler() + # Block unseen clients from training + + self.client_manager.block_unseen_client() # change the deadline if the asyn.aggregator is `time up` if self._cfg.asyn.use and self._cfg.asyn.aggregator == 'time_up': @@ -802,7 +766,7 @@ def trigger_for_time_up(self, check_timestamp=None): return False self.cur_timestamp = self.deadline_for_cur_round - self.check_and_move_on() + self.check_and_move_on(stage=STAGE.TRAIN) return True def terminate(self, msg_type='finish'): @@ -862,8 +826,7 @@ def eval(self): self.check_and_save() else: # Preform evaluation in clients - self.broadcast_model_para(msg_type='evaluate', - filter_unseen_clients=False) + self.broadcast_model_para(msg_type='evaluate') def callback_funcs_model_para(self, message: Message): """ @@ -884,17 +847,19 @@ def callback_funcs_model_para(self, message: Message): sender = message.sender timestamp = message.timestamp content = message.content - self.sampler.change_state(sender, 'idle') + + # After training, change the client status into idle + self.client_manager.change_state(sender, CLIENT_STATE.IDLE) # update the currency timestamp according to the received message assert timestamp >= self.cur_timestamp # for test self.cur_timestamp = timestamp if round == self.state: - if round not in self.msg_buffer['train']: - self.msg_buffer['train'][round] = dict() + if round not in self.msg_buffer[STAGE.TRAIN]: + self.msg_buffer[STAGE.TRAIN][round] = dict() # Save the messages in this round - self.msg_buffer['train'][round][sender] = content + self.msg_buffer[STAGE.TRAIN][round][sender] = content elif round >= self.state - self.staleness_toleration: # Save the staled messages self.staled_msg_buffer.append((round, sender, content)) @@ -906,7 +871,7 @@ def callback_funcs_model_para(self, message: Message): if self._cfg.federate.online_aggr: self.aggregator.inc(content) - move_on_flag = self.check_and_move_on() + move_on_flag = self.check_and_move_on(stage=STAGE.TRAIN) if self._cfg.asyn.use and self._cfg.asyn.broadcast_manner == \ 'after_receiving': self.broadcast_model_para(msg_type='model_para', @@ -932,13 +897,19 @@ def callback_funcs_for_join_in(self, message: Message): assert key in info self.join_in_info[sender] = info logger.info('Server: Client #{:d} has joined in !'.format(sender)) + + # Set the client status as idle + self.client_manager.change_state(sender, CLIENT_STATE.IDLE) + else: - self.join_in_client_num += 1 sender, address = message.sender, message.content - if int(sender) == -1: # assign number to client - sender = self.join_in_client_num - self.comm_manager.add_neighbors(neighbor_id=sender, - address=address) + # Register client in client_manager + register_id = self.client_manager.register_client(sender) + + # TODO: maybe we shouldn't support user-defined ID + if register_id != sender: # assign number to client + sender = register_id + self.comm_manager.add_neighbors(neighbor_id=sender, address=address) self.comm_manager.send( Message(msg_type='assign_client_id', sender=self.ID, @@ -947,8 +918,7 @@ def callback_funcs_for_join_in(self, message: Message): timestamp=self.cur_timestamp, content=str(sender))) else: - self.comm_manager.add_neighbors(neighbor_id=sender, - address=address) + self.comm_manager.add_neighbors(neighbor_id=sender, address=address) if len(self._cfg.federate.join_in_info) != 0: self.comm_manager.send( @@ -975,9 +945,9 @@ def callback_funcs_for_metrics(self, message: Message): sender = message.sender content = message.content - if round not in self.msg_buffer['eval'].keys(): - self.msg_buffer['eval'][round] = dict() + if round not in self.msg_buffer[STAGE.EVAL].keys(): + self.msg_buffer[STAGE.EVAL][round] = dict() - self.msg_buffer['eval'][round][sender] = content + self.msg_buffer[STAGE.EVAL][round][sender] = content - return self.check_and_move_on(check_eval_result=True) + return self.check_and_move_on(stage=STAGE.EVAL) diff --git a/federatedscope/gfl/fedsageplus/worker.py b/federatedscope/gfl/fedsageplus/worker.py index f1812598d..dd96fcb76 100644 --- a/federatedscope/gfl/fedsageplus/worker.py +++ b/federatedscope/gfl/fedsageplus/worker.py @@ -4,6 +4,7 @@ from torch_geometric.loader import NeighborSampler +from federatedscope.core.auxiliaries.enums import STAGE, CLIENT_STATE from federatedscope.core.message import Message from federatedscope.core.workers.server import Server from federatedscope.core.workers.client import Client @@ -67,22 +68,27 @@ def callback_funcs_for_join_in(self, message: Message): assert key in info self.join_in_info[sender] = info logger.info('Server: Client #{:d} has joined in !'.format(sender)) + + # Set the client status as idle + self.client_manager.change_state(sender, CLIENT_STATE.IDLE) else: - self.join_in_client_num += 1 sender, address = message.sender, message.content - if int(sender) == -1: # assign number to client - sender = self.join_in_client_num - self.comm_manager.add_neighbors(neighbor_id=sender, - address=address) + + # Register client in client_manager + register_id = self.client_manager.register_client(sender) + + if register_id != sender: # assign number to client + sender = register_id + self.comm_manager.add_neighbors(neighbor_id=sender, address=address) self.comm_manager.send( Message(msg_type='assign_client_id', sender=self.ID, receiver=[sender], state=self.state, + timestamp=self.cur_timestamp, content=str(sender))) else: - self.comm_manager.add_neighbors(neighbor_id=sender, - address=address) + self.comm_manager.add_neighbors(neighbor_id=sender, address=address) if len(self._cfg.federate.join_in_info) != 0: self.comm_manager.send( @@ -106,22 +112,22 @@ def callback_funcs_gradient(self, message: Message): round, _, content = message.state, message.sender, message.content gen_grad, ID = content # For a new round - if round not in self.msg_buffer['train'].keys(): - self.msg_buffer['train'][round] = dict() + if round not in self.msg_buffer[STAGE.TRAIN].keys(): + self.msg_buffer[STAGE.TRAIN][round] = dict() self.grad_cnt += 1 # Sum up all grad from other client - if ID not in self.msg_buffer['train'][round]: - self.msg_buffer['train'][round][ID] = dict() + if ID not in self.msg_buffer[STAGE.TRAIN][round]: + self.msg_buffer[STAGE.TRAIN][round][ID] = dict() for key in gen_grad.keys(): - self.msg_buffer['train'][round][ID][key] = torch.FloatTensor( + self.msg_buffer[STAGE.TRAIN][round][ID][key] = torch.FloatTensor( gen_grad[key].cpu()) else: for key in gen_grad.keys(): - self.msg_buffer['train'][round][ID][key] += torch.FloatTensor( + self.msg_buffer[STAGE.TRAIN][round][ID][key] += torch.FloatTensor( gen_grad[key].cpu()) self.check_and_move_on() - def check_and_move_on(self, check_eval_result=False): + def check_and_move_on(self, check_eval_result=False, **kwargs): client_IDs = [i for i in range(1, self.client_num + 1)] if check_eval_result: @@ -137,8 +143,8 @@ def check_and_move_on(self, check_eval_result=False): ) and self.state < self._cfg.fedsageplus.fedgen_epoch and self.state\ % 2 == 0: # FedGen: we should wait for all messages - for sender in self.msg_buffer['train'][self.state]: - content = self.msg_buffer['train'][self.state][sender] + for sender in self.msg_buffer[STAGE.TRAIN][self.state]: + content = self.msg_buffer[STAGE.TRAIN][self.state][sender] gen_para, embedding, label = content receiver_IDs = client_IDs[:sender - 1] + client_IDs[sender:] self.comm_manager.send( @@ -157,8 +163,8 @@ def check_and_move_on(self, check_eval_result=False): ) and self.state < self._cfg.fedsageplus.fedgen_epoch and self.state\ % 2 == 1 and self.grad_cnt == self.client_num * ( self.client_num - 1): - for ID in self.msg_buffer['train'][self.state]: - grad = self.msg_buffer['train'][self.state][ID] + for ID in self.msg_buffer[STAGE.TRAIN][self.state]: + grad = self.msg_buffer[STAGE.TRAIN][self.state][ID] self.comm_manager.send( Message(msg_type='gradient', sender=self.ID, @@ -186,7 +192,7 @@ def check_and_move_on(self, check_eval_result=False): if not check_eval_result: # in the training process # Get all the message - train_msg_buffer = self.msg_buffer['train'][self.state] + train_msg_buffer = self.msg_buffer[STAGE.TRAIN][self.state] msg_list = list() for client_id in train_msg_buffer: msg_list.append(train_msg_buffer[client_id]) diff --git a/federatedscope/gfl/gcflplus/worker.py b/federatedscope/gfl/gcflplus/worker.py index 26a1d62df..5ecd4623a 100644 --- a/federatedscope/gfl/gcflplus/worker.py +++ b/federatedscope/gfl/gcflplus/worker.py @@ -3,6 +3,7 @@ import copy import numpy as np +from federatedscope.core.auxiliaries.enums import STAGE from federatedscope.core.message import Message from federatedscope.core.workers.server import Server from federatedscope.core.workers.client import Client @@ -44,7 +45,7 @@ def compute_update_norm(self, cluster): max_norm = -np.inf cluster_dWs = [] for key in cluster: - content = self.msg_buffer['train'][self.state][key] + content = self.msg_buffer[STAGE.TRAIN][self.state][key] _, model_para, client_dw, _ = content dW = {} for k in model_para.keys(): @@ -71,7 +72,7 @@ def check_and_move_on(self, check_eval_result=False): if not check_eval_result: # in the training process # Get all the message - train_msg_buffer = self.msg_buffer['train'][self.state] + train_msg_buffer = self.msg_buffer[STAGE.TRAIN][self.state] for model_idx in range(self.model_num): model = self.models[model_idx] aggregator = self.aggregators[model_idx] @@ -135,7 +136,7 @@ def check_and_move_on(self, check_eval_result=False): for cluster in self.cluster_indices: msg_list = list() for key in cluster: - content = self.msg_buffer['train'][self.state - + content = self.msg_buffer[STAGE.TRAIN][self.state - 1][key] train_data_size, model_para, client_dw, \ convGradsNorm = content @@ -161,7 +162,7 @@ def check_and_move_on(self, check_eval_result=False): f'----------- Starting a new traininground(Round ' f'#{self.state}) -------------') # Clean the msg_buffer - self.msg_buffer['train'][self.state - 1].clear() + self.msg_buffer[STAGE.TRAIN][self.state - 1].clear() else: # Final Evaluate