-
Notifications
You must be signed in to change notification settings - Fork 44
/
config.py
60 lines (49 loc) · 2.85 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import configparser
class Config:
def __init__(self, config_path):
parser = configparser.ConfigParser()
parser.read(config_path)
# experiment
self.seed = int(parser.get('experiment', 'seed'))
# training
self.dataset_path = parser.get('training', 'dataset_path')
self.save_dir = parser.get('training', 'save_dir')
self.stage = int(parser.get('training', 'stage'))
self.log_dir = parser.get('training', 'log_dir')
self.log_dir = os.path.join(self.save_dir, f'stage{self.stage}_{self.log_dir}')
self.nThreads = int(parser.get("training", "nThreads"))
self.num_epochs = int(parser.get("training", "num_epochs"))
self.lr = float(parser.get("training", "lr"))
self.batch_size = int(parser.get('training', 'batch_size'))
self.patch_size = int(parser.get('training', 'patch_size'))
self.finetuning = (parser.get('training', 'finetuning') == 'True')
self.save_train_img = (parser.get('training', 'save_train_img') == 'True')
self.scale = int(parser.get('training', 'scale'))
self.num_seq = int(parser.get('training', 'num_seq'))
self.lr_warping_loss_weight = float(parser.get("training", "lr_warping_loss_weight"))
self.hr_warping_loss_weight = float(parser.get("training", "hr_warping_loss_weight"))
self.flow_loss_weight = float(parser.get("training", "flow_loss_weight"))
self.D_TA_loss_weight = float(parser.get("training", "D_TA_loss_weight"))
self.R_TA_loss_weight = float(parser.get("training", "R_TA_loss_weight"))
self.Net_D_weight = float(parser.get("training", "Net_D_weight"))
self.gpu = parser.get("training", "gpu")
# Network
self.in_channels = int(parser.get('network', 'in_channels'))
self.dim = int(parser.get('network', 'dim'))
self.ds_kernel_size = int(parser.get('network', 'ds_kernel_size'))
self.us_kernel_size = int(parser.get('network', 'us_kernel_size'))
self.num_RDB = int(parser.get('network', 'num_RDB'))
self.growth_rate = int(parser.get('network', 'growth_rate'))
self.num_dense_layer = int(parser.get('network', 'num_dense_layer'))
self.num_flow = int(parser.get('network', 'num_flow'))
self.num_FRMA = int(parser.get('network', 'num_FRMA'))
self.num_transformer_block = int(parser.get('network', 'num_transformer_block'))
self.num_heads = int(parser.get('network', 'num_heads'))
self.LayerNorm_type = parser.get('network', 'LayerNorm_type')
self.ffn_expansion_factor = float(parser.get('network', 'ffn_expansion_factor'))
self.bias = (parser.get('network', 'bias') == 'True')
# validation
self.val_period = int(parser.get('validation', 'val_period'))
# test
self.custom_path = parser.get('test', 'custom_path')