-
Notifications
You must be signed in to change notification settings - Fork 5
/
config.py
100 lines (86 loc) · 2.79 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
Experiment configuration file
Extended from config file from original PANet Repository
"""
import glob
import itertools
import os
import sacred
from sacred import Experiment
from sacred.observers import FileStorageObserver
from sacred.utils import apply_backspaces_and_linefeeds
from utils import *
sacred.SETTINGS['CONFIG']['READ_ONLY_CONFIG'] = False
sacred.SETTINGS.CAPTURE_MODE = 'no'
ex = Experiment("QNet")
ex.captured_out_filter = apply_backspaces_and_linefeeds
###### Set up source folder ######
source_folders = ['.', './dataloaders', './models', './utils']
sources_to_save = list(itertools.chain.from_iterable(
[glob.glob(f'{folder}/*.py') for folder in source_folders]))
for source_file in sources_to_save:
ex.add_source_file(source_file)
@ex.config
def cfg():
"""Default configurations"""
seed = 2021
gpu_id = 0
num_workers = 0 # 0 for debugging.
mode = 'train'
## dataset
dataset = 'CHAOST2' # i.e. abdominal MRI - 'CHAOST2'; cardiac MRI - CMR
exclude_label = None # None, for not excluding test labels;
# 1 for Liver, 2 for RK, 3 for LK, 4 for Spleen in 'CHAOST2'
if dataset == 'CMR':
n_sv = 1000
else:
n_sv = 5000
min_size = 200
max_slices = 3
use_gt = False # True - use ground truth as training label, False - use supervoxel as training label
eval_fold = 0 # (0-4) for 5-fold cross-validation
test_label = [1, 4] # for evaluation
supp_idx = 0 # choose which case as the support set for evaluation, (0-4) for 'CHAOST2', (0-7) for 'CMR'
n_part = 3 # for evaluation, i.e. 3 chunks
## training
n_steps = 1000
batch_size = 1
n_shot = 1
n_way = 1
n_query = 1
lr_step_gamma = 0.95
bg_wt = 0.1
t_loss_scaler = 0.0
ignore_label = 255
print_interval = 100
save_snapshot_every = 1000
max_iters_per_load = 1000 # epoch size, interval for reloading the dataset
alpha=0.9 # dual-scale
# Network
# reload_model_path = '/home/SQQ/fsmis/ADNet/runs/ADNet_train_CHAOST2_cv0/1/snapshots/1000.pth'
reload_model_path = None
# Prototype Refinement
n_iters=7
optim_type = 'sgd'
optim = {
'lr': 1e-3,
'momentum': 0.9,
'weight_decay': 0.0005,
}
exp_str = '_'.join(
[mode]
+ [dataset, ]
+ [f'cv{eval_fold}'])
path = {
'log_dir': './runs',
'CHAOST2': {'data_dir': './data/CHAOST2'},
'SABS': {'data_dir': './data/SABS'},
'CMR': {'data_dir': './data/CMR'},
}
@ex.config_hook
def add_observer(config, command_name, logger):
"""A hook fucntion to add observer"""
exp_name = f'{ex.path}_{config["exp_str"]}'
observer = FileStorageObserver.create(os.path.join(config['path']['log_dir'], exp_name))
ex.observers.append(observer)
return config