In a hetero-federated learning (vertically partitioned data) setting, multiple parties have different feature sets for the same common user samples. Federated learning enables these parties to collaboratively train a model without sharing their actual data. In FATE-2.0 we introduce our brand new Hetero-NN framework which allows you to quickly set up a hetero federated NN learning task. Since our framework is developed based on pytorch and transformers, it will be easy for you seamlessly integrate your existing dataset, models into our framework.
In this tutorial, we will show you how to run a Hetero-NN task under FATE-2.0 locally without using a FATE-Pipeline. You can refer to this example for local model experimentation, algorithm modification, and testing. Besides, in FATE-2.0 we provides two protection strategis: the SSHE and the FedPass. We will show you how to use them in this tutorial.
To run a Hetero-NN task, several steps are needed:
- Import required classes in a new python script
- Prepare data, datasets, models, loss and optimizers for guest side and host side
- Configure training parameters; initialize a hetero-nn model; set protection strategy
- Prepare the trainer
- Run the training script
In FATE-2.0, our neural network (NN) framework is constructed on the foundations of PyTorch and transformers libraries. This integration facilitates the incorporation of existing models and datasets into federated training. In our HeteroNN module, we use HeteroNNTrainerGuest and HeteroNNTrainerHost to train the model on guest and host side respectively. They are develop based on huggingface trainer so you can specify the training argument in the same way, via TrainingAruuments class.
We also provide a HeteroNNModelGuest and HeteroNNModelHost to wrap the top/bottom model and aggregate layer and provide a unified interface for the trainer. You can define your own bottom/top model structure and pass them to the HeteroNNModelGuest and HeteroNNModelHost. We offer two protion strategies: SSHE and FedPass. You can specify them in the HeteroNNModelGuest and HeteroNNModelHost with SSHEArgument and FedPassArgument.
import torch as t
from fate.arch import Context
from fate.ml.nn.hetero.hetero_nn import HeteroNNTrainerGuest, HeteroNNTrainerHost, TrainingArguments
from fate.ml.nn.model_zoo.hetero_nn_model import HeteroNNModelGuest, HeteroNNModelHost
from fate.ml.nn.model_zoo.hetero_nn_model import SSHEArgument, FedPassArgument, TopModelStrategyArguments
Here we show you an example of using our NN framework, it is a binary classification task whose features are tabular data. You can download our example data from:
And place them in the same directory with your python script.
In this example we will use the SSHEStrategy to protect the data, thus a sshe aggregate layer will be responsible for aggregate the forwards of guest and host side and propagate the gradients back to guest and host side.
FATE-2.0 uses a context object to configure the running environment, including party setting(guest, host and theirs party ids). We can create a context object by calling the create_context function.
def create_ctx(party, session_id='test_fate'):
parties = [("guest", "9999"), ("host", "10000")]
if party == "guest":
local_party = ("guest", "9999")
else:
local_party = ("host", "10000")
context = create_context(local_party, parties=parties, federation_session_id=session_id)
return context
If we run our task with launch() (we will explain later), it can automatically handle the context creation, this chapter will introduce the concept of context and show you how to do it manually.
Before starting training, as in PyTorch, we first define the model structure, prepare data, choose a loss function, and instantiate an optimizer. The following code demonstrates the preparation of data, datasets, models, loss, and optimizers. In a hetero-neural network (Hetero-NN) setting, which differs from a homogeneous (homo) federated learning scenario, features and models are divided, with each party managing its own segment. The code uses 'ctx' to differentiate guest and host codes: the guest has labels and 10 features, thus it creates top/model models, while the host, with 20 features and no label, only creates a bottom model. During the initialization of HeteroNNGuestModel and HeteroNNHostModel, SSHEArgument is passed to build a secure share and homomorphic encryption (SSHE) aggregate layer during training, safeguarding the forward and backward processes.
Similar to using a HuggingFace trainer, TrainingArgument is used for setting training parameters. Note that Hetero-NN currently does not support multi-GPU training, and the SSHE layer is incompatible with GPU training
Once models, datasets are prepared, we can now start the training process.
def get_setting(ctx):
from fate.ml.nn.dataset.table import TableDataset
# prepare data
if ctx.is_on_guest:
ds = TableDataset(to_tensor=True)
ds.load("./breast_hetero_guest.csv")
bottom_model = t.nn.Sequential(
t.nn.Linear(10, 8),
t.nn.ReLU(),
)
top_model = t.nn.Sequential(
t.nn.Linear(8, 1),
t.nn.Sigmoid()
)
model = HeteroNNModelGuest(
top_model=top_model,
bottom_model=bottom_model,
agglayer_arg=SSHEArgument(
guest_in_features=8,
host_in_features=8,
out_features=8,
layer_lr=0.01
)
)
optimizer = t.optim.Adam(model.parameters(), lr=0.01)
loss = t.nn.BCELoss()
else:
ds = TableDataset(to_tensor=True)
ds.load("./breast_hetero_host.csv")
bottom_model = t.nn.Sequential(
t.nn.Linear(20, 8),
t.nn.ReLU(),
)
model = HeteroNNModelHost(
bottom_model=bottom_model,
agglayer_arg=SSHEArgument(
guest_in_features=8,
host_in_features=8,
out_features=8,
layer_lr=0.01
)
)
optimizer = t.optim.Adam(model.parameters(), lr=0.01)
loss = None
args = TrainingArguments(
num_train_epochs=3,
per_device_train_batch_size=256
)
return ds, model, optimizer, loss, args
we add the train() function to initialize trainer for guest and host seperately and add run() function as the entrance for launching the task. The run() function will be called by launch() function in the end of the script. Below is the full code.
import torch as t
from fate.arch import Context
from fate.ml.nn.hetero.hetero_nn import HeteroNNTrainerGuest, HeteroNNTrainerHost, TrainingArguments
from fate.ml.nn.model_zoo.hetero_nn_model import HeteroNNModelGuest, HeteroNNModelHost
from fate.ml.nn.model_zoo.hetero_nn_model import SSHEArgument, FedPassArgument, TopModelStrategyArguments
def train(ctx: Context,
dataset = None,
model = None,
optimizer = None,
loss_func = None,
args: TrainingArguments = None,
):
if ctx.is_on_guest:
trainer = HeteroNNTrainerGuest(ctx=ctx,
model=model,
train_set=dataset,
optimizer=optimizer,
loss_fn=loss_func,
training_args=args
)
else:
trainer = HeteroNNTrainerHost(ctx=ctx,
model=model,
train_set=dataset,
optimizer=optimizer,
training_args=args
)
trainer.train()
return trainer
def predict(trainer, dataset):
return trainer.predict(dataset)
def get_setting(ctx):
from fate.ml.nn.dataset.table import TableDataset
# prepare data
if ctx.is_on_guest:
ds = TableDataset(to_tensor=True)
ds.load("./breast_hetero_guest.csv")
bottom_model = t.nn.Sequential(
t.nn.Linear(10, 8),
t.nn.ReLU(),
)
top_model = t.nn.Sequential(
t.nn.Linear(8, 1),
t.nn.Sigmoid()
)
model = HeteroNNModelGuest(
top_model=top_model,
bottom_model=bottom_model,
agglayer_arg=SSHEArgument(
guest_in_features=8,
host_in_features=8,
out_features=8,
layer_lr=0.01
)
)
optimizer = t.optim.Adam(model.parameters(), lr=0.01)
loss = t.nn.BCELoss()
else:
ds = TableDataset(to_tensor=True)
ds.load("./breast_hetero_host.csv")
bottom_model = t.nn.Sequential(
t.nn.Linear(20, 8),
t.nn.ReLU(),
)
model = HeteroNNModelHost(
bottom_model=bottom_model,
agglayer_arg=SSHEArgument(
guest_in_features=8,
host_in_features=8,
out_features=8,
layer_lr=0.01
)
)
optimizer = t.optim.Adam(model.parameters(), lr=0.01)
loss = None
args = TrainingArguments(
num_train_epochs=3,
per_device_train_batch_size=256
)
return ds, model, optimizer, loss, args
def run(ctx):
ds, model, optimizer, loss, args = get_setting(ctx)
trainer = train(ctx, ds, model, optimizer, loss, args)
pred = predict(trainer, ds)
if ctx.is_on_guest:
# print("pred:", pred)
# compute auc here
from sklearn.metrics import roc_auc_score
print('auc is')
print(roc_auc_score(pred.label_ids, pred.predictions))
if __name__ == '__main__':
from fate.arch.launchers.multiprocess_launcher import launch
launch(run)
Save the code as a python script named 'hetero_nn.py' and run it with the following command:
python hetero_nn.py --parties guest:9999 host:10000 --log_level INFO
Here is the partial outputs of the consle:
[15:16:49] INFO [Rank:0] disabled tracing _trace.py:31
INFO [Rank:0] sample id column not found, generate sample id from 0 to 569 table.py:139
label is None
INFO [Rank:0] use "y" as label column table.py:150
[15:16:49] INFO [Rank:1] disabled tracing _trace.py:31
INFO [Rank:1] sample id column not found, generate sample id from 0 to 569 table.py:139
label is None
INFO [Rank:1] found no "y"/"label"/"target" in input table, no label will be set table.py:153
INFO [Rank:0] ***** Running training ***** trainer.py:1706
INFO [Rank:0] Num examples = 569 trainer.py:1707
INFO [Rank:0] Num Epochs = 3 trainer.py:1708
INFO [Rank:0] Instantaneous batch size per device = 256 trainer.py:1709
INFO [Rank:0] Total train batch size (w. parallel, distributed & accumulation) = 256 trainer.py:1712
INFO [Rank:0] Gradient Accumulation steps = 1 trainer.py:1713
INFO [Rank:0] Total optimization steps = 9 trainer.py:1714
INFO [Rank:0] Number of trainable parameters = 97 trainer.py:1715
INFO [Rank:1] ***** Running training ***** trainer.py:1706
INFO [Rank:1] Num examples = 569 trainer.py:1707
INFO [Rank:1] Num Epochs = 3 trainer.py:1708
INFO [Rank:1] Instantaneous batch size per device = 256 trainer.py:1709
INFO [Rank:1] Total train batch size (w. parallel, distributed & accumulation) = 256 trainer.py:1712
INFO [Rank:1] Gradient Accumulation steps = 1 trainer.py:1713
INFO [Rank:1] Total optimization steps = 9 trainer.py:1714
INFO [Rank:1] Number of trainable parameters = 168 trainer.py:1715
{'loss': 0.7817, 'learning_rate': 0.01, 'epoch': 1.0}
[15:17:13] INFO [Rank:0] {'loss': 0.7817, 'learning_rate': 0.01, 'epoch': 1.0, 'step': 3} trainer_base.py:429
{'loss': 0.0, 'learning_rate': 0.01, 'epoch': 1.0}
[15:17:13] INFO [Rank:1] {'loss': 0.0, 'learning_rate': 0.01, 'epoch': 1.0, 'step': 3} trainer_base.py:429
{'loss': 0.5714, 'learning_rate': 0.01, 'epoch': 2.0}
[15:17:30] INFO [Rank:0] {'loss': 0.5714, 'learning_rate': 0.01, 'epoch': 2.0, 'step': 6} trainer_base.py:429
{'loss': 0.0, 'learning_rate': 0.01, 'epoch': 2.0}
[15:17:30] INFO [Rank:1] {'loss': 0.0, 'learning_rate': 0.01, 'epoch': 2.0, 'step': 6} trainer_base.py:429
{'loss': 0.4975, 'learning_rate': 0.01, 'epoch': 3.0}
[15:17:48] INFO [Rank:0] {'loss': 0.4975, 'learning_rate': 0.01, 'epoch': 3.0, 'step': 9} trainer_base.py:429
{'train_runtime': 58.4774, 'train_samples_per_second': 29.191, 'train_steps_per_second': 0.154, 'train_loss': 0.616881701681349, 'epoch': 3.0}
INFO [Rank:0] {'train_runtime': 58.4774, 'train_samples_per_second': 29.191, 'train_steps_per_second': 0.154, 'total_flos': 0.0, 'train_loss': 0.616881701681349, 'epoch': 3.0, 'step': 9} trainer_base.py:429
INFO [Rank:0] ***** Running Prediction ***** trainer.py:3154
INFO [Rank:0] Num examples = 569 trainer.py:3156
INFO [Rank:0] Batch size = 8 trainer.py:3159
{'loss': 0.0, 'learning_rate': 0.01, 'epoch': 3.0}
[15:17:48] INFO [Rank:1] {'loss': 0.0, 'learning_rate': 0.01, 'epoch': 3.0, 'step': 9} trainer_base.py:429
{'train_runtime': 58.5118, 'train_samples_per_second': 29.174, 'train_steps_per_second': 0.154, 'train_loss': 0.0, 'epoch': 3.0}
INFO [Rank:1] {'train_runtime': 58.5118, 'train_samples_per_second': 29.174, 'train_steps_per_second': 0.154, 'total_flos': 0.0, 'train_loss': 0.0, 'epoch': 3.0, 'step': 9} trainer_base.py:429
INFO [Rank:1] ***** Running Prediction ***** trainer.py:3154
INFO [Rank:1] Num examples = 569 trainer.py:3156
INFO [Rank:1] Batch size = 8 trainer.py:3159
[15:18:07] INFO [Rank:1] Total: 76.9601s, Driver: 18.8432s(24.48%), Federation: 57.9809s(75.34%), Computing: 0.1361s(0.18%) _profile.py:279
auc is
0.9712488769092542
To execute an image classification task with the FedPass protection strategy, a few modifications to the settings are required. In our example, the guest possesses only the labels, while the host holds the image data. Consequently, the guest configures a top model (without a bottom model), and the host sets up a bottom model.
We employ the FedPass strategy, detailed in 'FedPass: Privacy-Preserving Vertical Federated Deep Learning with Adaptive Obfuscation'. This approach enhances privacy in neural networks by integrating private passports for adaptive obfuscation. It incorporates a 'passport layer' that alters scale and bias in response to these private passports, thus offering robust privacy protection without compromising on model performance.
Let us replace the get_setting() function in the previous example with the following code:
def get_setting(ctx):
from fate.ml.nn.dataset.table import TableDataset
import torchvision
# define model
from torch import nn
from torch.nn import init
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, norm_type=None,
relu=False):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
self.norm_type = norm_type
if self.norm_type:
if self.norm_type == 'bn':
self.bn = nn.BatchNorm2d(out_channels)
elif self.norm_type == 'gn':
self.bn = nn.GroupNorm(out_channels // 16, out_channels)
elif self.norm_type == 'in':
self.bn = nn.InstanceNorm2d(out_channels)
else:
raise ValueError("Wrong norm_type")
else:
self.bn = None
if relu:
self.relu = nn.ReLU(inplace=True)
else:
self.relu = None
self.reset_parameters()
def reset_parameters(self):
init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu')
def forward(self, x, scales=None, biases=None):
x = self.conv(x)
if self.norm_type is not None:
x = self.bn(x)
if scales is not None and biases is not None:
x = scales[-1] * x + biases[-1]
if self.relu is not None:
x = self.relu(x)
return x
# host top model
class LeNetBottom(nn.Module):
def __init__(self):
super(LeNetBottom, self).__init__()
self.layer0 = nn.Sequential(
ConvBlock(1, 8, kernel_size=5),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2)
)
def forward(self, x):
x = self.layer0(x)
return x
# guest top model
class LeNetTop(nn.Module):
def __init__(self, out_feat=84):
super(LeNetTop, self).__init__()
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc1act = nn.ReLU(inplace=True)
self.fc2 = nn.Linear(120, 84)
self.fc2act = nn.ReLU(inplace=True)
self.fc3 = nn.Linear(84, out_feat)
def forward(self, x_a):
x = x_a
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = self.fc1act(x)
x = self.fc2(x)
x = self.fc2act(x)
x = self.fc3(x)
return x
# fed simulate tool
from torch.utils.data import Dataset
class NoFeatureDataset(Dataset):
def __init__(self, ds):
self.ds = ds
def __len__(self):
return len(self.ds)
def __getitem__(self, item):
return [self.ds[item][1]]
class NoLabelDataset(Dataset):
def __init__(self, ds):
self.ds = ds
def __len__(self):
return len(self.ds)
def __getitem__(self, item):
return [self.ds[item][0]]
# prepare mnist data
train_data = torchvision.datasets.MNIST(root='./',
train=True, download=True, transform=torchvision.transforms.ToTensor())
if ctx.is_on_guest:
model = HeteroNNModelGuest(
top_model=LeNetTop(),
top_arg=TopModelStrategyArguments(
protect_strategy='fedpass',
fed_pass_arg=FedPassArgument(
layer_type='linear',
in_channels_or_features=84,
hidden_features=64,
out_channels_or_features=10,
passport_mode='multi',
activation='relu',
num_passport=1000,
low=-10
)
)
)
optimizer = t.optim.Adam(model.parameters(), lr=0.01)
loss = t.nn.CrossEntropyLoss()
ds = NoFeatureDataset(train_data)
else:
model = HeteroNNModelHost(
bottom_model=LeNetBottom(),
agglayer_arg=FedPassArgument(
layer_type='conv',
in_channels_or_features=8,
out_channels_or_features=16,
kernel_size=(5, 5),
stride=(1, 1),
passport_mode='multi',
activation='relu',
num_passport=1000
)
)
optimizer = t.optim.Adam(model.parameters(), lr=0.01)
loss = None
ds = NoLabelDataset(train_data)
args = TrainingArguments(
num_train_epochs=3,
per_device_train_batch_size=256,
disable_tqdm=False
)
return ds, model, optimizer, loss, args
def run(ctx):
ds, model, optimizer, loss, args = get_setting(ctx)
trainer = train(ctx, ds, model, optimizer, loss, args)
pred = predict(trainer, ds)
In this configuration, we utilize the LeNet model both as the bottom and top models. The dataset is sourced from torchvision.datasets.MNIST. We use FedPassArgument to establish the FedPass aggregate layer. It's important to note that the FedPass argument for the bottom model is set using agg_layer_arg, and for the top model using top_arg. Both models are equipped with FedPass protection: during training, random passports are generated, which obfuscate the forward hidden features and backward gradients.
Another key aspect is the use of NoFeatureDataset and NoLabelDataset to encapsulate the dataset. This approach reflects the scenario where the guest holds only labels and the host possesses only features. This simplification aids in effectively simulating the federated learning environment.
The task can be submitted using the same command as in the previous example:
python hetero_nn.py --parties guest:9999 host:10000 --log_level INFO