Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] federated knowledge distillation with heterogeneous model and molecule #440

Open
wants to merge 75 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
f46aedb
merge contrastive baseline 通 lateset
xkxxfyf Aug 1, 2022
eaa50a1
delete other yamls
xkxxfyf Aug 1, 2022
3ac6f85
Merge branch 'alibaba:master' into master
xkxxfyf Aug 1, 2022
289942b
script debug
xkxxfyf Aug 1, 2022
c1c9bea
debug unit test error
xkxxfyf Aug 2, 2022
f303910
debug
xkxxfyf Aug 2, 2022
5f8ef6d
debug
xkxxfyf Aug 2, 2022
e3e2f64
debug
xkxxfyf Aug 3, 2022
6ac15aa
debug
xkxxfyf Aug 3, 2022
774419a
debug
xkxxfyf Aug 4, 2022
a55b844
debug
xkxxfyf Aug 4, 2022
88aa90a
debug
xkxxfyf Aug 5, 2022
c716db6
debug
xkxxfyf Aug 6, 2022
063e88f
debug
xkxxfyf Aug 6, 2022
8058ae6
debug
xkxxfyf Aug 6, 2022
e543b1b
debug
xkxxfyf Aug 8, 2022
2f8b958
Merge branch 'alibaba:master' into master
xkxxfyf Aug 8, 2022
79825ed
Merge branch 'master' of https://github.com/xkxxfyf/FederatedScope
xkxxfyf Aug 8, 2022
e68c480
Merge branch 'master' of https://github.com/xkxxfyf/FederatedScope
xkxxfyf Aug 21, 2022
00fe908
Merge branch 'alibaba:master' into master
xkxxfyf Aug 22, 2022
ee7ce3c
Merge branch 'alibaba:master' into master
xkxxfyf Aug 30, 2022
aad2038
FedGlobalContrast
xkxxfyf Aug 30, 2022
d03eef5
modify and add docstring
xkxxfyf Sep 4, 2022
a43ceef
create repro_exp shell and report exp result
xkxxfyf Sep 5, 2022
6b42eec
Accelerate computing global loss with GPU and repair load model problem
xkxxfyf Sep 14, 2022
747f8e8
add paper list
xkxxfyf Sep 16, 2022
a13560f
Merge branch 'alibaba:master' into paper-list
xkxxfyf Sep 19, 2022
65738c0
Merge branch 'paper-list'
xkxxfyf Sep 19, 2022
241b42e
add global loss grad and computed graph, and keep the same Non-IID di…
xkxxfyf Sep 19, 2022
a9aebf5
debug worker_builder
xkxxfyf Sep 19, 2022
cd171a0
modify the worker for global loss
xkxxfyf Sep 22, 2022
25be0dd
Merge branch 'alibaba:master' into master
xkxxfyf Oct 11, 2022
7e07c80
resolve review
xkxxfyf Oct 11, 2022
fea4627
Merge branch 'master' of https://github.com/xkxxfyf/FederatedScope
xkxxfyf Oct 11, 2022
6bde8df
create unittest of fedsimclr in cifar10
xkxxfyf Oct 12, 2022
7508570
Update test_simclr_cifar10.py
xkxxfyf Oct 12, 2022
ee46558
Update test_simclr_cifar10.py
xkxxfyf Oct 12, 2022
75526f4
Update test_simclr_cifar10.py
xkxxfyf Oct 13, 2022
7934b13
Update test_simclr_cifar10.py
xkxxfyf Oct 13, 2022
7b4a1a6
delete print and repair for unit-test failing
xkxxfyf Oct 13, 2022
3810fa5
re-try for unit-test timeout error after extending waiting time
xkxxfyf Oct 14, 2022
a474986
re-try for unit-test timeout error with adding shared memory
xkxxfyf Oct 14, 2022
1eed27e
delete never used
xkxxfyf Oct 17, 2022
8ee8f7c
Merge branch 'alibaba:master' into paper-list
xkxxfyf Oct 24, 2022
7298ce7
Merge branch 'master' into paper-list
xkxxfyf Oct 24, 2022
0a395c5
Update utils.py
xkxxfyf Oct 24, 2022
4cf9a96
Update SimCLR.py
xkxxfyf Oct 24, 2022
0999288
modify format
xkxxfyf Oct 24, 2022
a18aa3f
modify format
xkxxfyf Oct 25, 2022
34529d4
Merge branch 'master' of https://github.com/xkxxfyf/FederatedScope
xkxxfyf Oct 25, 2022
d213297
modify yapf format
xkxxfyf Oct 25, 2022
d6fadbc
debug for unit-test
xkxxfyf Oct 25, 2022
a250837
modify for unit-test
xkxxfyf Oct 26, 2022
25f2912
Merge branch 'alibaba:master' into master
xkxxfyf Oct 28, 2022
d306dc0
Update test_simclr_cifar10.py
xkxxfyf Oct 28, 2022
99915fa
Update Cifar10.py
xkxxfyf Oct 28, 2022
90bce1c
Merge branch 'master' into latest-branch
xkxxfyf Nov 21, 2022
2c8112f
Create dataloader_molecule.py
xkxxfyf Nov 23, 2022
c0b3605
Merge branch 'alibaba:master' into master
xkxxfyf Nov 23, 2022
1e63ddb
add models for different datasets
xkxxfyf Nov 27, 2022
0012794
Merge branch 'alibaba:master' into master
xkxxfyf Nov 28, 2022
87f5a54
define DimeNet++ for QM7b
xkxxfyf Nov 28, 2022
6032096
Merge branch 'alibaba:master' into master
xkxxfyf Dec 5, 2022
d1404f7
add trainer modelbuilder and generalmodel
xkxxfyf Dec 6, 2022
9cdccda
Merge branch 'alibaba:master' into master
xkxxfyf Dec 16, 2022
4382e07
Merge branch 'alibaba:master' into master
xkxxfyf Dec 20, 2022
20c3d60
Merge branch 'alibaba:master' into master
xkxxfyf Dec 26, 2022
e6cded5
Merge branch 'alibaba:master' into master
xkxxfyf Feb 21, 2023
877e79c
Merge branch 'alibaba:master' into master
xkxxfyf Mar 5, 2023
fd0735a
Merge branch 'alibaba:master' into master
xkxxfyf Mar 13, 2023
fa055b3
Merge branch 'alibaba:master' into master
xkxxfyf Mar 16, 2023
4fb1bc6
Merge branch 'alibaba:master' into master
xkxxfyf Mar 22, 2023
e5661e4
Merge branch 'alibaba:master' into master
xkxxfyf Apr 12, 2023
0a9dbe1
Merge branch 'alibaba:master' into master
xkxxfyf Apr 19, 2023
cc45e8c
Merge branch 'alibaba:master' into master
xkxxfyf May 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions federatedscope/gfkd/dataloader/dataloader_molecule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from torch_geometric import transforms
from torch_geometric.datasets import TUDataset, MoleculeNet, QM7b
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
m = Chem.MolFromSmiles('c1ccccc1')
m3d=Chem.AddHs(m)
AllChem.EmbedMolecule(m3d, randomSeed=1)



from federatedscope.core.auxiliaries.transform_builder import get_transform


def load_heteromolecule_dataset(config=None):
r"""Convert dataset to Dataloader.
:returns:
data_local_dict
:rtype: Dict {
'client_id': {
'train': DataLoader(),
'val': DataLoader(),
'test': DataLoader()
}
}
"""
splits = config.data.splits
path = config.data.root
name = config.data.type.upper()

# Transforms
transforms_funcs = get_transform(config, 'torch_geometric')

if name.startswith('heterogeneous molecule dataset'.upper()):
dataset = []

TUDdataset_names = ['BZR', 'ENZYMES', 'MUTAG']
MoleculeNet_names = ['ESOL', 'FreeSolv', 'BACE']
for dname in TUDdataset_names:
tmp_dataset = TUDataset(path, dname, **transforms_funcs)
dataset.append(tmp_dataset)
for dname in MoleculeNet_names:
tmp_dataset = MoleculeNet(path, dname, **transforms_funcs)

if dname in ['FreeSolv', 'BACE']:
for i in len(tmp_dataset):
smiles = dataset[i].smiles
mol = Chem.MolFromSmiles(smiles)
mol = AllChem.AddHs(mol)
res = AllChem.EmbedMolecule(mol, randomSeed=1)
# will random generate conformer with seed equal to -1. else fixed random seed.
if res == 0:
try:
AllChem.MMFFOptimizeMolecule(mol)# some conformer can not use MMFF optimize
except:
pass
mol = AllChem.RemoveHs(mol)
coordinates = mol.GetConformer().GetPositions()

elif res == -1:
mol_tmp = Chem.MolFromSmiles(smiles)
AllChem.EmbedMolecule(mol_tmp, maxAttempts=5000, randomSeed=1)
mol_tmp = AllChem.AddHs(mol_tmp, addCoords=True)
try:
AllChem.MMFFOptimizeMolecule(mol_tmp)# some conformer can not use MMFF optimize
except:
pass
mol_tmp = AllChem.RemoveHs(mol_tmp)
coordinates = mol_tmp.GetConformer().GetPositions()

assert dataset[i].x.shape[0] == len(coordinates), "coordinates shape is not align with {}".format(smiles)
tmp_dataset[i] = [tmp_dataset[i], coordinates]
dataset.append(tmp_dataset)
tmp_dataset = QM7b(path, dname, **transforms_funcs)
dataset.append(tmp_dataset)
else:
raise ValueError(f'No dataset named: {name}!')

client_num = min(len(dataset), config.federate.client_num
) if config.federate.client_num > 0 else len(dataset)
config.merge_from_list(['federate.client_num', client_num])

# get local dataset
data_dict = dict()
for client_idx in range(1, len(dataset) + 1):
data_dict[client_idx] = dataset[client_idx - 1]
return data_dict, config
71 changes: 71 additions & 0 deletions federatedscope/gfkd/model/SMILES_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import torch
import torch.nn
import random
import math

class SMILESTransformer(torch.nn.Module):
def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.1):
super(SMILESTransformer, self).__init__()
from torch.nn import TransformerEncoder, TransformerEncoderLayer
self.model_type = 'Transformer'
self.ninp = ninp

self.pos_encoder = torch.nn.Embedding(100, ninp)
self.encoder = torch.nn.Embedding(ntoken, ninp)

self.layer_norm = torch.nn.LayerNorm([ninp])
self.output_layer_norm = torch.nn.LayerNorm([ntoken])
self.input_layer_norm = torch.nn.LayerNorm([ninp])

encoder_layers = TransformerEncoderLayer(d_model=ninp,
nhead=nhead,
dim_feedforward=nhid,
dropout=dropout,
activation='gelu')

self.transformer_encoder = TransformerEncoder(encoder_layers,
nlayers,
norm=self.layer_norm)

self.dropout = torch.nn.Dropout(dropout)

self.decoder = torch.nn.Linear(ninp, ntoken, bias=False)
self.decoder_bias = torch.nn.Parameter(torch.zeros(ntoken))
self.init_weights()


def init_weights(self):
initrange = 0.1
self.encoder.weight.data.normal_(mean=0.0, std=1.0)
self.decoder.weight.data.normal_(mean=0.0, std=1.0)
self.decoder_bias.data.zero_()

self.input_layer_norm.weight.data.fill_(1.0)
self.input_layer_norm.bias.data.zero_()
self.output_layer_norm.weight.data.fill_(1.0)
self.output_layer_norm.bias.data.zero_()
self.layer_norm.weight.data.fill_(1.0)
self.layer_norm.bias.data.zero_()


def forward(self, src, latent_out=False):
pos = torch.arange(0,100).long().to(src.device)

mol_token_emb = self.encoder(src)
pos_emb = self.pos_encoder(pos)
input_emb = pos_emb + mol_token_emb
input_emb = self.input_layer_norm(input_emb)
input_emb = self.dropout(input_emb)
input_emb = input_emb.transpose(0, 1)

attention_mask = torch.ones_like(src).to(src.device)
attention_mask = attention_mask.masked_fill(src!=1., 0.)
attention_mask = attention_mask.bool().to(src.device)

output = self.transformer_encoder(input_emb)

if latent_out:
return output
output = self.decoder(output) + self.decoder_bias

return output
Loading