Skip to content

Commit

Permalink
add conditioning
Browse files Browse the repository at this point in the history
  • Loading branch information
ewencedr committed Aug 11, 2023
1 parent 1470891 commit ba172b5
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/data/lhco_jet_feature_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def setup(self, stage: Optional[str] = None):

# cut window
args_to_remove = (mjj >= self.hparams.window_left) & (mjj <= self.hparams.window_right)
mjj_cut = mjj[~args_to_remove]
conditioning = mjj[~args_to_remove]

jet_data_cut = jet_data[~args_to_remove]

Expand All @@ -145,13 +145,21 @@ def setup(self, stage: Optional[str] = None):
],
)

conditioning_train, conditioning_val, conditioning_test = np.split(
conditioning,
[
len(conditioning) - (n_samples_val + n_samples_test),
len(conditioning) - n_samples_test,
],
)

tensor_train = torch.tensor(dataset_train, dtype=torch.float)
tensor_val = torch.tensor(dataset_val, dtype=torch.float)
tensor_test = torch.tensor(dataset_test, dtype=torch.float)

tensor_conditioning_train = torch.zeros(len(dataset_train))
tensor_conditioning_val = torch.zeros(len(dataset_val))
tensor_conditioning_test = torch.zeros(len(dataset_test))
tensor_conditioning_train = torch.tensor(conditioning_train, dtype=torch.float)
tensor_conditioning_val = torch.tensor(conditioning_val, dtype=torch.float)
tensor_conditioning_test = torch.tensor(conditioning_test, dtype=torch.float)

self.data_train = TensorDataset(tensor_train, tensor_conditioning_train)
self.data_val = TensorDataset(tensor_val, tensor_conditioning_val)
Expand Down

0 comments on commit ba172b5

Please sign in to comment.