Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ss evaluation for tree-based model #585

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion federatedscope/core/secret_sharing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from federatedscope.core.secret_sharing.secret_sharing import \
AdditiveSecretSharing
AdditiveSecretSharing, MultiplicativeSecretSharing
124 changes: 120 additions & 4 deletions federatedscope/core/secret_sharing/secret_sharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def secret_split(self, secret):

secret = self.float2fixedpoint(secret)
secret_seq = np.random.randint(low=0, high=self.mod_number, size=shape)
# last_seq = self.mod_funs(secret - self.mod_funs(np.sum(secret_seq,
# axis=0)))
last_seq = self.mod_funs(secret -
self.mod_funs(np.sum(secret_seq, axis=0)))
# last_seq =
# self.mod_funs(secret - self.mod_funs(np.sum(secret_seq, axis=0)))
last_seq = self.mod_funs(
secret - self.mod_funs(np.sum(secret_seq, axis=0))).astype(int)

secret_seq = np.append(secret_seq,
np.expand_dims(last_seq, axis=0),
Expand All @@ -82,6 +82,14 @@ def secret_reconstruct(self, secret_seq):
else:
merge_model[key] += secret_seq[idx][key]
merge_model[key] = self.fixedpoint2float(merge_model[key])
# if merge_model is an ndarray or a list
else:
for idx in range(len(secret_seq)):
if idx == 0:
merge_model = secret_seq[idx].copy()
else:
merge_model += secret_seq[idx]
merge_model = self.fixedpoint2float(merge_model)

return merge_model

Expand All @@ -96,3 +104,111 @@ def _fixedpoint2float(self, x):
return -1 * (self.mod_number - x) / self.epsilon
else:
return x / self.epsilon


class MultiplicativeSecretSharing(AdditiveSecretSharing):
"""
AdditiveSecretSharing class, which can split a number into frames and
recover it by summing up
"""
def __init__(self, shared_party_num, size=60):
super().__init__(shared_party_num, size)
self.maximum = 2**size
self.mod_number = 2 * self.maximum + 1
self.epsilon = 1e8

def secret_split(self, secret, cls=None):
"""
To split the secret into frames according to the shared_party_num
"""
if isinstance(secret, dict):
secret_list = [dict() for _ in range(self.shared_party_num)]
for key in secret:
for idx, each in enumerate(
self.secret_split(secret[key], cls=cls)):
secret_list[idx][key] = each
return secret_list

if isinstance(secret, list) or isinstance(secret, np.ndarray):
secret = np.asarray(secret).astype(int)
shape = [self.shared_party_num - 1] + list(secret.shape)
elif isinstance(secret, torch.Tensor):
secret = secret.numpy()
shape = [self.shared_party_num - 1] + list(secret.shape)
else:
shape = [self.shared_party_num - 1]

if cls is None:
secret = self.float2fixedpoint(secret)
secret_seq = np.random.randint(low=0, high=self.mod_number, size=shape)
# last_seq =
# self.mod_funs(secret - self.mod_funs(np.sum(secret_seq, axis=0)))
last_seq = self.mod_funs(
secret - self.mod_funs(np.sum(secret_seq, axis=0))).astype(int)

secret_seq = np.append(secret_seq,
np.expand_dims(last_seq, axis=0),
axis=0)
return secret_seq

def secret_add_lists(self, args):
# args is a list
# whose last element is a list consisting of secret pieces
# TODO: add the condition that all elements in args are numbers
for i in range(len(args) - 1):
# if isinstance(args[i], int) or isinstance(args[i], np.int64):
if not isinstance(args[i], list) and not isinstance(
args[i], np.ndarray):
args[i] = [args[i]] * len(args[-1])
return self.mod_funs(np.sum(args, axis=0))
# TODO: in the future, when involve large numbers, numpy may overflow,
# thus, the following would work
# n = len(args[0])
# num = len(args)
# res = [0] * n
# for i in range(n):
# for j in range(num):
# res[i] += args[j][i]
# res[i] = res[i] % self.mod_number
# return np.asarray(res)

def secret_ndarray_star_ndarray(self, arr1, arr2):
# return a list whose i-th elements equals to
# the product of the i-th elements of arr1 and arr2
# where arr1 and arr2 are both secret pieces
if isinstance(arr1, int) or isinstance(arr1, np.int64):
arr1 = [arr1] * len(arr2)
if isinstance(arr2, int) or isinstance(arr2, np.int64):
arr2 = [arr2] * len(arr1)
n = len(arr1)
res = [0] * n
for i in range(n):
res[i] = (arr1[i].item() * arr2[i].item()) % self.mod_number
return np.asarray(res)

def beaver_triple(self, *args):
a = np.random.randint(0, self.mod_number, args).astype(int)
b = np.random.randint(0, self.mod_number, args).astype(int)

a_list = []
b_list = []
c = [(a[i].item() * b[i].item()) % self.mod_number
for i in range(len(a))]
c_list = []
for i in range(self.shared_party_num - 1):
a_tmp = np.random.randint(0, self.mod_number, args)
a_list.append(a_tmp)
a -= a_tmp
a = a % self.mod_number
b_tmp = np.random.randint(0, self.mod_number, args)
b_list.append(b_tmp)
b -= b_tmp
b = b % self.mod_number
c_tmp = np.random.randint(0, self.mod_number, args)
c_list.append(c_tmp)
c -= c_tmp
c = c % self.mod_number
a_list.append(a)
b_list.append(b)
c_list.append(c)
return a_list, b_list, c_list
106 changes: 106 additions & 0 deletions federatedscope/core/secret_sharing/ss_multiplicative_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import types
import logging
from federatedscope.core.message import Message

logger = logging.getLogger(__name__)


def wrap_client_for_ss_multiplicative(client):
# TODO: this only works when one of the arguments is a secret piece of
# the indicator vector which we do not make it to be a fixed point.
# For general cases, we should add a truncation step at the end.
def ss_multiplicative(self,
secret1,
secret2,
shared_party_num,
behavior=None):
self.secret1 = secret1
self.secret2 = secret2
self.behavior = behavior
self.shared_party_num = shared_party_num
self.pe_dict = dict()
self.pf_dict = dict()
self.res = None
if self.own_label:
self.comm_manager.send(
Message(msg_type='random_numbers',
sender=self.ID,
state=self.state,
receiver=[self.server_id],
content=(shared_party_num, len(secret2))))

def callback_fun_for_beaver_triplets(self, message: Message):
pa, pb, self.pc = message.content
pe = self.ss.secret_add_lists([self.secret1, -pa])
pf = self.ss.secret_add_lists([self.secret2, -pb])

self.pe_dict[self.ID] = pe
self.pf_dict[self.ID] = pf
for i in range(self.shared_party_num):
if i + 1 != self.ID:
self.comm_manager.send(
Message(msg_type='part_e_and_f',
sender=self.ID,
state=self.state,
receiver=[i + 1],
content=(pe, pf)))

def callback_func_for_part_e_and_f(self, message: Message):
pe, pf = message.content
self.pe_dict[message.sender] = pe
self.pf_dict[message.sender] = pf
if len(self.pe_dict) == self.shared_party_num:
e = self.ss.secret_add_lists([x for x in self.pe_dict.values()])
f = self.ss.secret_add_lists([x for x in self.pf_dict.values()])
self.pe_dict = {}
self.pf_dict = {}
t1 = self.ss.secret_ndarray_star_ndarray(f, self.secret1)
t2 = self.ss.secret_ndarray_star_ndarray(e, self.secret2)
if not self.own_label:
self.res = self.ss.secret_add_lists([t1, t2, self.pc])
else:
t3 = self.ss.secret_ndarray_star_ndarray(e, f)
self.res = self.ss.secret_add_lists([-t3, t1, t2, self.pc])
self.continue_next()

def continue_next(self):
if self.behavior == 'left_child':
self.set_left_child()
elif self.behavior == 'right_child':
self.set_right_child()
elif self.behavior == 'weight':
self.set_weight()

client.ss_multiplicative = types.MethodType(ss_multiplicative, client)
client.continue_next = types.MethodType(continue_next, client)
client.callback_fun_for_beaver_triplets = types.MethodType(
callback_fun_for_beaver_triplets, client)
client.callback_fun_for_part_e_and_f = types.MethodType(
callback_func_for_part_e_and_f, client)

client.register_handlers('beaver_triplets',
client.callback_fun_for_beaver_triplets)
client.register_handlers('part_e_and_f',
client.callback_fun_for_part_e_and_f)

return client


def wrap_server_for_ss_multiplicative(server):
def callback_func_for_random_numbers(self, message: Message):
shared_party_num, size = message.content
a_list, b_list, c_list = self.ss.beaver_triple(size)
for i in range(shared_party_num):
self.comm_manager.send(
Message(msg_type='beaver_triplets',
sender=self.ID,
receiver=[i + 1],
state=self.state,
content=(a_list[i], b_list[i], c_list[i])))

server.callback_func_for_random_numbers = types.MethodType(
callback_func_for_random_numbers, server)
server.register_handlers('random_numbers',
server.callback_func_for_random_numbers)

return server
2 changes: 1 addition & 1 deletion federatedscope/vertical_fl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ For label-scattering model, we provide privacy protection algorithms proposed by

```
vertical:
mode: 'label_based'
mode: 'label_scattering'
protect_object: 'grad_and_hess'
protect_method: 'he'
key_size: ks
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use_gpu: False
device: 0
backend: torch
federate:
mode: standalone
client_num: 2
model:
type: xgb_tree
lambda_: 0.1
gamma: 0
num_of_trees: 10
max_tree_depth: 3
data:
root: data/
type: adult
splits: [1.0, 0.0]
dataloader:
type: raw
batch_size: 2000
criterion:
type: CrossEntropyLoss
trainer:
type: verticaltrainer
train:
optimizer:
# learning rate for xgb model
eta: 0.5
vertical:
use: True
dims: [7, 14]
algo: 'xgb'
eval_protection: 'ss'
data_size_for_debug: 2000
eval:
freq: 3
best_res_update_round_wise_key: test_loss
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from federatedscope.core.message import Message
from federatedscope.vertical_fl.Paillier import \
abstract_paillier
from federatedscope.core.secret_sharing import MultiplicativeSecretSharing

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -36,6 +37,9 @@ def __init__(self,
keys = abstract_paillier.generate_paillier_keypair(
n_length=self._cfg.vertical.key_size)
self.public_key, self.private_key = keys
elif self._cfg.vertical.eval_protection == 'ss':
self.ss = MultiplicativeSecretSharing(
shared_party_num=self.client_num)

self.feature_order = None
self.merged_feature_order = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from federatedscope.core.workers import Server
from federatedscope.core.message import Message
from federatedscope.core.secret_sharing import MultiplicativeSecretSharing

import logging

Expand Down Expand Up @@ -29,6 +30,10 @@ def __init__(self,
self.total_num_of_feature = self._cfg.vertical.dims[-1]
self._init_data_related_var()

if self._cfg.vertical.eval_protection == 'ss':
self.ss = MultiplicativeSecretSharing(
shared_party_num=self.client_num)

def _init_data_related_var(self):
pass

Expand Down
Loading