Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add client manager in server #383

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 20 additions & 29 deletions federatedscope/attack/worker_as_attacker/server_attacker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from federatedscope.core.auxiliaries.enums import STAGE, CLIENT_STATE
from federatedscope.core.workers import Server
from federatedscope.core.message import Message

Expand Down Expand Up @@ -46,8 +47,7 @@ def __init__(self,

def broadcast_model_para(self,
msg_type='model_para',
sample_client_num=-1,
filter_unseen_clients=True):
sample_client_num=-1):
"""
To broadcast the message to all clients or sampled clients

Expand All @@ -56,18 +56,8 @@ def broadcast_model_para(self,
sample_client_num: the number of sampled clients in the broadcast
behavior. And sample_client_num = -1 denotes to broadcast to
all the clients.
filter_unseen_clients: whether filter out the unseen clients that
do not contribute to FL process by training on their local
data and uploading their local model update. The splitting is
useful to check participation generalization gap in [ICLR'22,
What Do We Mean by Generalization in Federated Learning?]
You may want to set it to be False when in evaluation stage
"""

if filter_unseen_clients:
# to filter out the unseen clients when sampling
self.sampler.change_state(self.unseen_clients_id, 'unseen')

if sample_client_num > 0: # only activated at training process
attacker_id = self._cfg.attack.attacker_id
setting = self._cfg.attack.setting
Expand Down Expand Up @@ -158,10 +148,6 @@ def broadcast_model_para(self,
for idx in range(self.model_num):
self.aggregators[idx].reset()

if filter_unseen_clients:
# restore the state of the unseen clients within sampler
self.sampler.change_state(self.unseen_clients_id, 'seen')


class PassiveServer(Server):
'''
Expand Down Expand Up @@ -244,15 +230,15 @@ def _reconstruct(self, model_para, batch_size, state, sender):
def run_reconstruct(self, state_list=None, sender_list=None):

if state_list is None:
state_list = self.msg_buffer['train'].keys()
state_list = self.msg_buffer[STAGE.TRAIN].keys()

# After FL running, using gradient based reconstruction method to
# recover client's private training data
for state in state_list:
if sender_list is None:
sender_list = self.msg_buffer['train'][state].keys()
sender_list = self.msg_buffer[STAGE.TRAIN][state].keys()
for sender in sender_list:
content = self.msg_buffer['train'][state][sender]
content = self.msg_buffer[STAGE.TRAIN][state][sender]
self._reconstruct(model_para=content[1],
batch_size=content[0],
state=state,
Expand All @@ -263,10 +249,12 @@ def callback_funcs_model_para(self, message: Message):
return 'finish'

round, sender, content = message.state, message.sender, message.content
self.sampler.change_state(sender, 'idle')
if round not in self.msg_buffer['train']:
self.msg_buffer['train'][round] = dict()
self.msg_buffer['train'][round][sender] = content
# After training, change the client status into idle
self.client_manager.change_state(sender, CLIENT_STATE.IDLE)

if round not in self.msg_buffer[STAGE.TRAIN]:
self.msg_buffer[STAGE.TRAIN][round] = dict()
self.msg_buffer[STAGE.TRAIN][round][sender] = content

# run reconstruction before the clear of self.msg_buffer

Expand All @@ -284,7 +272,7 @@ def callback_funcs_model_para(self, message: Message):
name='image_state_{}_client_{}.png'.format(
message.state, message.sender))

self.check_and_move_on()
self.check_and_move_on(stage=STAGE.TRAIN)


class PassivePIAServer(Server):
Expand Down Expand Up @@ -341,10 +329,13 @@ def callback_funcs_model_para(self, message: Message):
return 'finish'

round, sender, content = message.state, message.sender, message.content
self.sampler.change_state(sender, 'idle')
if round not in self.msg_buffer['train']:
self.msg_buffer['train'][round] = dict()
self.msg_buffer['train'][round][sender] = content

# After training, change the client status into idle
self.client_manager.change_state(sender, CLIENT_STATE.IDLE)

if round not in self.msg_buffer[STAGE.TRAIN]:
self.msg_buffer[STAGE.TRAIN][round] = dict()
self.msg_buffer[STAGE.TRAIN][round][sender] = content

# collect the updates
self.pia_attacker.collect_updates(
Expand All @@ -359,7 +350,7 @@ def callback_funcs_model_para(self, message: Message):
# TODO: put this line to `check_and_move_on`
# currently, no way to know the latest `sender`
self.aggregator.inc(content)
self.check_and_move_on()
self.check_and_move_on(stage=STAGE.TRAIN)

if self.state == self.total_round_num:
self.pia_attacker.train_property_classifier()
Expand Down
42 changes: 20 additions & 22 deletions federatedscope/autotune/fedex/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from numpy.linalg import norm
from scipy.special import logsumexp

from federatedscope.core.auxiliaries.enums import STAGE, CLIENT_STATE
from federatedscope.core.message import Message
from federatedscope.core.workers import Server
from federatedscope.core.auxiliaries.utils import merge_dict
Expand Down Expand Up @@ -146,23 +147,18 @@ def sample(self):

def broadcast_model_para(self,
msg_type='model_para',
sample_client_num=-1,
filter_unseen_clients=True):
sample_client_num=-1):
"""
To broadcast the message to all clients or sampled clients
"""
if filter_unseen_clients:
# to filter out the unseen clients when sampling
self.sampler.change_state(self.unseen_clients_id, 'unseen')

if sample_client_num > 0:
receiver = self.sampler.sample(size=sample_client_num)
receiver = self.client_manager.sample(size=sample_client_num)
else:
# broadcast to all clients
receiver = list(self.comm_manager.neighbors.keys())
if msg_type == 'model_para':
self.sampler.change_state(receiver, 'working')

# Inject noise
if self._noise_injector is not None and msg_type == 'model_para':
# Inject noise only when broadcast parameters
for model_idx_i in range(len(self.models)):
Expand Down Expand Up @@ -196,18 +192,16 @@ def broadcast_model_para(self,
for idx in range(self.model_num):
self.aggregators[idx].reset()

if filter_unseen_clients:
# restore the state of the unseen clients within sampler
self.sampler.change_state(self.unseen_clients_id, 'seen')

def callback_funcs_model_para(self, message: Message):
round, sender, content = message.state, message.sender, message.content
self.sampler.change_state(sender, 'idle')
# After training, change the client status into idle
self.client_manager.change_state(sender, CLIENT_STATE.IDLE)
# For a new round
if round not in self.msg_buffer['train'].keys():
self.msg_buffer['train'][round] = dict()
if round not in self.msg_buffer[STAGE.TRAIN].keys():
self.msg_buffer[STAGE.TRAIN][round] = dict()

self.msg_buffer['train'][round][sender] = content
self.msg_buffer[STAGE.TRAIN][round][sender] = content

if self._cfg.federate.online_aggr:
self.aggregator.inc(tuple(content[0:2]))
Expand Down Expand Up @@ -284,7 +278,7 @@ def update_policy(self, feedbacks):
self._trace['mle'][-1]))

def check_and_move_on(self,
check_eval_result=False,
stage,
min_received_num=None):
"""
To check the message_buffer, when enough messages are receiving,
Expand All @@ -295,17 +289,17 @@ def check_and_move_on(self,
min_received_num = self._cfg.federate.sample_client_num
assert min_received_num <= self.sample_client_num

if check_eval_result:
if stage == STAGE.EVAL:
min_received_num = len(list(self.comm_manager.neighbors.keys()))

move_on_flag = True # To record whether moving to a new training
# round or finishing the evaluation
if self.check_buffer(self.state, min_received_num, check_eval_result):
if self.check_buffer(self.state, min_received_num, stage):

if not check_eval_result: # in the training process
if stage == STAGE.TRAIN: # in the training process
mab_feedbacks = list()
# Get all the message
train_msg_buffer = self.msg_buffer['train'][self.state]
train_msg_buffer = self.msg_buffer[STAGE.TRAIN][self.state]
for model_idx in range(self.model_num):
model = self.models[model_idx]
aggregator = self.aggregators[model_idx]
Expand Down Expand Up @@ -363,7 +357,7 @@ def check_and_move_on(self,
f'----------- Starting a new training round (Round '
f'#{self.state}) -------------')
# Clean the msg_buffer
self.msg_buffer['train'][self.state - 1].clear()
self.msg_buffer[STAGE.TRAIN][self.state - 1].clear()

self.broadcast_model_para(
msg_type='model_para',
Expand All @@ -374,12 +368,16 @@ def check_and_move_on(self,
'evaluation.')
self.eval()

else: # in the evaluation process
elif stage == STAGE.EVAL: # in the evaluation process
# Get all the message & aggregate
formatted_eval_res = self.merge_eval_results_from_all_clients()
self.history_results = merge_dict(self.history_results,
formatted_eval_res)
self.check_and_save()

else:
pass

else:
move_on_flag = False

Expand Down
30 changes: 27 additions & 3 deletions federatedscope/core/auxiliaries/enums.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
class MODE:
class BasicEnum(object):
@classmethod
def assert_value(cls, value):
"""
Check if the **value** is legal for the given class (If the value equals one of the class attributes)
"""
if not value in [v for k, v in cls.__dict__.items() if not k.startswith('__')]:
raise ValueError(f"Value {value} is not in {cls.__name__}.")


class MODE(BasicEnum):
"""

Note:
Expand All @@ -11,7 +21,13 @@ class MODE:
FINETUNE = 'finetune'


class TRIGGER:
class STAGE(BasicEnum):
TRAIN = 'train'
EVAL = 'eval'
CONSULT = 'consult'


class TRIGGER(BasicEnum):
ON_FIT_START = 'on_fit_start'
ON_EPOCH_START = 'on_epoch_start'
ON_BATCH_START = 'on_batch_start'
Expand All @@ -30,8 +46,16 @@ def contains(cls, item):
]


class LIFECYCLE:
class LIFECYCLE(BasicEnum):
ROUTINE = 'routine'
EPOCH = 'epoch'
BATCH = 'batch'
NONE = None


class CLIENT_STATE(BasicEnum):
OFFLINE = -1 # not join in
CONSULTING = 2 # join in and is consulting
IDLE = 1 # join in but not working, available for training
WORKING = 0 # join in and is working
SIDELINE = 0 # join in but won't participate in federated training
10 changes: 4 additions & 6 deletions federatedscope/core/auxiliaries/sampler_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@
logger = logging.getLogger(__name__)


def get_sampler(sample_strategy='uniform',
def get_sampler(sample_strategy,
client_num=None,
client_info=None,
bins=10):
if sample_strategy == 'uniform':
return UniformSampler(client_num=client_num)
return UniformSampler()
elif sample_strategy == 'group':
return GroupSampler(client_num=client_num,
client_info=client_info,
return GroupSampler(client_info=client_info['client_resource'],
bins=bins)
else:
raise ValueError(
f"The sample strategy {sample_strategy} has not been provided.")
return None
Loading