Skip to content

Commit

Permalink
bugfixes, add dtype support
Browse files Browse the repository at this point in the history
  • Loading branch information
luigifvr committed Mar 29, 2023
1 parent 8445854 commit e539685
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 11 deletions.
24 changes: 15 additions & 9 deletions params/calo_inn.yaml
Original file line number Diff line number Diff line change
@@ -1,36 +1,42 @@
run_name: calo_inn
dtype: float64


#Dataset
loader_module: calo_inn
loader_params:
geant_file: /remote/gpu06/favaro/calo_inn/datasets/cls_data/train_cls_piplus.hdf5
generated_file: /remote/gpu06/favaro/calo_inn/datasets/train_piplus.hdf5
geant_file: /remote/gpu06/favaro/discriminator-metric/data/calo_cls_geant/full_cls_eplus.hdf5
generated_file: /remote/gpu06/favaro/discriminator-metric/data/calo_bay_samples/samples_eplus.hdf5
add_log_energy: True
add_log_layer_ens: True
add_logit_step: False
add_cut: 0.0
train_split: 0.6
test_split: 0.2

# Model
activation: leaky_relu
negative_slope: 0.1
dropout: 0.1
layers: 5
hidden_size: 256
negative_slope: 0.2
dropout: 0.0
layers: 2
hidden_size: 512

# Training
bayesian: False
prior_prec: 0.01
std_init: -9.0

lr: 1.e-3
betas: [0.9, 0.99]
weight_decay: 0.0
epochs: 50
batch_size: 1024
epochs: 150
batch_size: 512
lr_scheduler: reduce_on_plateau
lr_decay_factor: 0.1
lr_patience: 5
checkpoint_interval: 5

# Evaluation
#bayesian_samples: 2
bayesian_samples: 5
#lower_cluster_thresholds: [0.01, 0.1]
#upper_cluster_thresholds: [0.9, 0.99]
9 changes: 9 additions & 0 deletions src/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ def main():
print("Using device " + ("GPU" if use_cuda else "CPU"))
device = torch.device("cuda:0" if use_cuda else "cpu")

dtype = params.get('dtype', 'float32')
if dtype=='float64':
torch.set_default_dtype(torch.float64)
elif dtype=='float16':
torch.set_default_dtype(torch.float16)
elif dtype=='float32':
torch.set_default_dtype(torch.float32)
print("Using dtype {}".format(dtype))

print("Loading data")
loader = import_module(f"src.loaders.{params['loader_module']}")
datasets = loader.load(params["loader_params"])
Expand Down
2 changes: 1 addition & 1 deletion src/loaders/calo_inn.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def create_data(data_path, dataset_list, **kwargs):

if kwargs['add_log_energy']:
data = np.concatenate((data, np.log10(en_test*10).reshape(-1, 1)), axis=1)
data = np.nan_to_num(data, posinf=0, neginf=0)
#data = np.nan_to_num(data, posinf=0, neginf=0)

en0_t = np.log10(en0_t + 1e-8) + 2.
en1_t = np.log10(en1_t + 1e-8) +2.
Expand Down
2 changes: 1 addition & 1 deletion src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def init_data_loaders(self):
"""
make_loader = lambda data, mode: torch.utils.data.DataLoader(
dataset = torch.utils.data.TensorDataset(
torch.tensor(data, device=self.device)
torch.tensor(data, device=self.device, dtype=torch.get_default_dtype())
),
batch_size = self.params["batch_size"],
shuffle = mode == "train",
Expand Down

0 comments on commit e539685

Please sign in to comment.