Skip to content

Commit

Permalink
fix indices
Browse files Browse the repository at this point in the history
  • Loading branch information
ewencedr committed Sep 3, 2023
1 parent 67983f6 commit e4b50b6
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
12 changes: 6 additions & 6 deletions src/data/lhco_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,15 +342,15 @@ def setup(self, stage: Optional[str] = None):
pt_dataset_train_sr = dataset_train_sr.copy()
pt_dataset_val_sr = dataset_val_sr.copy()
if self.hparams.log_pt:
pt_dataset_train[:, :, 0] = np.ma.log(1.0 - pt_dataset_train[:, :, 0]).filled(
pt_dataset_train[:, :, 2] = np.ma.log(1.0 - pt_dataset_train[:, :, 2]).filled(
0
)
pt_dataset_val[:, :, 0] = np.ma.log(1.0 - pt_dataset_val[:, :, 0]).filled(0)
pt_dataset_train_sr[:, :, 0] = np.ma.log(
1.0 - pt_dataset_train_sr[:, :, 0]
pt_dataset_val[:, :, 2] = np.ma.log(1.0 - pt_dataset_val[:, :, 2]).filled(0)
pt_dataset_train_sr[:, :, 2] = np.ma.log(
1.0 - pt_dataset_train_sr[:, :, 2]
).filled(0)
pt_dataset_val_sr[:, :, 0] = np.ma.log(
1.0 - pt_dataset_val_sr[:, :, 0]
pt_dataset_val_sr[:, :, 2] = np.ma.log(
1.0 - pt_dataset_val_sr[:, :, 2]
).filled(0)

means = np.ma.mean(pt_dataset_train, axis=(0, 1))
Expand Down
4 changes: 2 additions & 2 deletions src/utils/data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def generate_data(
jet_samples_batch, means, stds, sigma=normalize_sigma
)
if log_pt:
jet_samples_batch[..., 0] = 1.0 - np.exp(jet_samples_batch[..., 0])
jet_samples_batch[..., 2] = 1.0 - np.exp(jet_samples_batch[..., 2])
if variable_set_sizes:
jet_samples_batch = jet_samples_batch * mask_batch
particle_data_sampled = torch.cat((particle_data_sampled, jet_samples_batch))
Expand Down Expand Up @@ -144,7 +144,7 @@ def generate_data(
jet_samples_batch, means, stds, sigma=normalize_sigma
)
if log_pt:
jet_samples_batch[..., 0] = 1.0 - np.exp(jet_samples_batch[..., 0])
jet_samples_batch[..., 2] = 1.0 - np.exp(jet_samples_batch[..., 2])
if variable_set_sizes:
jet_samples_batch = jet_samples_batch * mask_batch
particle_data_sampled = torch.cat((particle_data_sampled, jet_samples_batch))
Expand Down

0 comments on commit e4b50b6

Please sign in to comment.