-
Notifications
You must be signed in to change notification settings - Fork 214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add serveral differential privacy algorithms #369
base: master
Are you sure you want to change the base?
Changes from 3 commits
5c1f256
4ce3694
464e083
bff9d80
96cccc9
2389bd1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,4 @@ | ||
from federatedscope.core.regularizer.proximal_regularizer import * | ||
from federatedscope.core.regularizer.proximal_regularizer import ProximalRegularizer | ||
from federatedscope.core.regularizer.l2_regularizer import L2Regularizer | ||
|
||
__all__ = ['ProximalRegularizer', 'L2Regularizer'] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from federatedscope.register import register_regularizer | ||
from torch.nn import Module | ||
|
||
import torch | ||
|
||
REGULARIZER_NAME = "l2_regularizer" | ||
|
||
|
||
class L2Regularizer(Module): | ||
"""Returns the l2 norm of weight | ||
|
||
Arguments: | ||
p (int): The order of norm. | ||
tensor_before: The original matrix or vector | ||
tensor_after: The updated matrix or vector | ||
|
||
Returns: | ||
Tensor: the norm of the given udpate. | ||
""" | ||
def __init__(self): | ||
super(L2Regularizer, self).__init__() | ||
|
||
def forward(self, ctx): | ||
l2_norm = 0. | ||
for param in ctx.model.parameters(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An interesting discussion about whether l2 regularization should be applied to beta and gamma of BN layers: https://discuss.pytorch.org/t/weight-decay-in-the-optimizers-is-a-bad-idea-especially-with-batchnorm/16994. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Conflicts are solved |
||
l2_norm += torch.sum(param**2) | ||
return l2_norm | ||
|
||
|
||
def call_l2_regularizer(type): | ||
if type == REGULARIZER_NAME: | ||
regularizer = L2Regularizer | ||
return regularizer | ||
|
||
|
||
register_regularizer(REGULARIZER_NAME, call_l2_regularizer) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
use_gpu: True | ||
device: 0 | ||
early_stop: | ||
patience: 5 | ||
seed: 12345 | ||
federate: | ||
mode: standalone | ||
total_round_num: 300 | ||
sample_client_rate: 0.2 | ||
data: | ||
root: data/ | ||
type: femnist | ||
splits: [0.6,0.2,0.2] | ||
batch_size: 10 | ||
subsample: 0.05 | ||
num_workers: 0 | ||
transform: [['ToTensor'], ['Normalize', {'mean': [0.1307], 'std': [0.3081]}]] | ||
model: | ||
type: convnet2 | ||
hidden: 2048 | ||
out_channels: 62 | ||
dropout: 0.0 | ||
train: | ||
local_update_steps: 1 | ||
batch_or_epoch: epoch | ||
optimizer: | ||
type: DPGaussianSGD | ||
lr: 0.01 | ||
l2_norm_sensitivity: 0.1 | ||
noise_multiplier: 0.1 | ||
grad: | ||
grad_clip: 5.0 | ||
criterion: | ||
type: CrossEntropyLoss | ||
trainer: | ||
type: cvtrainer | ||
eval: | ||
freq: 10 | ||
metrics: ['acc', 'correct'] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
use_gpu: True | ||
federate: | ||
mode: 'standalone' | ||
total_round_num: 500 | ||
client_num: 10 | ||
seed: 12345 | ||
trainer: | ||
type: 'FMTrainer' | ||
train: | ||
local_update_steps: 10 | ||
batch_or_epoch: 'batch' | ||
optimizer: | ||
type: 'SGD' | ||
lr: 0.01 | ||
eval: | ||
freq: 20 | ||
metrics: ['loss_regular'] | ||
count_flops: False | ||
model: | ||
type: 'FMLinearRegression' | ||
data: | ||
type: 'toy' | ||
criterion: | ||
type: MSELoss | ||
regularizer: | ||
type: 'l2_regularizer' | ||
mu: 0.01 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
|
||
import numpy as np | ||
|
||
# For a client | ||
class AdvancedComposition(object): | ||
pass | ||
|
||
|
||
class PrivacyAccountantComposition(object): | ||
pass | ||
|
||
|
||
class RenyiComposition(object): | ||
def __init__(self, sample_rate): | ||
self.orders = [1.5, 1.75, 2, 2.5, 3, 4, 5, 6, 8, 16, 32, 64] | ||
|
||
self.budgets = np.zeros_like(self.orders) | ||
|
||
self.epsilon = 0 | ||
|
||
# sampling rate | ||
# alpha rate | ||
|
||
def compose(self, scale): | ||
for i, order in enumerate(self.orders): | ||
epsilon = order / scale ** 2 | ||
self.budgets[i] += epsilon |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from federatedscope.differential_privacy.model.fm_linear_regression import FMLinearRegression | ||
from federatedscope.differential_privacy.model.fm_logistic_regression import FMLogisticRegression | ||
|
||
__all__ = ['FMLinearRegression', 'FMLogisticRegression'] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from torch.nn import Parameter | ||
from torch.nn import Module | ||
from torch.nn.init import kaiming_normal_ | ||
|
||
import numpy as np | ||
|
||
import torch | ||
import math | ||
|
||
|
||
class FMLinearRegression(Module): | ||
"""Implementation of Functional Mechanism for linear regression refer to | ||
`Functional Mechanism: Regression Analysis under Differential Privacy` | ||
[Jun Wang, et al.](https://arxiv.org/abs/1208.0219) | ||
|
||
Args: | ||
in_channels (int): the number of dimensions | ||
epsilon (int): the epsilon bound for differential privacy | ||
|
||
Note: | ||
The forward function returns the average loss directly, so that we | ||
don't need the criterion function for fm linear regression. | ||
""" | ||
def __init__(self, in_channels, epsilon): | ||
super(FMLinearRegression, self).__init__() | ||
self.w = Parameter(torch.empty(in_channels, 1)) | ||
kaiming_normal_(self.w, a=math.sqrt(5)) | ||
|
||
sensitivity = float(2*(1+2*in_channels+in_channels**2)) | ||
|
||
self.laplace = torch.distributions.laplace.Laplace(loc=0, scale=sensitivity / epsilon * np.sqrt(2)) | ||
|
||
def forward(self, x, y): | ||
# J=0 | ||
lambda0 = torch.matmul(y.t(), y) | ||
lambda0 += self.laplace.sample(sample_shape=lambda0.size()).to(lambda0.device) | ||
# J=1 | ||
lambda1 = -2 * torch.matmul(y.t(), x) | ||
lambda1 += self.laplace.sample(sample_shape=lambda1.size()).to(lambda1.device) | ||
# J=2 | ||
lambda2 = torch.matmul(x.t(), x) | ||
lambda2 += self.laplace.sample(sample_shape=lambda2.size()).to(lambda2.device) | ||
w2 = torch.matmul(self.w, self.w.t()) | ||
|
||
loss_total = lambda0 + torch.sum(lambda1.t() * self.w) + torch.sum(lambda2*w2) | ||
|
||
pred = torch.matmul(x, self.w) | ||
|
||
return pred, loss_total / x.size(0) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from torch.nn import Parameter | ||
from torch.nn import Module | ||
from torch.nn.init import kaiming_normal_ | ||
|
||
import numpy as np | ||
|
||
import torch | ||
|
||
|
||
class FMLogisticRegression(Module): | ||
"""Implementation of Functional Mechanism for logistic regression refer to | ||
`Functional Mechanism: Regression Analysis under Differential Privacy` | ||
[Jun Wang, et al.](https://arxiv.org/abs/1208.0219) | ||
|
||
Args: | ||
in_channels (int): the number of dimensions | ||
epsilon (int): the epsilon bound for differential privacy | ||
|
||
Note: | ||
The forward function returns the average loss directly, so that we | ||
don't need the criterion function for fm logistic regression. | ||
""" | ||
def __init__(self, in_channels, epsilon): | ||
super(FMLogisticRegression, self).__init__() | ||
self.w = Parameter(torch.empty(in_channels, 1)) | ||
kaiming_normal_(self.w) | ||
|
||
sensitivity = 0.25 * in_channels ** 2 + 3 * in_channels | ||
self.laplace = torch.distributions.laplace.Laplace(loc=0, scale=sensitivity / epsilon * np.sqrt(2)) | ||
|
||
def forward(self, x, y): | ||
if len(y.size()) == 1: | ||
y = torch.unsqueeze(y, dim=-1) | ||
# J=0 | ||
lambda0 = np.log(2) | ||
lambda0 += self.laplace.sample(sample_shape=[1]).to(x.device) | ||
# J=1 | ||
lambda1 = 0.5 * x - y * x | ||
lambda1 += self.laplace.sample(sample_shape=lambda1.size()).to(lambda1.device) | ||
# J=2 | ||
lambda2 = torch.matmul(x.t(), x) | ||
lambda2 += self.laplace.sample(sample_shape=lambda2.size()).to(lambda2.device) | ||
w2 = torch.matmul(self.w, self.w.t()) | ||
|
||
loss_total = lambda0 * x.size(0) + torch.sum(lambda1.t() * self.w) + 0.125 * torch.sum(lambda2 * w2) | ||
|
||
pred = torch.matmul(x, self.w) | ||
|
||
return pred, loss_total / x.size(0) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from federatedscope.differential_privacy.optimizers.dp_optimizer import DPGaussianSGD, DPLaplaceSGD | ||
|
||
__all__ = ['DPGaussianSGD', 'DPLaplaceSGD'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'p=2'?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modified accordingly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the funcational mechanism original paper, the noise addition process is performed only once for the objective function, while your implementation performs noise addition in each iteration, is this supported by the theory?