-
Notifications
You must be signed in to change notification settings - Fork 0
/
cnn_30K_binary_2_temp_scaling.py
54 lines (49 loc) · 1.78 KB
/
cnn_30K_binary_2_temp_scaling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
'''
cnn_org: 1D CNN with original data
data: 30K
objective: binary
include mid point: True
temp scaling: True
'''
import sys
sys.path.append('/home/osman/src/temperature_scaling')
import torch
from xrd_analyzer.data.data_loader_30K import get_data_loader
from xrd_analyzer.models.cnn_classification import CNN
from xrd_analyzer.utils.utils import validate_model
from pathlib import Path
from temperature_scaling import ModelWithTemperature
identifier = 'cnn_30K_binary_2'
objective = 'binary'
save_path = Path(__file__).resolve().parent / "outputs" / identifier
arg_dict = {'dataloader': {
'data_ratio': [0.80, 0.10, 0.10],
'batch_size': 256,
'objective': objective,
'include_mid_point': True,
'save_path': save_path,
'random_state': 25,
'num_workers': 42},
'model': {
'objective': objective},
'train': {
'objective': objective,
'save_model': True,
'save_path': save_path,
'model_id': identifier}}
# dataloader
train_data_loader, val_data_loader, test_data_loader = get_data_loader(
**arg_dict['dataloader'])
# model + loss + optimizer
model = CNN(**arg_dict['model'])
loss = torch.nn.CrossEntropyLoss()
epochs = 100
state_dict = torch.load(Path(__file__).resolve().parent / 'outputs' / identifier / f'model_{identifier}_{epochs-1}.pth')['state_dict']
model.load_state_dict(state_dict)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Tune temperature
model = model.to(device)
model = ModelWithTemperature(model)
model.set_temperature(val_data_loader)
test_loss, test_acc = validate_model(model, test_data_loader, device, loss, objective)
print(test_loss, test_acc)