-
Notifications
You must be signed in to change notification settings - Fork 0
/
gru_30K_binary_0.8.py
58 lines (53 loc) · 1.88 KB
/
gru_30K_binary_0.8.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
'''
gru: GRU with original data
data: 30K
objective: binary
include mid point: True
'''
import torch
from xrd_analyzer.data.data_loader_30K import get_data_loader
from xrd_analyzer.models.gru_classification import GRU
from xrd_analyzer.training.train import Trainer
from pathlib import Path
import json
identifier = 'gru_30K_binary_0.8train'
objective = 'binary'
save_path = Path(__file__).resolve().parent / "outputs" / identifier
if not save_path.exists():
save_path.mkdir()
arg_dict = {'dataloader': {
'data_ratio': [0.80, 0.10, 0.10],
'batch_size': 256,
'objective': objective,
'include_mid_point': True,
'save_path': save_path,
'random_state': 42,
'num_workers': 42},
'model': {
'input_size': 250,
'hidden_size': 128,
'num_layers': 2,
'objective': objective},
'train': {
'objective': objective,
'save_model': True,
'save_path': save_path,
'model_id': identifier}}
# dataloader
train_data_loader, val_data_loader, test_data_loader = get_data_loader(
**arg_dict['dataloader'])
# model + loss + optimizer
model = GRU(**arg_dict['model'])
loss = torch.nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
# training
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
trainer = Trainer(model, optimizer, loss, device, **arg_dict['train'])
trainer.train(train_data_loader, val_data_loader,
test_data_loader, epochs=50)
for k in arg_dict:
if 'save_path' in arg_dict[k]:
arg_dict[k]['save_path'] = str(save_path)
with open(save_path / "args_dict.json", 'w+') as f:
json.dump(arg_dict, f)