forked from MinkaiXu/GeoDiff
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
158 lines (136 loc) · 5.91 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import argparse
import pickle
import yaml
import torch
from glob import glob
from tqdm.auto import tqdm
from easydict import EasyDict
from models.epsnet import *
from utils.datasets import *
from utils.transforms import *
from utils.misc import *
def num_confs(num:str):
if num.endswith('x'):
return lambda x:x*int(num[:-1])
elif int(num) > 0:
return lambda x:int(num)
else:
raise ValueError()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('ckpt', type=str, help='path for loading the checkpoint')
parser.add_argument('--save_traj', action='store_true', default=False,
help='whether store the whole trajectory for sampling')
parser.add_argument('--resume', type=str, default=None)
parser.add_argument('--tag', type=str, default='')
parser.add_argument('--num_confs', type=num_confs, default=num_confs('2x'))
parser.add_argument('--test_set', type=str, default=None)
parser.add_argument('--start_idx', type=int, default=800)
parser.add_argument('--end_idx', type=int, default=1000)
parser.add_argument('--out_dir', type=str, default=None)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--clip', type=float, default=1000.0)
parser.add_argument('--n_steps', type=int, default=5000,
help='sampling num steps; for DSM framework, this means num steps for each noise scale')
parser.add_argument('--global_start_sigma', type=float, default=0.5,
help='enable global gradients only when noise is low')
parser.add_argument('--w_global', type=float, default=1.0,
help='weight for global gradients')
# Parameters for DDPM
parser.add_argument('--sampling_type', type=str, default='ld',
help='generalized, ddpm_noisy, ld: sampling method for DDIM, DDPM or Langevin Dynamics')
parser.add_argument('--eta', type=float, default=1.0,
help='weight for DDIM and DDPM: 0->DDIM, 1->DDPM')
args = parser.parse_args()
# Load checkpoint
ckpt = torch.load(args.ckpt)
config_path = glob(os.path.join(os.path.dirname(os.path.dirname(args.ckpt)), '*.yml'))[0]
with open(config_path, 'r') as f:
config = EasyDict(yaml.safe_load(f))
seed_all(config.train.seed)
log_dir = os.path.dirname(os.path.dirname(args.ckpt))
# Logging
output_dir = get_new_log_dir(log_dir, 'sample', tag=args.tag)
logger = get_logger('test', output_dir)
logger.info(args)
# Datasets and loaders
logger.info('Loading datasets...')
transforms = Compose([
CountNodesPerGraph(),
AddHigherOrderEdges(order=config.model.edge_order), # Offline edge augmentation
])
if args.test_set is None:
test_set = PackedConformationDataset(config.dataset.test, transform=transforms)
else:
test_set = PackedConformationDataset(args.test_set, transform=transforms)
# Model
logger.info('Loading model...')
model = get_model(ckpt['config'].model).to(args.device)
model.load_state_dict(ckpt['model'])
test_set_selected = []
for i, data in enumerate(test_set):
if not (args.start_idx <= i < args.end_idx): continue
test_set_selected.append(data)
done_smiles = set()
results = []
if args.resume is not None:
with open(args.resume, 'rb') as f:
results = pickle.load(f)
for data in results:
done_smiles.add(data.smiles)
for i, data in enumerate(tqdm(test_set_selected)):
if data.smiles in done_smiles:
logger.info('Molecule#%d is already done.' % i)
continue
num_refs = data.pos_ref.size(0) // data.num_nodes
num_samples = args.num_confs(num_refs)
data_input = data.clone()
data_input['pos_ref'] = None
batch = repeat_data(data_input, num_samples).to(args.device)
clip_local = None
for _ in range(2): # Maximum number of retry
try:
pos_init = torch.randn(batch.num_nodes, 3).to(args.device)
pos_gen, pos_gen_traj = model.langevin_dynamics_sample(
atom_type=batch.atom_type,
pos_init=pos_init,
bond_index=batch.edge_index,
bond_type=batch.edge_type,
batch=batch.batch,
num_graphs=batch.num_graphs,
extend_order=False, # Done in transforms.
n_steps=args.n_steps,
step_lr=1e-6,
w_global=args.w_global,
global_start_sigma=args.global_start_sigma,
clip=args.clip,
clip_local=clip_local,
sampling_type=args.sampling_type,
eta=args.eta
)
pos_gen = pos_gen.cpu()
if args.save_traj:
data.pos_gen = torch.stack(pos_gen_traj)
else:
data.pos_gen = pos_gen
results.append(data)
done_smiles.add(data.smiles)
save_path = os.path.join(output_dir, 'samples_%d.pkl' % i)
logger.info('Saving samples to: %s' % save_path)
with open(save_path, 'wb') as f:
pickle.dump(results, f)
break # No errors occured, break the retry loop
except FloatingPointError:
clip_local = 20
logger.warning('Retrying with local clipping.')
save_path = os.path.join(output_dir, 'samples_all.pkl')
logger.info('Saving samples to: %s' % save_path)
def get_mol_key(data):
for i, d in enumerate(test_set_selected):
if d.smiles == data.smiles:
return i
return -1
results.sort(key=get_mol_key)
with open(save_path, 'wb') as f:
pickle.dump(results, f)