-
Notifications
You must be signed in to change notification settings - Fork 3
/
generate_dataset.py
46 lines (39 loc) · 1.52 KB
/
generate_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os
import argparse
from pathlib import Path
import numpy as np
def read_xyz(file):
all_coords = []
try:
while True:
n_atom = int(file.readline())
file.readline()
coords = []
for _ in range(n_atom):
coords.append([float(x) for x in file.readline().split()[1:4]])
all_coords.append(coords)
except (StopIteration, ValueError):
all_coords = np.array(all_coords, dtype=float)
print(all_coords.shape)
return all_coords
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, default='./data')
parser.add_argument('--out', type=str, default='./data')
args = parser.parse_args()
root = Path(args.root)
out = Path(args.out)
for mol in ['ethane', 'malonaldehyde']:
den = np.loadtxt(root / f'{mol}_300K/densities.txt')
train_dir = root / f'{mol}/{mol}_train/'
os.makedirs(train_dir, exist_ok=True)
np.save(train_dir / 'dft_densities.npy', den)
with open(root / f'{mol}_300K/structures.xyz') as f:
np.save(train_dir / 'structures.npy', read_xyz(f))
den = np.loadtxt(root / f'{mol}_300K-test/densities.txt')
test_dir = root / f'{mol}/{mol}_test/'
os.makedirs(test_dir, exist_ok=True)
np.save(test_dir / 'dft_densities.npy', den)
with open(root / f'{mol}_300K-test/structures.xyz') as f:
np.save(test_dir / 'structures.npy', read_xyz(f))
print('Done')