Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support more message types in server #379

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions federatedscope/attack/worker_as_attacker/server_attacker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from federatedscope.core.auxiliaries.enums import STAGE
from federatedscope.core.workers import Server
from federatedscope.core.message import Message

Expand Down Expand Up @@ -244,15 +245,15 @@ def _reconstruct(self, model_para, batch_size, state, sender):
def run_reconstruct(self, state_list=None, sender_list=None):

if state_list is None:
state_list = self.msg_buffer['train'].keys()
state_list = self.msg_buffer[STAGE.TRAIN].keys()

# After FL running, using gradient based reconstruction method to
# recover client's private training data
for state in state_list:
if sender_list is None:
sender_list = self.msg_buffer['train'][state].keys()
sender_list = self.msg_buffer[STAGE.TRAIN][state].keys()
for sender in sender_list:
content = self.msg_buffer['train'][state][sender]
content = self.msg_buffer[STAGE.TRAIN][state][sender]
self._reconstruct(model_para=content[1],
batch_size=content[0],
state=state,
Expand All @@ -264,9 +265,9 @@ def callback_funcs_model_para(self, message: Message):

round, sender, content = message.state, message.sender, message.content
self.sampler.change_state(sender, 'idle')
if round not in self.msg_buffer['train']:
self.msg_buffer['train'][round] = dict()
self.msg_buffer['train'][round][sender] = content
if round not in self.msg_buffer[STAGE.TRAIN]:
self.msg_buffer[STAGE.TRAIN][round] = dict()
self.msg_buffer[STAGE.TRAIN][round][sender] = content

# run reconstruction before the clear of self.msg_buffer

Expand All @@ -284,7 +285,7 @@ def callback_funcs_model_para(self, message: Message):
name='image_state_{}_client_{}.png'.format(
message.state, message.sender))

self.check_and_move_on()
self.check_and_move_on(stage=STAGE.TRAIN)


class PassivePIAServer(Server):
Expand Down Expand Up @@ -342,9 +343,9 @@ def callback_funcs_model_para(self, message: Message):

round, sender, content = message.state, message.sender, message.content
self.sampler.change_state(sender, 'idle')
if round not in self.msg_buffer['train']:
self.msg_buffer['train'][round] = dict()
self.msg_buffer['train'][round][sender] = content
if round not in self.msg_buffer[STAGE.TRAIN]:
self.msg_buffer[STAGE.TRAIN][round] = dict()
self.msg_buffer[STAGE.TRAIN][round][sender] = content

# collect the updates
self.pia_attacker.collect_updates(
Expand All @@ -359,7 +360,7 @@ def callback_funcs_model_para(self, message: Message):
# TODO: put this line to `check_and_move_on`
# currently, no way to know the latest `sender`
self.aggregator.inc(content)
self.check_and_move_on()
self.check_and_move_on(stage=STAGE.TRAIN)

if self.state == self.total_round_num:
self.pia_attacker.train_property_classifier()
Expand Down
21 changes: 11 additions & 10 deletions federatedscope/autotune/fedex/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from numpy.linalg import norm
from scipy.special import logsumexp

from federatedscope.core.auxiliaries.enums import STAGE
from federatedscope.core.message import Message
from federatedscope.core.workers import Server
from federatedscope.core.auxiliaries.utils import merge_dict
Expand Down Expand Up @@ -204,15 +205,15 @@ def callback_funcs_model_para(self, message: Message):
round, sender, content = message.state, message.sender, message.content
self.sampler.change_state(sender, 'idle')
# For a new round
if round not in self.msg_buffer['train'].keys():
self.msg_buffer['train'][round] = dict()
if round not in self.msg_buffer[STAGE.TRAIN].keys():
self.msg_buffer[STAGE.TRAIN][round] = dict()

self.msg_buffer['train'][round][sender] = content
self.msg_buffer[STAGE.TRAIN][round][sender] = content

if self._cfg.federate.online_aggr:
self.aggregator.inc(tuple(content[0:2]))

return self.check_and_move_on()
return self.check_and_move_on(stage=STAGE.TRAIN)

def update_policy(self, feedbacks):
"""Update the policy. This implementation is borrowed from the
Expand Down Expand Up @@ -284,7 +285,7 @@ def update_policy(self, feedbacks):
self._trace['mle'][-1]))

def check_and_move_on(self,
check_eval_result=False,
stage,
min_received_num=None):
"""
To check the message_buffer, when enough messages are receiving,
Expand All @@ -295,17 +296,17 @@ def check_and_move_on(self,
min_received_num = self._cfg.federate.sample_client_num
assert min_received_num <= self.sample_client_num

if check_eval_result:
if stage == STAGE.EVAL:
min_received_num = len(list(self.comm_manager.neighbors.keys()))

move_on_flag = True # To record whether moving to a new training
# round or finishing the evaluation
if self.check_buffer(self.state, min_received_num, check_eval_result):
if self.check_buffer(self.state, min_received_num, stage):

if not check_eval_result: # in the training process
if stage == STAGE.TRAIN: # in the training process
mab_feedbacks = list()
# Get all the message
train_msg_buffer = self.msg_buffer['train'][self.state]
train_msg_buffer = self.msg_buffer[STAGE.TRAIN][self.state]
for model_idx in range(self.model_num):
model = self.models[model_idx]
aggregator = self.aggregators[model_idx]
Expand Down Expand Up @@ -363,7 +364,7 @@ def check_and_move_on(self,
f'----------- Starting a new training round (Round '
f'#{self.state}) -------------')
# Clean the msg_buffer
self.msg_buffer['train'][self.state - 1].clear()
self.msg_buffer[STAGE.TRAIN][self.state - 1].clear()

self.broadcast_model_para(
msg_type='model_para',
Expand Down
6 changes: 6 additions & 0 deletions federatedscope/core/auxiliaries/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ class MODE:
FINETUNE = 'finetune'


class STAGE:
TRAIN = 'train'
EVAL = 'eval'
CONSULT = 'consult'


class TRIGGER:
ON_FIT_START = 'on_fit_start'
ON_EPOCH_START = 'on_epoch_start'
Expand Down
9 changes: 5 additions & 4 deletions federatedscope/core/workers/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import pickle

from federatedscope.core.auxiliaries.enums import STAGE
from federatedscope.core.message import Message
from federatedscope.core.communication import StandaloneCommManager, \
gRPCCommManager
Expand Down Expand Up @@ -225,12 +226,12 @@ def callback_funcs_for_model_para(self, message: Message):
# A fragment of the shared secret
state, content, timestamp = message.state, message.content, \
message.timestamp
self.msg_buffer['train'][state].append(content)
self.msg_buffer[STAGE.TRAIN][state].append(content)
Copy link
Collaborator

@rayrayraykk rayrayraykk Oct 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure whether the 'stage' is an appropriate key name, as the key of msg_buffer should be message type (although it is not now: msg_buffer['train'] contains model, msg_buffer['eval'] contains eval results).

Maybe we should use the message type as the key of msg_buffer (like msg_buffer['model_param'] ).

What's more, are STAGE and state confusing?


if len(self.msg_buffer['train']
if len(self.msg_buffer[STAGE.TRAIN]
[state]) == self._cfg.federate.client_num:
# Check whether the received fragments are enough
model_list = self.msg_buffer['train'][state]
model_list = self.msg_buffer[STAGE.TRAIN][state]
sample_size, first_aggregate_model_para = model_list[0]
single_model_case = True
if isinstance(first_aggregate_model_para, list):
Expand Down Expand Up @@ -355,7 +356,7 @@ def callback_funcs_for_model_para(self, message: Message):
single_model_case else \
[model_para_list[frame_idx] for model_para_list in
model_para_list_all]
self.msg_buffer['train'][self.state] = [(sample_size,
self.msg_buffer[STAGE.TRAIN][self.state] = [(sample_size,
content_frame)]
else:
if self._cfg.asyn.use:
Expand Down
Loading