From e4b50b62e05276123876169b8f2e83ed34fcf89c Mon Sep 17 00:00:00 2001 From: Cedric Ewen Date: Sun, 3 Sep 2023 22:26:43 +0200 Subject: [PATCH] fix indices --- src/data/lhco_datamodule.py | 12 ++++++------ src/utils/data_generation.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/data/lhco_datamodule.py b/src/data/lhco_datamodule.py index 13cce63..8eac554 100644 --- a/src/data/lhco_datamodule.py +++ b/src/data/lhco_datamodule.py @@ -342,15 +342,15 @@ def setup(self, stage: Optional[str] = None): pt_dataset_train_sr = dataset_train_sr.copy() pt_dataset_val_sr = dataset_val_sr.copy() if self.hparams.log_pt: - pt_dataset_train[:, :, 0] = np.ma.log(1.0 - pt_dataset_train[:, :, 0]).filled( + pt_dataset_train[:, :, 2] = np.ma.log(1.0 - pt_dataset_train[:, :, 2]).filled( 0 ) - pt_dataset_val[:, :, 0] = np.ma.log(1.0 - pt_dataset_val[:, :, 0]).filled(0) - pt_dataset_train_sr[:, :, 0] = np.ma.log( - 1.0 - pt_dataset_train_sr[:, :, 0] + pt_dataset_val[:, :, 2] = np.ma.log(1.0 - pt_dataset_val[:, :, 2]).filled(0) + pt_dataset_train_sr[:, :, 2] = np.ma.log( + 1.0 - pt_dataset_train_sr[:, :, 2] ).filled(0) - pt_dataset_val_sr[:, :, 0] = np.ma.log( - 1.0 - pt_dataset_val_sr[:, :, 0] + pt_dataset_val_sr[:, :, 2] = np.ma.log( + 1.0 - pt_dataset_val_sr[:, :, 2] ).filled(0) means = np.ma.mean(pt_dataset_train, axis=(0, 1)) diff --git a/src/utils/data_generation.py b/src/utils/data_generation.py index c34092e..e7317ae 100644 --- a/src/utils/data_generation.py +++ b/src/utils/data_generation.py @@ -105,7 +105,7 @@ def generate_data( jet_samples_batch, means, stds, sigma=normalize_sigma ) if log_pt: - jet_samples_batch[..., 0] = 1.0 - np.exp(jet_samples_batch[..., 0]) + jet_samples_batch[..., 2] = 1.0 - np.exp(jet_samples_batch[..., 2]) if variable_set_sizes: jet_samples_batch = jet_samples_batch * mask_batch particle_data_sampled = torch.cat((particle_data_sampled, jet_samples_batch)) @@ -144,7 +144,7 @@ def generate_data( jet_samples_batch, means, stds, sigma=normalize_sigma ) if log_pt: - jet_samples_batch[..., 0] = 1.0 - np.exp(jet_samples_batch[..., 0]) + jet_samples_batch[..., 2] = 1.0 - np.exp(jet_samples_batch[..., 2]) if variable_set_sizes: jet_samples_batch = jet_samples_batch * mask_batch particle_data_sampled = torch.cat((particle_data_sampled, jet_samples_batch))