-
Notifications
You must be signed in to change notification settings - Fork 0
/
datasets.py
38 lines (31 loc) · 1.05 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import numpy as np
from scipy.spatial.transform import Rotation as R
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
class EnantiomerDataset(Dataset):
def __init__(self, root, split, rotate=True):
self.root = root
self.split = split
self.rotate = rotate
data = np.load(f'{root}/{split}.npz')
self.z = torch.LongTensor(data['z'])
self.pos = data['pos']
self.label = torch.FloatTensor(data['label'])
def __len__(self):
return len(self.label)
def __getitem__(self, item):
pos = self.pos[item]
if self.rotate:
rot = R.random()
pos = rot.apply(pos)
data = Data(z=self.z, pos=torch.FloatTensor(pos), y=self.label[item])
return data
def get_dataset(cfg):
d_cfg = cfg.copy()
train_cfg = d_cfg.pop('train', {})
test_cfg = d_cfg.pop('test', {})
return (
EnantiomerDataset(split='train', **train_cfg, **d_cfg),
EnantiomerDataset(split='test', **test_cfg, **d_cfg),
)