Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add serveral differential privacy algorithms #369

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions federatedscope/core/auxiliaries/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ def get_model(model_config, local_data=None, backend='torch'):
elif model_config.type.lower() in ['vmfnet', 'hmfnet']:
from federatedscope.mf.model.model_builder import get_mfnet
model = get_mfnet(model_config, input_shape)
elif model_config.type.lower() == 'fmlinearregression':
from federatedscope.differential_privacy.model.fm_linear_regression import FMLinearRegression
model = FMLinearRegression(in_channels=input_shape[-1], epsilon=0.5)
elif model_config.type.lower() == 'fmlogisticregression':
from federatedscope.differential_privacy.model.fm_logistic_regression import FMLogisticRegression
model = FMLogisticRegression(in_channels=input_shape[-1], epsilon=0.5)
else:
raise ValueError('Model {} is not provided'.format(model_config.type))

Expand Down
9 changes: 9 additions & 0 deletions federatedscope/core/auxiliaries/optimizer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
except ImportError:
torch = None

from federatedscope.differential_privacy.optimizers import *
from federatedscope.register import optimizer_dict

import copy


Expand All @@ -25,6 +28,12 @@ def get_optimizer(model, type, lr, **kwargs):
else:
return getattr(torch.optim, type)(model, lr, **tmp_kwargs)
else:
# registered optimizers
for func in optimizer_dict.values():
optimizer = func(type)
if optimizer is not None:
return optimizer(model.parameters(), lr, **tmp_kwargs)

raise NotImplementedError(
'Optimizer {} not implement'.format(type))
else:
Expand Down
2 changes: 1 addition & 1 deletion federatedscope/core/auxiliaries/regularizer_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from federatedscope.register import regularizer_dict
from federatedscope.core.regularizer.proximal_regularizer import *
from federatedscope.core.regularizer import *
try:
from torch.nn import Module
except ImportError:
Expand Down
3 changes: 3 additions & 0 deletions federatedscope/core/auxiliaries/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"fedvattrainer": "FedVATTrainer",
"fedfocaltrainer": "FedFocalTrainer",
"mftrainer": "MFTrainer",
"fmtrainer": "FMTrainer"
}


Expand Down Expand Up @@ -80,6 +81,8 @@ def get_trainer(model=None,
dict_path = "federatedscope.gfl.flitplus.trainer"
elif config.trainer.type.lower() in ['mftrainer']:
dict_path = "federatedscope.mf.trainer.trainer"
elif config.trainer.type.lower() == 'fmtrainer':
dict_path = "federatedscope.differential_privacy.trainers.fmtrainer"
else:
raise ValueError

Expand Down
5 changes: 4 additions & 1 deletion federatedscope/core/regularizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from federatedscope.core.regularizer.proximal_regularizer import *
from federatedscope.core.regularizer.proximal_regularizer import ProximalRegularizer
from federatedscope.core.regularizer.l2_regularizer import L2Regularizer

__all__ = ['ProximalRegularizer', 'L2Regularizer']
36 changes: 36 additions & 0 deletions federatedscope/core/regularizer/l2_regularizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from federatedscope.register import register_regularizer
from torch.nn import Module

import torch

REGULARIZER_NAME = "l2_regularizer"


class L2Regularizer(Module):
"""Returns the l2 norm of weight

Arguments:
p (int): The order of norm.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'p=2'?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modified accordingly

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the funcational mechanism original paper, the noise addition process is performed only once for the objective function, while your implementation performs noise addition in each iteration, is this supported by the theory?

tensor_before: The original matrix or vector
tensor_after: The updated matrix or vector

Returns:
Tensor: the norm of the given udpate.
"""
def __init__(self):
super(L2Regularizer, self).__init__()

def forward(self, ctx):
l2_norm = 0.
for param in ctx.model.parameters():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An interesting discussion about whether l2 regularization should be applied to beta and gamma of BN layers: https://discuss.pytorch.org/t/weight-decay-in-the-optimizers-is-a-bad-idea-especially-with-batchnorm/16994.
IMO, we can keep the implementation same as that in torch.norm

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current torch.norm will calculate the l2 norm for all parameters including bn weight and bias.
For now, we use the parameter skip_bn to aovid calculating l2-norm for bn layer(with 'bn' in its name).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conflicts are solved

l2_norm += torch.sum(param**2)
return l2_norm


def call_l2_regularizer(type):
if type == REGULARIZER_NAME:
regularizer = L2Regularizer
return regularizer


register_regularizer(REGULARIZER_NAME, call_l2_regularizer)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use_gpu: True
device: 0
early_stop:
patience: 5
seed: 12345
federate:
mode: standalone
total_round_num: 300
sample_client_rate: 0.2
data:
root: data/
type: femnist
splits: [0.6,0.2,0.2]
batch_size: 10
subsample: 0.05
num_workers: 0
transform: [['ToTensor'], ['Normalize', {'mean': [0.1307], 'std': [0.3081]}]]
model:
type: convnet2
hidden: 2048
out_channels: 62
dropout: 0.0
train:
local_update_steps: 1
batch_or_epoch: epoch
optimizer:
type: DPGaussianSGD
lr: 0.01
l2_norm_sensitivity: 0.1
noise_multiplier: 0.1
grad:
grad_clip: 5.0
criterion:
type: CrossEntropyLoss
trainer:
type: cvtrainer
eval:
freq: 10
metrics: ['acc', 'correct']
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use_gpu: True
federate:
mode: 'standalone'
total_round_num: 500
client_num: 10
seed: 12345
trainer:
type: 'FMTrainer'
train:
local_update_steps: 10
batch_or_epoch: 'batch'
optimizer:
type: 'SGD'
lr: 0.01
eval:
freq: 20
metrics: ['loss_regular']
count_flops: False
model:
type: 'FMLinearRegression'
data:
type: 'toy'
criterion:
type: MSELoss
regularizer:
type: 'l2_regularizer'
mu: 0.01
Empty file.
27 changes: 27 additions & 0 deletions federatedscope/differential_privacy/composition/compositor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

import numpy as np

# For a client
class AdvancedComposition(object):
pass


class PrivacyAccountantComposition(object):
pass


class RenyiComposition(object):
def __init__(self, sample_rate):
self.orders = [1.5, 1.75, 2, 2.5, 3, 4, 5, 6, 8, 16, 32, 64]

self.budgets = np.zeros_like(self.orders)

self.epsilon = 0

# sampling rate
# alpha rate

def compose(self, scale):
for i, order in enumerate(self.orders):
epsilon = order / scale ** 2
self.budgets[i] += epsilon
4 changes: 4 additions & 0 deletions federatedscope/differential_privacy/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from federatedscope.differential_privacy.model.fm_linear_regression import FMLinearRegression
from federatedscope.differential_privacy.model.fm_logistic_regression import FMLogisticRegression

__all__ = ['FMLinearRegression', 'FMLogisticRegression']
52 changes: 52 additions & 0 deletions federatedscope/differential_privacy/model/fm_linear_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from torch.nn import Parameter
from torch.nn import Module
from torch.nn.init import kaiming_normal_

import numpy as np

import torch
import math


class FMLinearRegression(Module):
"""Implementation of Functional Mechanism for linear regression refer to
`Functional Mechanism: Regression Analysis under Differential Privacy`
[Jun Wang, et al.](https://arxiv.org/abs/1208.0219)

Args:
in_channels (int): the number of dimensions
epsilon (int): the epsilon bound for differential privacy

Note:
The forward function returns the average loss directly, so that we
don't need the criterion function for fm linear regression.
"""
def __init__(self, in_channels, epsilon):
super(FMLinearRegression, self).__init__()
self.w = Parameter(torch.empty(in_channels, 1))
kaiming_normal_(self.w, a=math.sqrt(5))

sensitivity = float(2*(1+2*in_channels+in_channels**2))

self.laplace = torch.distributions.laplace.Laplace(loc=0, scale=sensitivity / epsilon * np.sqrt(2))

def forward(self, x, y):
# J=0
lambda0 = torch.matmul(y.t(), y)
lambda0 += self.laplace.sample(sample_shape=lambda0.size()).to(lambda0.device)
# J=1
lambda1 = -2 * torch.matmul(y.t(), x)
lambda1 += self.laplace.sample(sample_shape=lambda1.size()).to(lambda1.device)
# J=2
lambda2 = torch.matmul(x.t(), x)
lambda2 += self.laplace.sample(sample_shape=lambda2.size()).to(lambda2.device)
w2 = torch.matmul(self.w, self.w.t())

loss_total = lambda0 + torch.sum(lambda1.t() * self.w) + torch.sum(lambda2*w2)

pred = torch.matmul(x, self.w)

return pred, loss_total / x.size(0)



Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from torch.nn import Parameter
from torch.nn import Module
from torch.nn.init import kaiming_normal_

import numpy as np

import torch


class FMLogisticRegression(Module):
"""Implementation of Functional Mechanism for logistic regression refer to
`Functional Mechanism: Regression Analysis under Differential Privacy`
[Jun Wang, et al.](https://arxiv.org/abs/1208.0219)

Args:
in_channels (int): the number of dimensions
epsilon (int): the epsilon bound for differential privacy

Note:
The forward function returns the average loss directly, so that we
don't need the criterion function for fm logistic regression.
"""
def __init__(self, in_channels, epsilon):
super(FMLogisticRegression, self).__init__()
self.w = Parameter(torch.empty(in_channels, 1))
kaiming_normal_(self.w)

sensitivity = 0.25 * in_channels ** 2 + 3 * in_channels
self.laplace = torch.distributions.laplace.Laplace(loc=0, scale=sensitivity / epsilon * np.sqrt(2))

def forward(self, x, y):
if len(y.size()) == 1:
y = torch.unsqueeze(y, dim=-1)
# J=0
lambda0 = np.log(2)
lambda0 += self.laplace.sample(sample_shape=[1]).to(x.device)
# J=1
lambda1 = 0.5 * x - y * x
lambda1 += self.laplace.sample(sample_shape=lambda1.size()).to(lambda1.device)
# J=2
lambda2 = torch.matmul(x.t(), x)
lambda2 += self.laplace.sample(sample_shape=lambda2.size()).to(lambda2.device)
w2 = torch.matmul(self.w, self.w.t())

loss_total = lambda0 * x.size(0) + torch.sum(lambda1.t() * self.w) + 0.125 * torch.sum(lambda2 * w2)

pred = torch.matmul(x, self.w)

return pred, loss_total / x.size(0)
3 changes: 3 additions & 0 deletions federatedscope/differential_privacy/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from federatedscope.differential_privacy.optimizers.dp_optimizer import DPGaussianSGD, DPLaplaceSGD

__all__ = ['DPGaussianSGD', 'DPLaplaceSGD']
Loading