From ad7b6357f85f40272eaa87a6648c321435ebc5ee Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Thu, 3 Aug 2023 10:54:17 +0200 Subject: [PATCH 01/50] Update jetclass_cond.yaml --- configs/experiment/jetclass_cond.yaml | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/configs/experiment/jetclass_cond.yaml b/configs/experiment/jetclass_cond.yaml index 632c5d3e..0a65a780 100644 --- a/configs/experiment/jetclass_cond.yaml +++ b/configs/experiment/jetclass_cond.yaml @@ -16,11 +16,13 @@ defaults: tags: ["flow_matching", "JetClass", "cond"] +run_note: "" + seed: 12345 trainer: - min_epochs: 10 - max_epochs: 10000 + min_epochs: 1 + max_epochs: 1_500 gradient_clip_val: 0.5 model: @@ -29,9 +31,13 @@ model: local_cond_dim: 0 data: - number_of_used_jets: 200000 + # preprocessing + number_of_used_jets: 1_000_000 use_custom_eta_centering: True # this means we are using eta_rel = eta_particle - eta_jet remove_etadiff_tails: True # remove tracks with | eta_rel | > 1 + normalize: True + normalize_sigma: 5 + # conditioning conditioning_pt: True conditioning_eta: False conditioning_mass: True @@ -44,18 +50,25 @@ callbacks: start_step: 0 save_ema_weights_in_callback_state: True evaluate_ema_weights_instead: True + jetclass_eval: + every_n_epochs: 100 # evaluate every n epochs + additional_eval_epochs: [1, 5, 10, 30, 50, 75] # evaluate at these epochs as well + num_jet_samples: 50_000 # jet samples to generate + jetclass_eval_test: + num_jet_samples: 200_000 # jet samples to generate #early_stopping: # monitor: "val/loss" # patience: 2000 # mode: "min" -task_name: "jetclass_flow_matching" +task_name: "jetclass_cond" logger: wandb: tags: ${tags} - group: "jetclass_flow_matching" + group: "flow_matching_jetclass" name: ${task_name} comet: experiment_name: ${task_name} + project_name: "flow-matching" From fcbf07e1790dab429081b3b1de0a75ccb837c26c Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Fri, 4 Aug 2023 12:27:52 +0200 Subject: [PATCH 02/50] Add first draft of multi-jettype loading --- configs/data/jetclass.yaml | 19 ++++++- src/callbacks/jetclass_eval.py | 3 + src/data/jetclass_datamodule.py | 99 ++++++++++++++++++++++++++++++--- 3 files changed, 112 insertions(+), 9 deletions(-) diff --git a/configs/data/jetclass.yaml b/configs/data/jetclass.yaml index 0ec2b448..4f37deaa 100644 --- a/configs/data/jetclass.yaml +++ b/configs/data/jetclass.yaml @@ -1,6 +1,23 @@ _target_: src.data.jetclass_datamodule.JetClassDataModule data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass_npz -data_filename: jetclass_TTBar_2_000_000.npz + +jet_types: + ttbar: + latex_label: $t(\\rightarrow qq')$ + files: + - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBar_100_000.npz + # - ${data.data_dir}/jetclass_TTBar_100_000.npz + qcd: + latex_label: QCD + files: + - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_100_000.npz + # - ${data_dir}/jetclass_ZJetsToNuNu_100_000.npz + hbb: + latex_label: $H(\\rightarrow b\\bar{b})$ + files: + - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_100_000.npz + # - ${data.data_dir}/jetclass_ZJetsToNuNu_100_000.npz + batch_size: 1024 num_workers: 32 pin_memory: False diff --git a/src/callbacks/jetclass_eval.py b/src/callbacks/jetclass_eval.py index db94edae..a867b3b4 100644 --- a/src/callbacks/jetclass_eval.py +++ b/src/callbacks/jetclass_eval.py @@ -155,6 +155,9 @@ def on_train_start(self, trainer, pl_module) -> None: elif self.ema_callback is not None and self.use_ema: pylogger.info("Using EMA weights for evaluation.") + # TODO: maybe add here crosscheck plots (e.g. the jet mass of different + # jet types to ensure the labels are not messed up etc (+ other variables)) + def on_train_epoch_end(self, trainer, pl_module): if self.fix_seed: # fix seed for better reproducibility and comparable results diff --git a/src/data/jetclass_datamodule.py b/src/data/jetclass_datamodule.py index 859f8a57..f038b93f 100644 --- a/src/data/jetclass_datamodule.py +++ b/src/data/jetclass_datamodule.py @@ -65,8 +65,8 @@ def teardown(self): def __init__( self, - data_dir: str = "data/", - data_filename: str = "jetclass.npz", + data_dir: str, + jet_types: dict = None, number_of_used_jets: int = None, val_fraction: float = 0.15, test_fraction: float = 0.15, @@ -91,6 +91,9 @@ def __init__( ): super().__init__() + if jet_types is None: + raise ValueError("`jet_types` must be specified in the datamodule.") + # this line allows to access init params with 'self.hparams' attribute # also ensures init params will be stored in ckpt self.save_hyperparameters(logger=False) @@ -140,21 +143,88 @@ def setup(self, stage: Optional[str] = None): # load and split datasets only if not loaded already if not self.data_train and not self.data_val and not self.data_test: + particle_features_list = [] + jet_features_list = [] + labels_list = [] + + names_particle_features_previous = None + names_jet_features_previous = None + names_labels_previous = None + filename_previous = None + + print(self.hparams.jet_types) # data loading - path = f"{self.hparams.data_dir}/{self.hparams.data_filename}" - npfile = np.load(path, allow_pickle=True) + for jet_type, jet_type_dict in self.hparams.jet_types.items(): + for filename in jet_type_dict["files"]: + print(f"Loading {filename}") + npfile = np.load(filename, allow_pickle=True) + particle_features_list.append(npfile["part_features"]) + jet_features_list.append(npfile["jet_features"]) + labels_list.append(npfile["labels"]) + + # Check that the labels are in the same order for all files + names_particle_features = npfile["names_part_features"] + names_jet_features = npfile["names_jet_features"] + names_labels = npfile["names_labels"] + + if ( + names_particle_features_previous is None + and names_jet_features_previous is None + and names_labels_previous is None + ): + # first file + pass + else: + if not np.all(names_particle_features == names_particle_features_previous): + raise ValueError( + "Names of particle features are not the same for all files." + f"\n{filename_previous}: {names_particle_features_previous}" + f"\n{filename}: {names_particle_features}" + ) + if not np.all(names_jet_features == names_jet_features_previous): + raise ValueError( + "Names of jet features are not the same for all files." + f"\n{filename_previous}: {names_jet_features_previous}" + f"\n{filename}: {names_jet_features}" + ) + if not np.all(names_labels == names_labels_previous): + raise ValueError( + "Names of labels are not the same for all files." + f"\n{filename_previous}: {names_labels_previous}" + f"\n{filename}: {names_labels}" + ) + names_particle_features_previous = names_particle_features + names_jet_features_previous = names_jet_features + names_labels_previous = names_labels + filename_previous = filename + + particle_features = np.concatenate(particle_features_list) + jet_features = np.concatenate(jet_features_list) + labels = np.concatenate(labels_list) + + # shuffle data + np.random.seed(42) + permutation = np.random.permutation(len(labels)) + particle_features = particle_features[permutation] + jet_features = jet_features[permutation] + labels = labels[permutation] + + print("Loaded data.") + print(f"particle_features.shape = {particle_features.shape}") + print(f"jet_features.shape = {jet_features.shape}") + print(f"labels.shape = {labels.shape}") - particle_features = npfile["part_features"] - jet_features = npfile["jet_features"] if self.hparams.number_of_used_jets is not None: if self.hparams.number_of_used_jets < len(jet_features): particle_features = particle_features[: self.hparams.number_of_used_jets] jet_features = jet_features[: self.hparams.number_of_used_jets] + labels = labels[: self.hparams.number_of_used_jets] + # TODO: check that these are consistent over all loaded files! + # --> raise an error otherwise names_part_features = npfile["names_part_features"] names_jet_features = npfile["names_jet_features"] - # TODO: anything to do with labels? - # labels = npfile["labels"] + names_labels = npfile["names_labels"] # NOTE: everything below here assumes that the particle features # array after preprocessing stores the features [eta_rel, phi_rel, pt_rel] @@ -225,6 +295,13 @@ def setup(self, stage: Optional[str] = None): len(ma_particle_features) - n_samples_test, ], ) + labels_train, labels_val, labels_test = np.split( + labels, + [ + len(labels) - (n_samples_val + n_samples_test), + len(labels) - n_samples_test, + ], + ) if self.num_cond_features == 0: self.tensor_conditioning_train = torch.zeros(len(dataset_train)) self.tensor_conditioning_val = torch.zeros(len(dataset_val)) @@ -253,6 +330,12 @@ def setup(self, stage: Optional[str] = None): self.tensor_train = torch.tensor(dataset_train[:, :, :3], dtype=torch.float32) self.tensor_test = torch.tensor(dataset_test[:, :, :3], dtype=torch.float32) self.tensor_val = torch.tensor(dataset_val[:, :, :3], dtype=torch.float32) + self.labels_train = torch.tensor(labels_train, dtype=torch.float32) + self.labels_test = torch.tensor(labels_test, dtype=torch.float32) + self.labels_val = torch.tensor(labels_val, dtype=torch.float32) + self.names_part_features = names_part_features + self.names_jet_features = names_jet_features + self.names_labels = names_labels if self.hparams.normalize: # calculate means and stds only based on the training data From 8db4d43369ebc69a340ae1ef332c40bf5ab0bde5 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Fri, 4 Aug 2023 13:51:00 +0200 Subject: [PATCH 03/50] Switch to flat labels (not one-hot encoded) --- src/data/jetclass_datamodule.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/data/jetclass_datamodule.py b/src/data/jetclass_datamodule.py index f038b93f..ade27009 100644 --- a/src/data/jetclass_datamodule.py +++ b/src/data/jetclass_datamodule.py @@ -80,6 +80,7 @@ def __init__( conditioning_eta: bool = True, conditioning_mass: bool = True, conditioning_num_particles: bool = True, + # conditioning_jet_type: bool = True, num_particles: int = 128, # preprocessing normalize: bool = True, @@ -131,6 +132,7 @@ def num_cond_features(self): self.hparams.conditioning_eta, self.hparams.conditioning_mass, self.hparams.conditioning_num_particles, + # self.hparams.conditioning_jet_type, ] ) @@ -200,7 +202,9 @@ def setup(self, stage: Optional[str] = None): particle_features = np.concatenate(particle_features_list) jet_features = np.concatenate(jet_features_list) - labels = np.concatenate(labels_list) + labels_one_hot = np.concatenate(labels_list) + + labels = np.argmax(labels_one_hot, axis=1) # shuffle data np.random.seed(42) @@ -220,21 +224,15 @@ def setup(self, stage: Optional[str] = None): jet_features = jet_features[: self.hparams.number_of_used_jets] labels = labels[: self.hparams.number_of_used_jets] - # TODO: check that these are consistent over all loaded files! - # --> raise an error otherwise - names_part_features = npfile["names_part_features"] - names_jet_features = npfile["names_jet_features"] - names_labels = npfile["names_labels"] - # NOTE: everything below here assumes that the particle features # array after preprocessing stores the features [eta_rel, phi_rel, pt_rel] # check if the particle features are in the correct order - index_part_deta = get_feat_index(names_part_features, "part_deta") + index_part_deta = get_feat_index(names_particle_features, "part_deta") assert index_part_deta == 0, "part_deta is not the first feature" - index_part_dphi = get_feat_index(names_part_features, "part_dphi") + index_part_dphi = get_feat_index(names_particle_features, "part_dphi") assert index_part_dphi == 1, "part_dphi is not the second feature" - index_part_pt = get_feat_index(names_part_features, "part_pt") + index_part_pt = get_feat_index(names_particle_features, "part_pt") assert index_part_pt == 2, "part_pt is not the third feature" # divide particle pt by jet pt @@ -245,7 +243,7 @@ def setup(self, stage: Optional[str] = None): # instead of using the part_deta variable, use part_eta - jet_eta if self.hparams.use_custom_eta_centering: - if "part_eta" not in names_part_features: + if "part_eta" not in names_particle_features: raise ValueError( "`use_custom_eta_centering` is True, but `part_eta` is not in " "in the dataset --> check the dataset" @@ -259,7 +257,7 @@ def setup(self, stage: Optional[str] = None): jet_eta_repeat = jet_features[:, index_jet_eta][:, np.newaxis].repeat( particle_features.shape[1], 1 ) - index_part_eta = get_feat_index(names_part_features, "part_eta") + index_part_eta = get_feat_index(names_particle_features, "part_eta") particle_eta_minus_jet_eta = ( particle_features[:, :, index_part_eta] - jet_eta_repeat ) @@ -333,7 +331,7 @@ def setup(self, stage: Optional[str] = None): self.labels_train = torch.tensor(labels_train, dtype=torch.float32) self.labels_test = torch.tensor(labels_test, dtype=torch.float32) self.labels_val = torch.tensor(labels_val, dtype=torch.float32) - self.names_part_features = names_part_features + self.names_particle_features = names_particle_features self.names_jet_features = names_jet_features self.names_labels = names_labels From cb5e0052a305001948c05cb5675dea279341e283 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Fri, 4 Aug 2023 14:26:37 +0200 Subject: [PATCH 04/50] add jetclass support for conditioning on jet type --- src/data/jetclass_datamodule.py | 46 +++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/src/data/jetclass_datamodule.py b/src/data/jetclass_datamodule.py index ade27009..cbcab042 100644 --- a/src/data/jetclass_datamodule.py +++ b/src/data/jetclass_datamodule.py @@ -80,7 +80,7 @@ def __init__( conditioning_eta: bool = True, conditioning_mass: bool = True, conditioning_num_particles: bool = True, - # conditioning_jet_type: bool = True, + conditioning_jet_type: bool = True, num_particles: int = 128, # preprocessing normalize: bool = True, @@ -132,7 +132,7 @@ def num_cond_features(self): self.hparams.conditioning_eta, self.hparams.conditioning_mass, self.hparams.conditioning_num_particles, - # self.hparams.conditioning_jet_type, + self.hparams.conditioning_jet_type, ] ) @@ -304,13 +304,16 @@ def setup(self, stage: Optional[str] = None): self.tensor_conditioning_train = torch.zeros(len(dataset_train)) self.tensor_conditioning_val = torch.zeros(len(dataset_val)) self.tensor_conditioning_test = torch.zeros(len(dataset_test)) + self.names_conditioning = None else: - jet_features = self._handle_conditioning(jet_features, names_jet_features) + conditioning_features, self.names_conditioning = self._handle_conditioning( + jet_features, names_jet_features, labels + ) (conditioning_train, conditioning_val, conditioning_test) = np.split( - jet_features, + conditioning_features, [ - len(jet_features) - (n_samples_val + n_samples_test), - len(jet_features) - n_samples_test, + len(conditioning_features) - (n_samples_val + n_samples_test), + len(conditioning_features) - n_samples_test, ], ) self.tensor_conditioning_train = torch.tensor( @@ -421,12 +424,18 @@ def load_state_dict(self, state_dict: Dict[str, Any]): """Things to do when loading checkpoint.""" pass - def _handle_conditioning(self, jet_data: np.array, names_jet_data: np.array): + def _handle_conditioning(self, jet_data: np.array, names_jet_data: np.array, labels: np.array): """Select the conditioning variables. - jet_data: np.array of shape (n_jets, n_features) - names_jet_data: np.array of shape (n_features,) which contains the names of - the features + Args: + jet_data: np.array of shape (n_jets, n_features) + names_jet_data: np.array of shape (n_features,) which contains the names of + the features + labels: np.array of shape (n_jets,) which contains the labels / jet-types + Returns: + conditioning_data: np.array of shape (n_jets, n_conditioning_features) + names_conditioning_data: np.array of shape (n_conditioning_features,) which + contains the names of the conditioning features """ if ( @@ -434,6 +443,7 @@ def _handle_conditioning(self, jet_data: np.array, names_jet_data: np.array): and not self.hparams.conditioning_eta and not self.hparams.conditioning_mass and not self.hparams.conditioning_num_particles + and not self.hparams.conditioning_jet_type ): return None @@ -449,7 +459,21 @@ def _handle_conditioning(self, jet_data: np.array, names_jet_data: np.array): if self.hparams.conditioning_num_particles: keep_col.append(get_feat_index(names_jet_data, "jet_nparticles")) - return jet_data[:, keep_col] + if self.hparams.conditioning_jet_type: + # in this case also add the jet type to the conditioning variables + # and to the names array + return ( + np.concatenate( + ( + jet_data[:, keep_col], + labels[:, np.newaxis], + ), + axis=1, + ), + list(names_jet_data[keep_col]) + ["jet_type"], + ) + + return jet_data[:, keep_col], names_jet_data[keep_col] if __name__ == "__main__": From ef6866ae83e6907abe421ba38e7a16c8193239f0 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Fri, 4 Aug 2023 14:33:42 +0200 Subject: [PATCH 05/50] add comet folder to gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index df93fedd..8c3cceab 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ lightning_logs/ logs/ +.cometml-runs # Byte-compiled / optimized / DLL files __pycache__/ From 7e9c6b1b2110546128d29ba20b7af8b981d35d47 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Fri, 4 Aug 2023 14:37:27 +0200 Subject: [PATCH 06/50] Update configs to try out conditioning on jet type --- configs/experiment/jetclass_cond.yaml | 11 ++++++----- configs/experiment/jetclass_dev.yaml | 7 ++++--- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/configs/experiment/jetclass_cond.yaml b/configs/experiment/jetclass_cond.yaml index 0a65a780..807896db 100644 --- a/configs/experiment/jetclass_cond.yaml +++ b/configs/experiment/jetclass_cond.yaml @@ -22,7 +22,7 @@ seed: 12345 trainer: min_epochs: 1 - max_epochs: 1_500 + max_epochs: 500 gradient_clip_val: 0.5 model: @@ -32,16 +32,17 @@ model: data: # preprocessing - number_of_used_jets: 1_000_000 + number_of_used_jets: null use_custom_eta_centering: True # this means we are using eta_rel = eta_particle - eta_jet remove_etadiff_tails: True # remove tracks with | eta_rel | > 1 normalize: True normalize_sigma: 5 # conditioning - conditioning_pt: True + conditioning_pt: False conditioning_eta: False - conditioning_mass: True + conditioning_mass: False conditioning_num_particles: False + conditioning_jet_type: True callbacks: ema: @@ -62,7 +63,7 @@ callbacks: # patience: 2000 # mode: "min" -task_name: "jetclass_cond" +task_name: "jetclass_cond_jettype" logger: wandb: diff --git a/configs/experiment/jetclass_dev.yaml b/configs/experiment/jetclass_dev.yaml index 55fba924..0ffb79b2 100644 --- a/configs/experiment/jetclass_dev.yaml +++ b/configs/experiment/jetclass_dev.yaml @@ -17,8 +17,8 @@ defaults: -tags: ["flow_matching", "JetClass", "uncond", "dev", "debug"] -run_note: "This is a note" +tags: ["flow_matching", "JetClass", "dev", "debug"] +run_note: "Test run with conditioning on jet type" seed: 12345 @@ -29,7 +29,7 @@ trainer: model: num_particles: 128 - global_cond_dim: 0 # needs to be calculated when using conditioning + global_cond_dim: 1 # needs to be calculated when using conditioning local_cond_dim: 0 data: @@ -40,6 +40,7 @@ data: conditioning_eta: False conditioning_mass: False conditioning_num_particles: False + conditioning_jet_type: True callbacks: ema: From b6796235692f4edc4e692f2443f02fcb5d22c8cd Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Fri, 4 Aug 2023 14:44:25 +0200 Subject: [PATCH 07/50] Raise error when using files corresponding to multiple jet types without conditioning on jet type --- src/data/jetclass_datamodule.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/data/jetclass_datamodule.py b/src/data/jetclass_datamodule.py index cbcab042..3d7d77cf 100644 --- a/src/data/jetclass_datamodule.py +++ b/src/data/jetclass_datamodule.py @@ -154,7 +154,11 @@ def setup(self, stage: Optional[str] = None): names_labels_previous = None filename_previous = None - print(self.hparams.jet_types) + if len(self.hparams.jet_types) > 1 and not self.hparams.conditioning_jet_type: + raise ValueError( + "Multiple jet types are specified in the datamodule, but " + "`conditioning_jet_type` is set to False." + ) # data loading for jet_type, jet_type_dict in self.hparams.jet_types.items(): for filename in jet_type_dict["files"]: From 2557e418e5186a20b91a8daaa3c568d0dc0be869 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Fri, 4 Aug 2023 14:47:15 +0200 Subject: [PATCH 08/50] Fix wrong number of cond variables --- configs/experiment/jetclass_cond.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/experiment/jetclass_cond.yaml b/configs/experiment/jetclass_cond.yaml index 807896db..0c94adc9 100644 --- a/configs/experiment/jetclass_cond.yaml +++ b/configs/experiment/jetclass_cond.yaml @@ -27,7 +27,7 @@ trainer: model: num_particles: 128 - global_cond_dim: 2 # needs to be calculated when using conditioning + global_cond_dim: 1 # needs to be calculated when using conditioning local_cond_dim: 0 data: From fb9ee302995b8c434457e2b2da122317e9787d49 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Fri, 4 Aug 2023 16:17:33 +0200 Subject: [PATCH 09/50] This seems to work --- configs/experiment/jetclass_dev.yaml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/configs/experiment/jetclass_dev.yaml b/configs/experiment/jetclass_dev.yaml index 0ffb79b2..1d59c5ee 100644 --- a/configs/experiment/jetclass_dev.yaml +++ b/configs/experiment/jetclass_dev.yaml @@ -18,6 +18,7 @@ defaults: tags: ["flow_matching", "JetClass", "dev", "debug"] + run_note: "Test run with conditioning on jet type" seed: 12345 @@ -33,9 +34,13 @@ model: local_cond_dim: 0 data: - number_of_used_jets: 200 + # preprocessing + number_of_used_jets: 100000 use_custom_eta_centering: True # this means we are using eta_rel = eta_particle - eta_jet remove_etadiff_tails: True # remove tracks with | eta_rel | > 1 + normalize: True + normalize_sigma: 5 + # conditioning conditioning_pt: False conditioning_eta: False conditioning_mass: False From 6c1367451cc36caa368af037417e4b5f995c872b Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Fri, 4 Aug 2023 16:27:13 +0200 Subject: [PATCH 10/50] Ok no but this works now (without conditioning) --- configs/data/jetclass.yaml | 18 +++++++++--------- configs/experiment/jetclass_dev.yaml | 4 ++-- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/configs/data/jetclass.yaml b/configs/data/jetclass.yaml index 4f37deaa..8ded1305 100644 --- a/configs/data/jetclass.yaml +++ b/configs/data/jetclass.yaml @@ -7,16 +7,16 @@ jet_types: files: - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBar_100_000.npz # - ${data.data_dir}/jetclass_TTBar_100_000.npz - qcd: - latex_label: QCD - files: - - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_100_000.npz + # qcd: + # latex_label: QCD + # files: + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_100_000.npz # - ${data_dir}/jetclass_ZJetsToNuNu_100_000.npz - hbb: - latex_label: $H(\\rightarrow b\\bar{b})$ - files: - - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_100_000.npz - # - ${data.data_dir}/jetclass_ZJetsToNuNu_100_000.npz + # hbb: + # latex_label: $H(\\rightarrow b\\bar{b})$ + # files: + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_100_000.npz + # # - ${data.data_dir}/jetclass_ZJetsToNuNu_100_000.npz batch_size: 1024 num_workers: 32 diff --git a/configs/experiment/jetclass_dev.yaml b/configs/experiment/jetclass_dev.yaml index 1d59c5ee..a6199d9b 100644 --- a/configs/experiment/jetclass_dev.yaml +++ b/configs/experiment/jetclass_dev.yaml @@ -30,7 +30,7 @@ trainer: model: num_particles: 128 - global_cond_dim: 1 # needs to be calculated when using conditioning + global_cond_dim: 0 # needs to be calculated when using conditioning local_cond_dim: 0 data: @@ -45,7 +45,7 @@ data: conditioning_eta: False conditioning_mass: False conditioning_num_particles: False - conditioning_jet_type: True + conditioning_jet_type: False callbacks: ema: From d6e7dc33d9c722b08a39088add63d9968d6bef93 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Fri, 4 Aug 2023 22:19:25 +0200 Subject: [PATCH 11/50] This configuration make jetclass_dev crash - which doesn't make sense because it uses only one jet type --> with conditioning on jet type this should only be a useless node, but apparently isn't... there must be something wrong here --- configs/data/jetclass.yaml | 7 +- configs/experiment/jetclass_cond.yaml | 2 +- configs/experiment/jetclass_dev.yaml | 8 +- src/data/jetclass_datamodule.py | 114 +++++++++++++++++--------- 4 files changed, 87 insertions(+), 44 deletions(-) diff --git a/configs/data/jetclass.yaml b/configs/data/jetclass.yaml index 8ded1305..2fbe63ec 100644 --- a/configs/data/jetclass.yaml +++ b/configs/data/jetclass.yaml @@ -4,15 +4,18 @@ data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass_npz jet_types: ttbar: latex_label: $t(\\rightarrow qq')$ + n_jets_per_file: 300_000 files: - - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBar_100_000.npz + - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBar_300_000.npz # - ${data.data_dir}/jetclass_TTBar_100_000.npz # qcd: # latex_label: QCD + # n_jets_per_file: 100_000 # files: # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_100_000.npz - # - ${data_dir}/jetclass_ZJetsToNuNu_100_000.npz + # # - ${data_dir}/jetclass_ZJetsToNuNu_100_000.npz # hbb: + # n_jets_per_file: 100_000 # latex_label: $H(\\rightarrow b\\bar{b})$ # files: # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_100_000.npz diff --git a/configs/experiment/jetclass_cond.yaml b/configs/experiment/jetclass_cond.yaml index 0c94adc9..02528e26 100644 --- a/configs/experiment/jetclass_cond.yaml +++ b/configs/experiment/jetclass_cond.yaml @@ -16,7 +16,7 @@ defaults: tags: ["flow_matching", "JetClass", "cond"] -run_note: "" +run_note: "Test run with conditioning on jet type" seed: 12345 diff --git a/configs/experiment/jetclass_dev.yaml b/configs/experiment/jetclass_dev.yaml index a6199d9b..a1474934 100644 --- a/configs/experiment/jetclass_dev.yaml +++ b/configs/experiment/jetclass_dev.yaml @@ -25,17 +25,17 @@ seed: 12345 trainer: min_epochs: 1 - max_epochs: 2 + max_epochs: 200 gradient_clip_val: 0.5 model: num_particles: 128 - global_cond_dim: 0 # needs to be calculated when using conditioning + global_cond_dim: 1 # needs to be calculated when using conditioning local_cond_dim: 0 data: # preprocessing - number_of_used_jets: 100000 + number_of_used_jets: 300000 use_custom_eta_centering: True # this means we are using eta_rel = eta_particle - eta_jet remove_etadiff_tails: True # remove tracks with | eta_rel | > 1 normalize: True @@ -45,7 +45,7 @@ data: conditioning_eta: False conditioning_mass: False conditioning_num_particles: False - conditioning_jet_type: False + conditioning_jet_type: True callbacks: ema: diff --git a/src/data/jetclass_datamodule.py b/src/data/jetclass_datamodule.py index 3d7d77cf..287bd96c 100644 --- a/src/data/jetclass_datamodule.py +++ b/src/data/jetclass_datamodule.py @@ -8,9 +8,9 @@ from src.utils.pylogger import get_pylogger -from .components import normalize_tensor +from .components import normalize_tensor, one_hot_encode -log = get_pylogger("JetClassDataModule") +pylogger = get_pylogger("JetClassDataModule") def get_feat_index(names_array: np.array, name: str): @@ -154,19 +154,22 @@ def setup(self, stage: Optional[str] = None): names_labels_previous = None filename_previous = None - if len(self.hparams.jet_types) > 1 and not self.hparams.conditioning_jet_type: - raise ValueError( - "Multiple jet types are specified in the datamodule, but " - "`conditioning_jet_type` is set to False." - ) + # TODO: the error raise below should be used + # if len(self.hparams.jet_types) > 1 and not self.hparams.conditioning_jet_type: + # raise ValueError( + # "Multiple jet types are specified in the datamodule, but " + # "`conditioning_jet_type` is set to False." + # ) # data loading for jet_type, jet_type_dict in self.hparams.jet_types.items(): for filename in jet_type_dict["files"]: - print(f"Loading {filename}") + pylogger.info(f"Loading {filename}") npfile = np.load(filename, allow_pickle=True) - particle_features_list.append(npfile["part_features"]) - jet_features_list.append(npfile["jet_features"]) - labels_list.append(npfile["labels"]) + # TODO: remove the n_jets_per_file parameter + stop = jet_type_dict["n_jets_per_file"] + particle_features_list.append(npfile["part_features"][:stop]) + jet_features_list.append(npfile["jet_features"][:stop]) + labels_list.append(npfile["labels"][:stop]) # Check that the labels are in the same order for all files names_particle_features = npfile["names_part_features"] @@ -217,10 +220,11 @@ def setup(self, stage: Optional[str] = None): jet_features = jet_features[permutation] labels = labels[permutation] - print("Loaded data.") - print(f"particle_features.shape = {particle_features.shape}") - print(f"jet_features.shape = {jet_features.shape}") - print(f"labels.shape = {labels.shape}") + pylogger.info("Loaded data.") + pylogger.info("Shapes of arrays as available in files:") + pylogger.info(f"particle_features.shape = {particle_features.shape}") + pylogger.info(f"jet_features.shape = {jet_features.shape}") + pylogger.info(f"labels.shape = {labels.shape}") if self.hparams.number_of_used_jets is not None: if self.hparams.number_of_used_jets < len(jet_features): @@ -342,6 +346,11 @@ def setup(self, stage: Optional[str] = None): self.names_jet_features = names_jet_features self.names_labels = names_labels + # the tensors below will be overwritten if normalization is used + self.tensor_conditioning_train_dl = torch.zeros(len(dataset_train)) + self.tensor_conditioning_val_dl = torch.zeros(len(dataset_val)) + self.tensor_conditioning_test_dl = torch.zeros(len(dataset_test)) + if self.hparams.normalize: # calculate means and stds only based on the training data self.means = np.ma.mean(dataset_train, axis=(0, 1)) @@ -364,26 +373,55 @@ def setup(self, stage: Optional[str] = None): self.tensor_test_dl = torch.tensor( norm_dataset_test[:, :, :3], dtype=torch.float32 ) + if self.num_cond_features > 0: + means_cond = torch.mean(self.tensor_conditioning_train, axis=0) + stds_cond = torch.std(self.tensor_conditioning_train, axis=0) + # Train + self.tensor_conditioning_train_dl = normalize_tensor( + self.tensor_conditioning_train, + means_cond, + stds_cond, + sigma=self.hparams.normalize_sigma, + ) + + # Validation + self.tensor_conditioning_val_dl = normalize_tensor( + self.tensor_conditioning_val, + means_cond, + stds_cond, + sigma=self.hparams.normalize_sigma, + ) + + # Test + self.tensor_conditioning_test_dl = normalize_tensor( + self.tensor_conditioning_test, + means_cond, + stds_cond, + sigma=self.hparams.normalize_sigma, + ) else: self.tensor_train_dl = torch.tensor(dataset_train[:, :, :3], dtype=torch.float32) self.tensor_test_dl = torch.tensor(dataset_test[:, :, :3], dtype=torch.float32) self.tensor_val_dl = torch.tensor(dataset_val[:, :, :3], dtype=torch.float32) + self.tensor_conditioning_train_dl = self.tensor_conditioning_train + self.tensor_conditioning_val_dl = self.tensor_conditioning_val + self.tensor_conditioning_test_dl = self.tensor_conditioning_test self.data_train = TensorDataset( self.tensor_train_dl, self.mask_train, - self.tensor_conditioning_train, + self.tensor_conditioning_train_dl, ) self.data_val = TensorDataset( self.tensor_val_dl, self.mask_val, - self.tensor_conditioning_val, + self.tensor_conditioning_val_dl, ) self.data_test = TensorDataset( self.tensor_test_dl, self.mask_test, - self.tensor_conditioning_test, + self.tensor_conditioning_test_dl, ) def train_dataloader(self): @@ -441,7 +479,12 @@ def _handle_conditioning(self, jet_data: np.array, names_jet_data: np.array, lab names_conditioning_data: np.array of shape (n_conditioning_features,) which contains the names of the conditioning features """ + categories = np.unique(jet_data[:, 0]) + jet_data_one_hot = one_hot_encode( + jet_data, categories=[categories], num_other_features=jet_data.shape[1] - 1 + ) + one_hot_len = len(categories) if ( not self.hparams.conditioning_pt and not self.hparams.conditioning_eta @@ -454,30 +497,27 @@ def _handle_conditioning(self, jet_data: np.array, names_jet_data: np.array, lab # select the columns which correspond to the conditioning variables that should be used keep_col = [] + names_conditioning_data = [] + + print(jet_data.shape) + if self.hparams.conditioning_jet_type: + keep_col += list(np.arange(one_hot_len)) + names_conditioning_data += [f"jet_type_{i:.0f}" for i in categories] if self.hparams.conditioning_pt: - keep_col.append(get_feat_index(names_jet_data, "jet_pt")) + keep_col.append(get_feat_index(names_jet_data, "jet_pt") + one_hot_len - 1) + names_conditioning_data.append("jet_pt") if self.hparams.conditioning_eta: - keep_col.append(get_feat_index(names_jet_data, "jet_eta")) + keep_col.append(get_feat_index(names_jet_data, "jet_eta") + one_hot_len - 1) + names_conditioning_data.append("jet_eta") if self.hparams.conditioning_mass: - keep_col.append(get_feat_index(names_jet_data, "jet_sdmass")) + keep_col.append(get_feat_index(names_jet_data, "jet_sdmass") + one_hot_len - 1) + names_conditioning_data.append("jet_sdmass") if self.hparams.conditioning_num_particles: - keep_col.append(get_feat_index(names_jet_data, "jet_nparticles")) - - if self.hparams.conditioning_jet_type: - # in this case also add the jet type to the conditioning variables - # and to the names array - return ( - np.concatenate( - ( - jet_data[:, keep_col], - labels[:, np.newaxis], - ), - axis=1, - ), - list(names_jet_data[keep_col]) + ["jet_type"], - ) + keep_col.append(get_feat_index(names_jet_data, "jet_nparticles") + one_hot_len - 1) + names_conditioning_data.append("jet_nparticles") + print(keep_col) - return jet_data[:, keep_col], names_jet_data[keep_col] + return jet_data_one_hot[:, keep_col], names_conditioning_data if __name__ == "__main__": From ed0882307c9f234e6896ebc94f96b5ceaaa52727 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Fri, 4 Aug 2023 22:21:47 +0200 Subject: [PATCH 12/50] In contrast to previous commit: this configuration doesn't lead to nan loss. It's the same configuration, just without conditioning on the jet type (which should be redundant since only using one jet type) --- configs/experiment/jetclass_dev.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/experiment/jetclass_dev.yaml b/configs/experiment/jetclass_dev.yaml index a1474934..543556ef 100644 --- a/configs/experiment/jetclass_dev.yaml +++ b/configs/experiment/jetclass_dev.yaml @@ -30,7 +30,7 @@ trainer: model: num_particles: 128 - global_cond_dim: 1 # needs to be calculated when using conditioning + global_cond_dim: 0 # needs to be calculated when using conditioning local_cond_dim: 0 data: @@ -45,7 +45,7 @@ data: conditioning_eta: False conditioning_mass: False conditioning_num_particles: False - conditioning_jet_type: True + conditioning_jet_type: False callbacks: ema: From 8ab05bd73da9d68da0d14648da01c106070ba382 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Fri, 4 Aug 2023 22:36:21 +0200 Subject: [PATCH 13/50] Revert "In contrast to previous commit: this configuration doesn't lead to nan loss. It's the same configuration, just without conditioning on the jet type (which should be redundant since only using one jet type)" This reverts commit ed0882307c9f234e6896ebc94f96b5ceaaa52727. --- configs/experiment/jetclass_dev.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/experiment/jetclass_dev.yaml b/configs/experiment/jetclass_dev.yaml index 543556ef..a1474934 100644 --- a/configs/experiment/jetclass_dev.yaml +++ b/configs/experiment/jetclass_dev.yaml @@ -30,7 +30,7 @@ trainer: model: num_particles: 128 - global_cond_dim: 0 # needs to be calculated when using conditioning + global_cond_dim: 1 # needs to be calculated when using conditioning local_cond_dim: 0 data: @@ -45,7 +45,7 @@ data: conditioning_eta: False conditioning_mass: False conditioning_num_particles: False - conditioning_jet_type: False + conditioning_jet_type: True callbacks: ema: From 63ced68bb475284b5e6eaaf759003ad8980cd44a Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Fri, 4 Aug 2023 22:49:26 +0200 Subject: [PATCH 14/50] Comment-out conditional normalization for now. This is wild. Conditioning on jet type runs with two types of jets when using 30k jets, but crashes when changing to 50k jets --> what is going on here? --- configs/data/jetclass.yaml | 12 ++--- configs/experiment/jetclass_dev.yaml | 4 +- src/data/jetclass_datamodule.py | 81 +++++++++++++++------------- 3 files changed, 53 insertions(+), 44 deletions(-) diff --git a/configs/data/jetclass.yaml b/configs/data/jetclass.yaml index 2fbe63ec..7c8192cc 100644 --- a/configs/data/jetclass.yaml +++ b/configs/data/jetclass.yaml @@ -8,12 +8,12 @@ jet_types: files: - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBar_300_000.npz # - ${data.data_dir}/jetclass_TTBar_100_000.npz - # qcd: - # latex_label: QCD - # n_jets_per_file: 100_000 - # files: - # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_100_000.npz - # # - ${data_dir}/jetclass_ZJetsToNuNu_100_000.npz + qcd: + latex_label: QCD + n_jets_per_file: 300_000 + files: + - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_300_000.npz + # - ${data_dir}/jetclass_ZJetsToNuNu_100_000.npz # hbb: # n_jets_per_file: 100_000 # latex_label: $H(\\rightarrow b\\bar{b})$ diff --git a/configs/experiment/jetclass_dev.yaml b/configs/experiment/jetclass_dev.yaml index a1474934..6cabc20e 100644 --- a/configs/experiment/jetclass_dev.yaml +++ b/configs/experiment/jetclass_dev.yaml @@ -30,12 +30,12 @@ trainer: model: num_particles: 128 - global_cond_dim: 1 # needs to be calculated when using conditioning + global_cond_dim: 2 # needs to be calculated when using conditioning local_cond_dim: 0 data: # preprocessing - number_of_used_jets: 300000 + number_of_used_jets: 30000 use_custom_eta_centering: True # this means we are using eta_rel = eta_particle - eta_jet remove_etadiff_tails: True # remove tracks with | eta_rel | > 1 normalize: True diff --git a/src/data/jetclass_datamodule.py b/src/data/jetclass_datamodule.py index 287bd96c..1095e4f1 100644 --- a/src/data/jetclass_datamodule.py +++ b/src/data/jetclass_datamodule.py @@ -314,6 +314,7 @@ def setup(self, stage: Optional[str] = None): self.tensor_conditioning_test = torch.zeros(len(dataset_test)) self.names_conditioning = None else: + print("At least one conditioning feature is used.") conditioning_features, self.names_conditioning = self._handle_conditioning( jet_features, names_jet_features, labels ) @@ -331,6 +332,7 @@ def setup(self, stage: Optional[str] = None): self.tensor_conditioning_test = torch.tensor( conditioning_test, dtype=torch.float32 ) + # nan-fine until here # invert the masks from the masked arrays (numpy ma masks are True for masked values) self.mask_train = torch.tensor(~dataset_train.mask[:, :, :1], dtype=torch.float32) @@ -346,11 +348,6 @@ def setup(self, stage: Optional[str] = None): self.names_jet_features = names_jet_features self.names_labels = names_labels - # the tensors below will be overwritten if normalization is used - self.tensor_conditioning_train_dl = torch.zeros(len(dataset_train)) - self.tensor_conditioning_val_dl = torch.zeros(len(dataset_val)) - self.tensor_conditioning_test_dl = torch.zeros(len(dataset_test)) - if self.hparams.normalize: # calculate means and stds only based on the training data self.means = np.ma.mean(dataset_train, axis=(0, 1)) @@ -373,40 +370,54 @@ def setup(self, stage: Optional[str] = None): self.tensor_test_dl = torch.tensor( norm_dataset_test[:, :, :3], dtype=torch.float32 ) - if self.num_cond_features > 0: - means_cond = torch.mean(self.tensor_conditioning_train, axis=0) - stds_cond = torch.std(self.tensor_conditioning_train, axis=0) - # Train - self.tensor_conditioning_train_dl = normalize_tensor( - self.tensor_conditioning_train, - means_cond, - stds_cond, - sigma=self.hparams.normalize_sigma, - ) - - # Validation - self.tensor_conditioning_val_dl = normalize_tensor( - self.tensor_conditioning_val, - means_cond, - stds_cond, - sigma=self.hparams.normalize_sigma, - ) - - # Test - self.tensor_conditioning_test_dl = normalize_tensor( - self.tensor_conditioning_test, - means_cond, - stds_cond, - sigma=self.hparams.normalize_sigma, - ) + # if self.num_cond_features > 0: + # means_cond = torch.mean(self.tensor_conditioning_train, axis=0) + # stds_cond = torch.std(self.tensor_conditioning_train, axis=0) + # # Train + # self.tensor_conditioning_train_dl = normalize_tensor( + # self.tensor_conditioning_train, + # means_cond, + # stds_cond, + # sigma=self.hparams.normalize_sigma, + # ) + + # # Validation + # self.tensor_conditioning_val_dl = normalize_tensor( + # self.tensor_conditioning_val, + # means_cond, + # stds_cond, + # sigma=self.hparams.normalize_sigma, + # ) + + # # Test + # self.tensor_conditioning_test_dl = normalize_tensor( + # self.tensor_conditioning_test, + # means_cond, + # stds_cond, + # sigma=self.hparams.normalize_sigma, + # ) else: self.tensor_train_dl = torch.tensor(dataset_train[:, :, :3], dtype=torch.float32) self.tensor_test_dl = torch.tensor(dataset_test[:, :, :3], dtype=torch.float32) self.tensor_val_dl = torch.tensor(dataset_val[:, :, :3], dtype=torch.float32) - self.tensor_conditioning_train_dl = self.tensor_conditioning_train - self.tensor_conditioning_val_dl = self.tensor_conditioning_val - self.tensor_conditioning_test_dl = self.tensor_conditioning_test + + self.tensor_conditioning_train_dl = self.tensor_conditioning_train + self.tensor_conditioning_val_dl = self.tensor_conditioning_val + self.tensor_conditioning_test_dl = self.tensor_conditioning_test + print("First 10 values of conditioning data:") + print(self.tensor_conditioning_train_dl[:10]) + + # check if conditioning data contains nan values + print("Checking conditioning data for NaNs...") + if ( + torch.isnan(self.tensor_conditioning_train_dl).any() + or torch.isnan(self.tensor_conditioning_val_dl).any() + or torch.isnan(self.tensor_conditioning_test_dl).any() + ): + print("NaNs found in conditioning data!") + else: + print("GOOD NEWS: No NaNs found in conditioning data.") self.data_train = TensorDataset( self.tensor_train_dl, @@ -499,7 +510,6 @@ def _handle_conditioning(self, jet_data: np.array, names_jet_data: np.array, lab keep_col = [] names_conditioning_data = [] - print(jet_data.shape) if self.hparams.conditioning_jet_type: keep_col += list(np.arange(one_hot_len)) names_conditioning_data += [f"jet_type_{i:.0f}" for i in categories] @@ -515,7 +525,6 @@ def _handle_conditioning(self, jet_data: np.array, names_jet_data: np.array, lab if self.hparams.conditioning_num_particles: keep_col.append(get_feat_index(names_jet_data, "jet_nparticles") + one_hot_len - 1) names_conditioning_data.append("jet_nparticles") - print(keep_col) return jet_data_one_hot[:, keep_col], names_conditioning_data From 39104cdffd76440b5f4d3528a9929479bf3486e8 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Sat, 5 Aug 2023 00:17:19 +0200 Subject: [PATCH 15/50] The good news: this works (with ttbar and Hbb jets) --> maybe the problem is really just the quark jets? --- configs/data/jetclass.yaml | 20 ++++++++++---------- configs/experiment/jetclass_dev.yaml | 6 +++--- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/configs/data/jetclass.yaml b/configs/data/jetclass.yaml index 7c8192cc..db1888e2 100644 --- a/configs/data/jetclass.yaml +++ b/configs/data/jetclass.yaml @@ -8,18 +8,18 @@ jet_types: files: - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBar_300_000.npz # - ${data.data_dir}/jetclass_TTBar_100_000.npz - qcd: - latex_label: QCD + # qcd: + # latex_label: QCD + # n_jets_per_file: 300_000 + # files: + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_300_000.npz + # # - ${data_dir}/jetclass_ZJetsToNuNu_100_000.npz + hbb: n_jets_per_file: 300_000 + latex_label: $H(\\rightarrow b\\bar{b})$ files: - - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_300_000.npz - # - ${data_dir}/jetclass_ZJetsToNuNu_100_000.npz - # hbb: - # n_jets_per_file: 100_000 - # latex_label: $H(\\rightarrow b\\bar{b})$ - # files: - # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_100_000.npz - # # - ${data.data_dir}/jetclass_ZJetsToNuNu_100_000.npz + - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_300_000.npz + # - ${data.data_dir}/jetclass_ZJetsToNuNu_100_000.npz batch_size: 1024 num_workers: 32 diff --git a/configs/experiment/jetclass_dev.yaml b/configs/experiment/jetclass_dev.yaml index 6cabc20e..a77df384 100644 --- a/configs/experiment/jetclass_dev.yaml +++ b/configs/experiment/jetclass_dev.yaml @@ -25,7 +25,7 @@ seed: 12345 trainer: min_epochs: 1 - max_epochs: 200 + max_epochs: 2 gradient_clip_val: 0.5 model: @@ -35,11 +35,11 @@ model: data: # preprocessing - number_of_used_jets: 30000 + number_of_used_jets: 300000 use_custom_eta_centering: True # this means we are using eta_rel = eta_particle - eta_jet remove_etadiff_tails: True # remove tracks with | eta_rel | > 1 normalize: True - normalize_sigma: 5 + normalize_sigma: 1 # conditioning conditioning_pt: False conditioning_eta: False From 3fae81391fb61d0f7f8a7a0e3e4b2044f56f7d8e Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Sat, 5 Aug 2023 00:22:45 +0200 Subject: [PATCH 16/50] Commit config that was used for first actual training with ttbar and Hbb jets --- configs/experiment/jetclass_cond.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/configs/experiment/jetclass_cond.yaml b/configs/experiment/jetclass_cond.yaml index 02528e26..46e54d03 100644 --- a/configs/experiment/jetclass_cond.yaml +++ b/configs/experiment/jetclass_cond.yaml @@ -22,17 +22,17 @@ seed: 12345 trainer: min_epochs: 1 - max_epochs: 500 + max_epochs: 2000 gradient_clip_val: 0.5 model: num_particles: 128 - global_cond_dim: 1 # needs to be calculated when using conditioning + global_cond_dim: 2 # needs to be calculated when using conditioning local_cond_dim: 0 data: # preprocessing - number_of_used_jets: null + number_of_used_jets: 300000 use_custom_eta_centering: True # this means we are using eta_rel = eta_particle - eta_jet remove_etadiff_tails: True # remove tracks with | eta_rel | > 1 normalize: True From 7e8c9a6cb90c22894982acc74a51a55e55ab5a60 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Sat, 5 Aug 2023 00:25:29 +0200 Subject: [PATCH 17/50] Update jetclass.yaml experiment --- configs/experiment/jetclass.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/configs/experiment/jetclass.yaml b/configs/experiment/jetclass.yaml index c1c3cbf1..dfc00502 100644 --- a/configs/experiment/jetclass.yaml +++ b/configs/experiment/jetclass.yaml @@ -16,7 +16,7 @@ defaults: tags: ["flow_matching", "JetClass", "uncond"] -run_note: "This is a note" +run_note: "Crosscheck that without conditioning everything still works" seed: 12345 @@ -42,6 +42,7 @@ data: conditioning_eta: False conditioning_mass: False conditioning_num_particles: False + conditioning_jet_type: False callbacks: ema: From 58aae36183a79ffbbc56bc8c35c41e3a2911ff3b Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Sun, 6 Aug 2023 17:01:58 +0200 Subject: [PATCH 18/50] Important bugfix: this corrects the mask calculation (before this fix, the number of particles from the jet variable jet_nparticles did not agree with the number calculated from the mask - now it does!) --- src/data/jetclass_datamodule.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/data/jetclass_datamodule.py b/src/data/jetclass_datamodule.py index 1095e4f1..5a9b6437 100644 --- a/src/data/jetclass_datamodule.py +++ b/src/data/jetclass_datamodule.py @@ -286,10 +286,15 @@ def setup(self, stage: Optional[str] = None): # convert to masked array (more convenient for normalization later on, because # the mask is unaffected) # Note: numpy masks are True for masked values + particle_mask_zero_entries = (particle_features[:, :, 2] == 0)[..., np.newaxis] ma_particle_features = np.ma.masked_array( particle_features, - mask=np.ma.make_mask(particle_features == 0), + mask=np.repeat( + particle_mask_zero_entries, repeats=particle_features.shape[2], axis=2 + ), ) + # TODO: add check that no jets without particles are allowed + # --> either raise an error or remove the jet from the dataset # data splitting n_samples_val = int(self.hparams.val_fraction * len(particle_features)) From 55ca9ac7238156a48a946aeea54b958378cc909f Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Sun, 6 Aug 2023 18:43:42 +0200 Subject: [PATCH 19/50] Add comment about using pt for masking + adding error raise in case there are jets without any tracks --- src/data/jetclass_datamodule.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/data/jetclass_datamodule.py b/src/data/jetclass_datamodule.py index 5a9b6437..1898146b 100644 --- a/src/data/jetclass_datamodule.py +++ b/src/data/jetclass_datamodule.py @@ -286,6 +286,8 @@ def setup(self, stage: Optional[str] = None): # convert to masked array (more convenient for normalization later on, because # the mask is unaffected) # Note: numpy masks are True for masked values + # Important: use pt_rel for masking, because eta_rel and phi_rel can be zero + # even though it is a valid track particle_mask_zero_entries = (particle_features[:, :, 2] == 0)[..., np.newaxis] ma_particle_features = np.ma.masked_array( particle_features, @@ -293,8 +295,12 @@ def setup(self, stage: Optional[str] = None): particle_mask_zero_entries, repeats=particle_features.shape[2], axis=2 ), ) - # TODO: add check that no jets without particles are allowed - # --> either raise an error or remove the jet from the dataset + n_jets_without_particles = np.sum(np.sum(~particle_mask_zero_entries, axis=1) == 0) + if n_jets_without_particles > 0: + raise NotImplementedError( + f"There are {n_jets_without_particles} jets without particles in " + "the dataset. This is not allowed, since the model cannot handle this case." + ) # data splitting n_samples_val = int(self.hparams.val_fraction * len(particle_features)) From 77e204ccf55867a9a5ee8a73bf33d7771eab0c26 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Sun, 6 Aug 2023 19:01:01 +0200 Subject: [PATCH 20/50] Removed the "n_jets_per_file" parameter that was just used for debugging --- configs/data/jetclass.yaml | 21 +++++++++------------ src/data/jetclass_datamodule.py | 8 +++----- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/configs/data/jetclass.yaml b/configs/data/jetclass.yaml index db1888e2..d1026cb6 100644 --- a/configs/data/jetclass.yaml +++ b/configs/data/jetclass.yaml @@ -4,22 +4,19 @@ data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass_npz jet_types: ttbar: latex_label: $t(\\rightarrow qq')$ - n_jets_per_file: 300_000 files: - - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBar_300_000.npz - # - ${data.data_dir}/jetclass_TTBar_100_000.npz - # qcd: - # latex_label: QCD - # n_jets_per_file: 300_000 - # files: - # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_300_000.npz - # # - ${data_dir}/jetclass_ZJetsToNuNu_100_000.npz + - ${data.data_dir}/jetclass_TTBar_100_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBar_300_000.npz + qcd: + latex_label: QCD + files: + - ${data.data_dir}/jetclass_ZJetsToNuNu_100_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_300_000.npz hbb: - n_jets_per_file: 300_000 latex_label: $H(\\rightarrow b\\bar{b})$ files: - - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_300_000.npz - # - ${data.data_dir}/jetclass_ZJetsToNuNu_100_000.npz + - ${data.data_dir}/jetclass_ZJetsToNuNu_100_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_300_000.npz batch_size: 1024 num_workers: 32 diff --git a/src/data/jetclass_datamodule.py b/src/data/jetclass_datamodule.py index 1898146b..6d5c7963 100644 --- a/src/data/jetclass_datamodule.py +++ b/src/data/jetclass_datamodule.py @@ -165,11 +165,9 @@ def setup(self, stage: Optional[str] = None): for filename in jet_type_dict["files"]: pylogger.info(f"Loading {filename}") npfile = np.load(filename, allow_pickle=True) - # TODO: remove the n_jets_per_file parameter - stop = jet_type_dict["n_jets_per_file"] - particle_features_list.append(npfile["part_features"][:stop]) - jet_features_list.append(npfile["jet_features"][:stop]) - labels_list.append(npfile["labels"][:stop]) + particle_features_list.append(npfile["part_features"]) + jet_features_list.append(npfile["jet_features"]) + labels_list.append(npfile["labels"]) # Check that the labels are in the same order for all files names_particle_features = npfile["names_part_features"] From 541209425183191ee2e84d5514684049697957cf Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Sun, 6 Aug 2023 19:02:11 +0200 Subject: [PATCH 21/50] Add again the error raise for cases where conditioning on jet type is false but multiple jettypes are in the dataset --- src/data/jetclass_datamodule.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/data/jetclass_datamodule.py b/src/data/jetclass_datamodule.py index 6d5c7963..83d39ebf 100644 --- a/src/data/jetclass_datamodule.py +++ b/src/data/jetclass_datamodule.py @@ -154,12 +154,11 @@ def setup(self, stage: Optional[str] = None): names_labels_previous = None filename_previous = None - # TODO: the error raise below should be used - # if len(self.hparams.jet_types) > 1 and not self.hparams.conditioning_jet_type: - # raise ValueError( - # "Multiple jet types are specified in the datamodule, but " - # "`conditioning_jet_type` is set to False." - # ) + if len(self.hparams.jet_types) > 1 and not self.hparams.conditioning_jet_type: + raise ValueError( + "Multiple jet types are specified in the datamodule, but " + "`conditioning_jet_type` is set to False." + ) # data loading for jet_type, jet_type_dict in self.hparams.jet_types.items(): for filename in jet_type_dict["files"]: From 64394956449dc1bcd2afbc25a27fe95fab305bd2 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Sun, 6 Aug 2023 19:02:42 +0200 Subject: [PATCH 22/50] Remove "latex_label" entries from data config (since not used so far) --- configs/data/jetclass.yaml | 3 --- 1 file changed, 3 deletions(-) diff --git a/configs/data/jetclass.yaml b/configs/data/jetclass.yaml index d1026cb6..53625af6 100644 --- a/configs/data/jetclass.yaml +++ b/configs/data/jetclass.yaml @@ -3,17 +3,14 @@ data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass_npz jet_types: ttbar: - latex_label: $t(\\rightarrow qq')$ files: - ${data.data_dir}/jetclass_TTBar_100_000.npz # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBar_300_000.npz qcd: - latex_label: QCD files: - ${data.data_dir}/jetclass_ZJetsToNuNu_100_000.npz # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_300_000.npz hbb: - latex_label: $H(\\rightarrow b\\bar{b})$ files: - ${data.data_dir}/jetclass_ZJetsToNuNu_100_000.npz # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_300_000.npz From 563c7ee9c5b7a661aceb3ab640613b35d4aa66de Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Sun, 6 Aug 2023 19:04:57 +0200 Subject: [PATCH 23/50] Fix typos in data config --- configs/data/jetclass.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/configs/data/jetclass.yaml b/configs/data/jetclass.yaml index 53625af6..e6803326 100644 --- a/configs/data/jetclass.yaml +++ b/configs/data/jetclass.yaml @@ -4,15 +4,15 @@ data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass_npz jet_types: ttbar: files: - - ${data.data_dir}/jetclass_TTBar_100_000.npz + - ${data.data_dir}/jetclass_TTBar_300_000.npz # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBar_300_000.npz qcd: files: - - ${data.data_dir}/jetclass_ZJetsToNuNu_100_000.npz + - ${data.data_dir}/jetclass_ZJetsToNuNu_300_000.npz # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_300_000.npz hbb: files: - - ${data.data_dir}/jetclass_ZJetsToNuNu_100_000.npz + - ${data.data_dir}/jetclass_HToBB_300_000.npz # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_300_000.npz batch_size: 1024 From 61707a6ab3c609ac57304cf54812187f64610ee2 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Sun, 6 Aug 2023 19:16:13 +0200 Subject: [PATCH 24/50] Adding all jet classes to config --- configs/data/jetclass.yaml | 50 +++++++++++++++++++++++----- configs/experiment/jetclass_dev.yaml | 2 +- 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/configs/data/jetclass.yaml b/configs/data/jetclass.yaml index e6803326..49033235 100644 --- a/configs/data/jetclass.yaml +++ b/configs/data/jetclass.yaml @@ -1,4 +1,18 @@ _target_: src.data.jetclass_datamodule.JetClassDataModule + +# overall parameters +batch_size: 1024 +num_workers: 32 +pin_memory: False +val_fraction: 0.10 +test_fraction: 0.30 + +# preprocessing +normalize: True +normalize_sigma: 5 + + +# files and jet types to use data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass_npz jet_types: @@ -14,11 +28,31 @@ jet_types: files: - ${data.data_dir}/jetclass_HToBB_300_000.npz # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_300_000.npz - -batch_size: 1024 -num_workers: 32 -pin_memory: False -val_fraction: 0.10 -test_fraction: 0.30 -normalize: True -normalize_sigma: 5 + hcc: + files: + - ${data.data_dir}/jetclass_HToCC_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToCC_300_000.npz + hgg: + files: + - ${data.data_dir}/jetclass_HToGG_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToGG_300_000.npz + h4q: + files: + - ${data.data_dir}/jetclass_HToWW4Q_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW4Q_300_000.npz + hqql: + files: + - ${data.data_dir}/jetclass_HToWW2Q1L_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW2Q1L_300_000.npz + zqq: + files: + - ${data.data_dir}/jetclass_ZToQQ_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZToQQ_300_000.npz + wqq: + files: + - ${data.data_dir}/jetclass_WToQQ_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_WToQQ_300_000.npz + ttbarlep: + files: + - ${data.data_dir}/jetclass_TTBarLep_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBarLep_300_000.npz diff --git a/configs/experiment/jetclass_dev.yaml b/configs/experiment/jetclass_dev.yaml index a77df384..3afd565f 100644 --- a/configs/experiment/jetclass_dev.yaml +++ b/configs/experiment/jetclass_dev.yaml @@ -30,7 +30,7 @@ trainer: model: num_particles: 128 - global_cond_dim: 2 # needs to be calculated when using conditioning + global_cond_dim: 10 # needs to be calculated when using conditioning (= number of jet types + additional conditioning variables) local_cond_dim: 0 data: From 4e143599f6fc59cb77f306b0328ca4b7667d7c17 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Sun, 6 Aug 2023 19:19:53 +0200 Subject: [PATCH 25/50] Update jetclass_cond.yaml to run one training with all 10 jet types --- configs/experiment/jetclass_cond.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/configs/experiment/jetclass_cond.yaml b/configs/experiment/jetclass_cond.yaml index 46e54d03..46dc9f7b 100644 --- a/configs/experiment/jetclass_cond.yaml +++ b/configs/experiment/jetclass_cond.yaml @@ -16,7 +16,7 @@ defaults: tags: ["flow_matching", "JetClass", "cond"] -run_note: "Test run with conditioning on jet type" +run_note: "jet-type conditioning, using all 10 jet types" seed: 12345 @@ -27,16 +27,16 @@ trainer: model: num_particles: 128 - global_cond_dim: 2 # needs to be calculated when using conditioning + global_cond_dim: 10 # needs to be calculated when using conditioning (= number of jet types + additional conditioning variables) local_cond_dim: 0 data: # preprocessing - number_of_used_jets: 300000 + number_of_used_jets: 3_000_000 use_custom_eta_centering: True # this means we are using eta_rel = eta_particle - eta_jet remove_etadiff_tails: True # remove tracks with | eta_rel | > 1 normalize: True - normalize_sigma: 5 + normalize_sigma: 1 # conditioning conditioning_pt: False conditioning_eta: False @@ -52,7 +52,7 @@ callbacks: save_ema_weights_in_callback_state: True evaluate_ema_weights_instead: True jetclass_eval: - every_n_epochs: 100 # evaluate every n epochs + every_n_epochs: 5 # evaluate every n epochs additional_eval_epochs: [1, 5, 10, 30, 50, 75] # evaluate at these epochs as well num_jet_samples: 50_000 # jet samples to generate jetclass_eval_test: From bdf0fc75e4c017059f988d69bd8f51bf06cf5294 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Sun, 6 Aug 2023 20:37:14 +0200 Subject: [PATCH 26/50] One training with only top and qcd jets --- configs/data/jetclass.yaml | 64 +++++++++++++-------------- configs/experiment/jetclass_cond.yaml | 6 +-- 2 files changed, 35 insertions(+), 35 deletions(-) diff --git a/configs/data/jetclass.yaml b/configs/data/jetclass.yaml index 49033235..1908442c 100644 --- a/configs/data/jetclass.yaml +++ b/configs/data/jetclass.yaml @@ -24,35 +24,35 @@ jet_types: files: - ${data.data_dir}/jetclass_ZJetsToNuNu_300_000.npz # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_300_000.npz - hbb: - files: - - ${data.data_dir}/jetclass_HToBB_300_000.npz - # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_300_000.npz - hcc: - files: - - ${data.data_dir}/jetclass_HToCC_300_000.npz - # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToCC_300_000.npz - hgg: - files: - - ${data.data_dir}/jetclass_HToGG_300_000.npz - # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToGG_300_000.npz - h4q: - files: - - ${data.data_dir}/jetclass_HToWW4Q_300_000.npz - # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW4Q_300_000.npz - hqql: - files: - - ${data.data_dir}/jetclass_HToWW2Q1L_300_000.npz - # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW2Q1L_300_000.npz - zqq: - files: - - ${data.data_dir}/jetclass_ZToQQ_300_000.npz - # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZToQQ_300_000.npz - wqq: - files: - - ${data.data_dir}/jetclass_WToQQ_300_000.npz - # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_WToQQ_300_000.npz - ttbarlep: - files: - - ${data.data_dir}/jetclass_TTBarLep_300_000.npz - # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBarLep_300_000.npz + # hbb: + # files: + # - ${data.data_dir}/jetclass_HToBB_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_300_000.npz + # hcc: + # files: + # - ${data.data_dir}/jetclass_HToCC_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToCC_300_000.npz + # hgg: + # files: + # - ${data.data_dir}/jetclass_HToGG_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToGG_300_000.npz + # h4q: + # files: + # - ${data.data_dir}/jetclass_HToWW4Q_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW4Q_300_000.npz + # hqql: + # files: + # - ${data.data_dir}/jetclass_HToWW2Q1L_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW2Q1L_300_000.npz + # zqq: + # files: + # - ${data.data_dir}/jetclass_ZToQQ_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZToQQ_300_000.npz + # wqq: + # files: + # - ${data.data_dir}/jetclass_WToQQ_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_WToQQ_300_000.npz + # ttbarlep: + # files: + # - ${data.data_dir}/jetclass_TTBarLep_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBarLep_300_000.npz diff --git a/configs/experiment/jetclass_cond.yaml b/configs/experiment/jetclass_cond.yaml index 46dc9f7b..7d481abb 100644 --- a/configs/experiment/jetclass_cond.yaml +++ b/configs/experiment/jetclass_cond.yaml @@ -16,18 +16,18 @@ defaults: tags: ["flow_matching", "JetClass", "cond"] -run_note: "jet-type conditioning, using all 10 jet types" +run_note: "jet-type conditioning, using only top and qcd jets" seed: 12345 trainer: min_epochs: 1 - max_epochs: 2000 + max_epochs: 10_000 gradient_clip_val: 0.5 model: num_particles: 128 - global_cond_dim: 10 # needs to be calculated when using conditioning (= number of jet types + additional conditioning variables) + global_cond_dim: 2 # needs to be calculated when using conditioning (= number of jet types + additional conditioning variables) local_cond_dim: 0 data: From e592983d895e0739ca88436ec527028d4d1fcd3d Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Sun, 6 Aug 2023 22:24:32 +0200 Subject: [PATCH 27/50] Log the used jet types as hyperparameters --- src/data/jetclass_datamodule.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/data/jetclass_datamodule.py b/src/data/jetclass_datamodule.py index 83d39ebf..72695522 100644 --- a/src/data/jetclass_datamodule.py +++ b/src/data/jetclass_datamodule.py @@ -95,6 +95,7 @@ def __init__( if jet_types is None: raise ValueError("`jet_types` must be specified in the datamodule.") + self.hparams["jet_types_list"] = list(jet_types.keys()) # this line allows to access init params with 'self.hparams' attribute # also ensures init params will be stored in ckpt self.save_hyperparameters(logger=False) From 8534c6dd40ef0b1da9b32ef8d793ce1b7898a5ff Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Sun, 6 Aug 2023 22:26:28 +0200 Subject: [PATCH 28/50] Make one run with ttbar-had and ttbar-lep jets --- configs/data/jetclass.yaml | 16 ++++++++-------- configs/experiment/jetclass_cond.yaml | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/configs/data/jetclass.yaml b/configs/data/jetclass.yaml index 1908442c..da4f2e57 100644 --- a/configs/data/jetclass.yaml +++ b/configs/data/jetclass.yaml @@ -20,10 +20,10 @@ jet_types: files: - ${data.data_dir}/jetclass_TTBar_300_000.npz # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBar_300_000.npz - qcd: - files: - - ${data.data_dir}/jetclass_ZJetsToNuNu_300_000.npz - # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_300_000.npz + # qcd: + # files: + # - ${data.data_dir}/jetclass_ZJetsToNuNu_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_300_000.npz # hbb: # files: # - ${data.data_dir}/jetclass_HToBB_300_000.npz @@ -52,7 +52,7 @@ jet_types: # files: # - ${data.data_dir}/jetclass_WToQQ_300_000.npz # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_WToQQ_300_000.npz - # ttbarlep: - # files: - # - ${data.data_dir}/jetclass_TTBarLep_300_000.npz - # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBarLep_300_000.npz + ttbarlep: + files: + - ${data.data_dir}/jetclass_TTBarLep_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBarLep_300_000.npz diff --git a/configs/experiment/jetclass_cond.yaml b/configs/experiment/jetclass_cond.yaml index 7d481abb..701a8b31 100644 --- a/configs/experiment/jetclass_cond.yaml +++ b/configs/experiment/jetclass_cond.yaml @@ -16,7 +16,7 @@ defaults: tags: ["flow_matching", "JetClass", "cond"] -run_note: "jet-type conditioning, using only top and qcd jets" +run_note: "jet-type conditioning, using only top (had) and top (lep) jets" seed: 12345 From 3d70eb4b61646b23990f8f67cc9cbe955f9ce6e3 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Sun, 6 Aug 2023 22:27:28 +0200 Subject: [PATCH 29/50] Add todo --- src/data/jetclass_datamodule.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/data/jetclass_datamodule.py b/src/data/jetclass_datamodule.py index 72695522..1e6607c5 100644 --- a/src/data/jetclass_datamodule.py +++ b/src/data/jetclass_datamodule.py @@ -95,6 +95,7 @@ def __init__( if jet_types is None: raise ValueError("`jet_types` must be specified in the datamodule.") + # TODO: this doesn't work yet... self.hparams["jet_types_list"] = list(jet_types.keys()) # this line allows to access init params with 'self.hparams' attribute # also ensures init params will be stored in ckpt From e1daa252c2858f6b84783384bcf565b32bc6db51 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Sun, 6 Aug 2023 22:29:35 +0200 Subject: [PATCH 30/50] 10 jet types but with larger sigma --- configs/data/jetclass.yaml | 64 +++++++++++++-------------- configs/experiment/jetclass_cond.yaml | 6 +-- 2 files changed, 35 insertions(+), 35 deletions(-) diff --git a/configs/data/jetclass.yaml b/configs/data/jetclass.yaml index da4f2e57..49033235 100644 --- a/configs/data/jetclass.yaml +++ b/configs/data/jetclass.yaml @@ -20,38 +20,38 @@ jet_types: files: - ${data.data_dir}/jetclass_TTBar_300_000.npz # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBar_300_000.npz - # qcd: - # files: - # - ${data.data_dir}/jetclass_ZJetsToNuNu_300_000.npz - # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_300_000.npz - # hbb: - # files: - # - ${data.data_dir}/jetclass_HToBB_300_000.npz - # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_300_000.npz - # hcc: - # files: - # - ${data.data_dir}/jetclass_HToCC_300_000.npz - # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToCC_300_000.npz - # hgg: - # files: - # - ${data.data_dir}/jetclass_HToGG_300_000.npz - # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToGG_300_000.npz - # h4q: - # files: - # - ${data.data_dir}/jetclass_HToWW4Q_300_000.npz - # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW4Q_300_000.npz - # hqql: - # files: - # - ${data.data_dir}/jetclass_HToWW2Q1L_300_000.npz - # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW2Q1L_300_000.npz - # zqq: - # files: - # - ${data.data_dir}/jetclass_ZToQQ_300_000.npz - # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZToQQ_300_000.npz - # wqq: - # files: - # - ${data.data_dir}/jetclass_WToQQ_300_000.npz - # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_WToQQ_300_000.npz + qcd: + files: + - ${data.data_dir}/jetclass_ZJetsToNuNu_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_300_000.npz + hbb: + files: + - ${data.data_dir}/jetclass_HToBB_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_300_000.npz + hcc: + files: + - ${data.data_dir}/jetclass_HToCC_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToCC_300_000.npz + hgg: + files: + - ${data.data_dir}/jetclass_HToGG_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToGG_300_000.npz + h4q: + files: + - ${data.data_dir}/jetclass_HToWW4Q_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW4Q_300_000.npz + hqql: + files: + - ${data.data_dir}/jetclass_HToWW2Q1L_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW2Q1L_300_000.npz + zqq: + files: + - ${data.data_dir}/jetclass_ZToQQ_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZToQQ_300_000.npz + wqq: + files: + - ${data.data_dir}/jetclass_WToQQ_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_WToQQ_300_000.npz ttbarlep: files: - ${data.data_dir}/jetclass_TTBarLep_300_000.npz diff --git a/configs/experiment/jetclass_cond.yaml b/configs/experiment/jetclass_cond.yaml index 701a8b31..cd448851 100644 --- a/configs/experiment/jetclass_cond.yaml +++ b/configs/experiment/jetclass_cond.yaml @@ -16,7 +16,7 @@ defaults: tags: ["flow_matching", "JetClass", "cond"] -run_note: "jet-type conditioning, using only top (had) and top (lep) jets" +run_note: "jet-type conditioning, using all 10 jet types, but with standardization sigma=5" seed: 12345 @@ -27,7 +27,7 @@ trainer: model: num_particles: 128 - global_cond_dim: 2 # needs to be calculated when using conditioning (= number of jet types + additional conditioning variables) + global_cond_dim: 10 # needs to be calculated when using conditioning (= number of jet types + additional conditioning variables) local_cond_dim: 0 data: @@ -36,7 +36,7 @@ data: use_custom_eta_centering: True # this means we are using eta_rel = eta_particle - eta_jet remove_etadiff_tails: True # remove tracks with | eta_rel | > 1 normalize: True - normalize_sigma: 1 + normalize_sigma: 5 # conditioning conditioning_pt: False conditioning_eta: False From c52ef7e9900229b418231ae04e4942ccfaa3636e Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Sun, 6 Aug 2023 23:15:15 +0200 Subject: [PATCH 31/50] Remove unused mnist datamodules tests --- tests/test_datamodules.py | 32 -------------------------------- 1 file changed, 32 deletions(-) delete mode 100644 tests/test_datamodules.py diff --git a/tests/test_datamodules.py b/tests/test_datamodules.py deleted file mode 100644 index 407b949a..00000000 --- a/tests/test_datamodules.py +++ /dev/null @@ -1,32 +0,0 @@ -from pathlib import Path - -import pytest -import torch - -from src.data.mnist_datamodule import MNISTDataModule - - -@pytest.mark.parametrize("batch_size", [32, 128]) -def test_mnist_datamodule(batch_size): - data_dir = "data/" - - dm = MNISTDataModule(data_dir=data_dir, batch_size=batch_size) - dm.prepare_data() - - assert not dm.data_train and not dm.data_val and not dm.data_test - assert Path(data_dir, "MNIST").exists() - assert Path(data_dir, "MNIST", "raw").exists() - - dm.setup() - assert dm.data_train and dm.data_val and dm.data_test - assert dm.train_dataloader() and dm.val_dataloader() and dm.test_dataloader() - - num_datapoints = len(dm.data_train) + len(dm.data_val) + len(dm.data_test) - assert num_datapoints == 70_000 - - batch = next(iter(dm.train_dataloader())) - x, y = batch - assert len(x) == batch_size - assert len(y) == batch_size - assert x.dtype == torch.float32 - assert y.dtype == torch.int64 From 7e1ab284c4398f4b41aaa28e344ca999ca81fd47 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Sun, 6 Aug 2023 23:23:03 +0200 Subject: [PATCH 32/50] Use fewer epochs when using all 10 jet classes --- configs/experiment/jetclass_cond.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/experiment/jetclass_cond.yaml b/configs/experiment/jetclass_cond.yaml index cd448851..e8659c6c 100644 --- a/configs/experiment/jetclass_cond.yaml +++ b/configs/experiment/jetclass_cond.yaml @@ -22,7 +22,7 @@ seed: 12345 trainer: min_epochs: 1 - max_epochs: 10_000 + max_epochs: 2_000 gradient_clip_val: 0.5 model: From 24aaaa524cca9ae8b32c8035c194384c80b67a66 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Mon, 7 Aug 2023 09:06:38 +0200 Subject: [PATCH 33/50] Adding jet-level spectator variables --- configs/data/jetclass.yaml | 3 +++ src/data/jetclass_datamodule.py | 37 +++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/configs/data/jetclass.yaml b/configs/data/jetclass.yaml index 49033235..80dd1b66 100644 --- a/configs/data/jetclass.yaml +++ b/configs/data/jetclass.yaml @@ -11,6 +11,9 @@ test_fraction: 0.30 normalize: True normalize_sigma: 5 +spectator_jet_features: + - jet_pt + # files and jet types to use data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass_npz diff --git a/src/data/jetclass_datamodule.py b/src/data/jetclass_datamodule.py index 1e6607c5..c47dd8c5 100644 --- a/src/data/jetclass_datamodule.py +++ b/src/data/jetclass_datamodule.py @@ -87,6 +87,7 @@ def __init__( normalize_sigma: int = 5, use_custom_eta_centering: bool = True, remove_etadiff_tails: bool = True, + spectator_jet_features: list = None, # centering: bool = False, # use_calculated_base_distribution: bool = True, ): @@ -118,6 +119,9 @@ def __init__( self.tensor_conditioning_train: Optional[torch.Tensor] = None self.tensor_conditioning_val: Optional[torch.Tensor] = None self.tensor_conditioning_test: Optional[torch.Tensor] = None + self.tensor_spectator_jet_train: Optional[torch.Tensor] = None + self.tensor_spectator_jet_val: Optional[torch.Tensor] = None + self.tensor_spectator_jet_test: Optional[torch.Tensor] = None def prepare_data(self): """Download data if needed. @@ -318,6 +322,30 @@ def setup(self, stage: Optional[str] = None): len(labels) - n_samples_test, ], ) + + if self.hparams.spectator_jet_features is not None: + # initialize and fill array + spectator_jet_features = np.zeros( + (len(jet_features), len(self.hparams.spectator_jet_features)) + ) + for i, feat in enumerate(self.hparams.spectator_jet_features): + index = get_feat_index(names_jet_features, feat) + spectator_jet_features[:, i] = jet_features[:, index] + else: + spectator_jet_features = np.zeros(len(jet_features)) + + ( + spectator_jet_features_train, + spectator_jet_features_val, + spectator_jet_features_test, + ) = np.split( + spectator_jet_features, + [ + len(spectator_jet_features) - (n_samples_val + n_samples_test), + len(spectator_jet_features) - n_samples_test, + ], + ) + if self.num_cond_features == 0: self.tensor_conditioning_train = torch.zeros(len(dataset_train)) self.tensor_conditioning_val = torch.zeros(len(dataset_val)) @@ -357,6 +385,15 @@ def setup(self, stage: Optional[str] = None): self.names_particle_features = names_particle_features self.names_jet_features = names_jet_features self.names_labels = names_labels + self.tensor_spectator_train = torch.tensor( + spectator_jet_features_train, dtype=torch.float32 + ) + self.tensor_spectator_test = torch.tensor( + spectator_jet_features_test, dtype=torch.float32 + ) + self.tensor_spectator_val = torch.tensor( + spectator_jet_features_val, dtype=torch.float32 + ) if self.hparams.normalize: # calculate means and stds only based on the training data From 4b747c52fd8e9045388902cfe36846f5cfb8ff85 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Mon, 7 Aug 2023 09:39:50 +0200 Subject: [PATCH 34/50] Remove debugging printout --- configs/experiment/jetclass.yaml | 2 +- src/data/jetclass_datamodule.py | 7 +------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/configs/experiment/jetclass.yaml b/configs/experiment/jetclass.yaml index dfc00502..6f5b1535 100644 --- a/configs/experiment/jetclass.yaml +++ b/configs/experiment/jetclass.yaml @@ -16,7 +16,7 @@ defaults: tags: ["flow_matching", "JetClass", "uncond"] -run_note: "Crosscheck that without conditioning everything still works" +run_note: "" seed: 12345 diff --git a/src/data/jetclass_datamodule.py b/src/data/jetclass_datamodule.py index c47dd8c5..367e7c05 100644 --- a/src/data/jetclass_datamodule.py +++ b/src/data/jetclass_datamodule.py @@ -452,19 +452,14 @@ def setup(self, stage: Optional[str] = None): self.tensor_conditioning_train_dl = self.tensor_conditioning_train self.tensor_conditioning_val_dl = self.tensor_conditioning_val self.tensor_conditioning_test_dl = self.tensor_conditioning_test - print("First 10 values of conditioning data:") - print(self.tensor_conditioning_train_dl[:10]) # check if conditioning data contains nan values - print("Checking conditioning data for NaNs...") if ( torch.isnan(self.tensor_conditioning_train_dl).any() or torch.isnan(self.tensor_conditioning_val_dl).any() or torch.isnan(self.tensor_conditioning_test_dl).any() ): - print("NaNs found in conditioning data!") - else: - print("GOOD NEWS: No NaNs found in conditioning data.") + raise ValueError("NaNs found in conditioning data!") self.data_train = TensorDataset( self.tensor_train_dl, From 75931743998ac1beaff722844779a0d480fedc93 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Mon, 7 Aug 2023 09:41:29 +0200 Subject: [PATCH 35/50] Revert "Remove unused mnist datamodules tests" This reverts commit c52ef7e9900229b418231ae04e4942ccfaa3636e. --- tests/test_datamodules.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 tests/test_datamodules.py diff --git a/tests/test_datamodules.py b/tests/test_datamodules.py new file mode 100644 index 00000000..407b949a --- /dev/null +++ b/tests/test_datamodules.py @@ -0,0 +1,32 @@ +from pathlib import Path + +import pytest +import torch + +from src.data.mnist_datamodule import MNISTDataModule + + +@pytest.mark.parametrize("batch_size", [32, 128]) +def test_mnist_datamodule(batch_size): + data_dir = "data/" + + dm = MNISTDataModule(data_dir=data_dir, batch_size=batch_size) + dm.prepare_data() + + assert not dm.data_train and not dm.data_val and not dm.data_test + assert Path(data_dir, "MNIST").exists() + assert Path(data_dir, "MNIST", "raw").exists() + + dm.setup() + assert dm.data_train and dm.data_val and dm.data_test + assert dm.train_dataloader() and dm.val_dataloader() and dm.test_dataloader() + + num_datapoints = len(dm.data_train) + len(dm.data_val) + len(dm.data_test) + assert num_datapoints == 70_000 + + batch = next(iter(dm.train_dataloader())) + x, y = batch + assert len(x) == batch_size + assert len(y) == batch_size + assert x.dtype == torch.float32 + assert y.dtype == torch.int64 From a6d185db9f0218bad70df08e093ea5ec472261df Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Mon, 7 Aug 2023 09:46:42 +0200 Subject: [PATCH 36/50] Add nan-check for particle data --- src/data/jetclass_datamodule.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/data/jetclass_datamodule.py b/src/data/jetclass_datamodule.py index 367e7c05..97cbcf06 100644 --- a/src/data/jetclass_datamodule.py +++ b/src/data/jetclass_datamodule.py @@ -352,7 +352,6 @@ def setup(self, stage: Optional[str] = None): self.tensor_conditioning_test = torch.zeros(len(dataset_test)) self.names_conditioning = None else: - print("At least one conditioning feature is used.") conditioning_features, self.names_conditioning = self._handle_conditioning( jet_features, names_jet_features, labels ) @@ -453,6 +452,14 @@ def setup(self, stage: Optional[str] = None): self.tensor_conditioning_val_dl = self.tensor_conditioning_val self.tensor_conditioning_test_dl = self.tensor_conditioning_test + # check if particle data contains nan values + if ( + torch.isnan(self.tensor_train_dl).any() + or torch.isnan(self.tensor_val_dl).any() + or torch.isnan(self.tensor_test_dl).any() + ): + raise ValueError("NaNs found in particle data!") + # check if conditioning data contains nan values if ( torch.isnan(self.tensor_conditioning_train_dl).any() From c81cbf6eda45d9ea0a14cdbc652307b8a039442d Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Mon, 7 Aug 2023 11:09:53 +0200 Subject: [PATCH 37/50] Try again qcd and top jets with sigma=5 --- configs/data/jetclass.yaml | 64 +++++++++++++-------------- configs/experiment/jetclass_cond.yaml | 2 +- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/configs/data/jetclass.yaml b/configs/data/jetclass.yaml index 80dd1b66..d952d260 100644 --- a/configs/data/jetclass.yaml +++ b/configs/data/jetclass.yaml @@ -27,35 +27,35 @@ jet_types: files: - ${data.data_dir}/jetclass_ZJetsToNuNu_300_000.npz # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_300_000.npz - hbb: - files: - - ${data.data_dir}/jetclass_HToBB_300_000.npz - # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_300_000.npz - hcc: - files: - - ${data.data_dir}/jetclass_HToCC_300_000.npz - # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToCC_300_000.npz - hgg: - files: - - ${data.data_dir}/jetclass_HToGG_300_000.npz - # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToGG_300_000.npz - h4q: - files: - - ${data.data_dir}/jetclass_HToWW4Q_300_000.npz - # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW4Q_300_000.npz - hqql: - files: - - ${data.data_dir}/jetclass_HToWW2Q1L_300_000.npz - # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW2Q1L_300_000.npz - zqq: - files: - - ${data.data_dir}/jetclass_ZToQQ_300_000.npz - # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZToQQ_300_000.npz - wqq: - files: - - ${data.data_dir}/jetclass_WToQQ_300_000.npz - # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_WToQQ_300_000.npz - ttbarlep: - files: - - ${data.data_dir}/jetclass_TTBarLep_300_000.npz - # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBarLep_300_000.npz + # hbb: + # files: + # - ${data.data_dir}/jetclass_HToBB_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_300_000.npz + # hcc: + # files: + # - ${data.data_dir}/jetclass_HToCC_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToCC_300_000.npz + # hgg: + # files: + # - ${data.data_dir}/jetclass_HToGG_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToGG_300_000.npz + # h4q: + # files: + # - ${data.data_dir}/jetclass_HToWW4Q_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW4Q_300_000.npz + # hqql: + # files: + # - ${data.data_dir}/jetclass_HToWW2Q1L_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW2Q1L_300_000.npz + # zqq: + # files: + # - ${data.data_dir}/jetclass_ZToQQ_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZToQQ_300_000.npz + # wqq: + # files: + # - ${data.data_dir}/jetclass_WToQQ_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_WToQQ_300_000.npz + # ttbarlep: + # files: + # - ${data.data_dir}/jetclass_TTBarLep_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBarLep_300_000.npz diff --git a/configs/experiment/jetclass_cond.yaml b/configs/experiment/jetclass_cond.yaml index e8659c6c..5059d9c1 100644 --- a/configs/experiment/jetclass_cond.yaml +++ b/configs/experiment/jetclass_cond.yaml @@ -27,7 +27,7 @@ trainer: model: num_particles: 128 - global_cond_dim: 10 # needs to be calculated when using conditioning (= number of jet types + additional conditioning variables) + global_cond_dim: 2 # needs to be calculated when using conditioning (= number of jet types + additional conditioning variables) local_cond_dim: 0 data: From 0511e4924e47230aaa54cdc07aa73324bc7de505 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Mon, 7 Aug 2023 11:28:14 +0200 Subject: [PATCH 38/50] Revert "Try again qcd and top jets with sigma=5" This reverts commit c81cbf6eda45d9ea0a14cdbc652307b8a039442d. --- configs/data/jetclass.yaml | 64 +++++++++++++-------------- configs/experiment/jetclass_cond.yaml | 2 +- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/configs/data/jetclass.yaml b/configs/data/jetclass.yaml index d952d260..80dd1b66 100644 --- a/configs/data/jetclass.yaml +++ b/configs/data/jetclass.yaml @@ -27,35 +27,35 @@ jet_types: files: - ${data.data_dir}/jetclass_ZJetsToNuNu_300_000.npz # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_300_000.npz - # hbb: - # files: - # - ${data.data_dir}/jetclass_HToBB_300_000.npz - # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_300_000.npz - # hcc: - # files: - # - ${data.data_dir}/jetclass_HToCC_300_000.npz - # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToCC_300_000.npz - # hgg: - # files: - # - ${data.data_dir}/jetclass_HToGG_300_000.npz - # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToGG_300_000.npz - # h4q: - # files: - # - ${data.data_dir}/jetclass_HToWW4Q_300_000.npz - # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW4Q_300_000.npz - # hqql: - # files: - # - ${data.data_dir}/jetclass_HToWW2Q1L_300_000.npz - # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW2Q1L_300_000.npz - # zqq: - # files: - # - ${data.data_dir}/jetclass_ZToQQ_300_000.npz - # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZToQQ_300_000.npz - # wqq: - # files: - # - ${data.data_dir}/jetclass_WToQQ_300_000.npz - # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_WToQQ_300_000.npz - # ttbarlep: - # files: - # - ${data.data_dir}/jetclass_TTBarLep_300_000.npz - # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBarLep_300_000.npz + hbb: + files: + - ${data.data_dir}/jetclass_HToBB_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_300_000.npz + hcc: + files: + - ${data.data_dir}/jetclass_HToCC_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToCC_300_000.npz + hgg: + files: + - ${data.data_dir}/jetclass_HToGG_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToGG_300_000.npz + h4q: + files: + - ${data.data_dir}/jetclass_HToWW4Q_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW4Q_300_000.npz + hqql: + files: + - ${data.data_dir}/jetclass_HToWW2Q1L_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW2Q1L_300_000.npz + zqq: + files: + - ${data.data_dir}/jetclass_ZToQQ_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZToQQ_300_000.npz + wqq: + files: + - ${data.data_dir}/jetclass_WToQQ_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_WToQQ_300_000.npz + ttbarlep: + files: + - ${data.data_dir}/jetclass_TTBarLep_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBarLep_300_000.npz diff --git a/configs/experiment/jetclass_cond.yaml b/configs/experiment/jetclass_cond.yaml index 5059d9c1..e8659c6c 100644 --- a/configs/experiment/jetclass_cond.yaml +++ b/configs/experiment/jetclass_cond.yaml @@ -27,7 +27,7 @@ trainer: model: num_particles: 128 - global_cond_dim: 2 # needs to be calculated when using conditioning (= number of jet types + additional conditioning variables) + global_cond_dim: 10 # needs to be calculated when using conditioning (= number of jet types + additional conditioning variables) local_cond_dim: 0 data: From 336ed146228767e4f2df954436e7d99f495994ac Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Mon, 7 Aug 2023 11:51:39 +0200 Subject: [PATCH 39/50] Change to eval only every 50 epochs --- configs/experiment/jetclass_cond.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/configs/experiment/jetclass_cond.yaml b/configs/experiment/jetclass_cond.yaml index e8659c6c..386a2840 100644 --- a/configs/experiment/jetclass_cond.yaml +++ b/configs/experiment/jetclass_cond.yaml @@ -16,7 +16,7 @@ defaults: tags: ["flow_matching", "JetClass", "cond"] -run_note: "jet-type conditioning, using all 10 jet types, but with standardization sigma=5" +run_note: "jet-type conditioning, using all 10 jet types" seed: 12345 @@ -52,8 +52,8 @@ callbacks: save_ema_weights_in_callback_state: True evaluate_ema_weights_instead: True jetclass_eval: - every_n_epochs: 5 # evaluate every n epochs - additional_eval_epochs: [1, 5, 10, 30, 50, 75] # evaluate at these epochs as well + every_n_epochs: 50 # evaluate every n epochs + additional_eval_epochs: [1, 30, 75] # evaluate at these epochs as well num_jet_samples: 50_000 # jet samples to generate jetclass_eval_test: num_jet_samples: 200_000 # jet samples to generate From 31ab253f18e2302ba7fa03f07ae4915b33e78aa2 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Mon, 7 Aug 2023 13:13:59 +0200 Subject: [PATCH 40/50] Update eval notebook + add latex labels for plotting the different jetclass jet types --- configs/plotting/labels.yaml | 11 ++ notebooks/30_jetclass_eval.ipynb | 232 ++++++++++++++++++++++++------- 2 files changed, 191 insertions(+), 52 deletions(-) create mode 100644 configs/plotting/labels.yaml diff --git a/configs/plotting/labels.yaml b/configs/plotting/labels.yaml new file mode 100644 index 00000000..fe535ec9 --- /dev/null +++ b/configs/plotting/labels.yaml @@ -0,0 +1,11 @@ +latex_labels: + label_QCD: QCD + label_Hbb: $H(\rightarrow b\bar{b})$ + label_Hcc: $H(\rightarrow c\bar{c})$ + label_Hgg: $H(\rightarrow gg)$ + label_H4q: $H(\rightarrow 4q)$ + label_Hqql: $H(\rightarrow l \nu qq')$ + label_Zqq: $Z(\rightarrow q\bar{q})$ + label_Wqq: $W(\rightarrow q\bar{q})$ + label_Tbqq: $t(\rightarrow b qq')$ + label_Tbl: $t(\rightarrow b l \nu)$ diff --git a/notebooks/30_jetclass_eval.ipynb b/notebooks/30_jetclass_eval.ipynb index 02e64e7d..9172f7d7 100644 --- a/notebooks/30_jetclass_eval.ipynb +++ b/notebooks/30_jetclass_eval.ipynb @@ -33,7 +33,11 @@ "# plots and metrics\n", "import matplotlib.pyplot as plt\n", "\n", - "from src.data.components import calculate_all_wasserstein_metrics\n", + "from src.data.components import (\n", + " calculate_all_wasserstein_metrics,\n", + " inverse_normalize_tensor,\n", + " normalize_tensor,\n", + ")\n", "from src.utils.data_generation import generate_data\n", "from src.utils.plotting import apply_mpl_styles, create_and_plot_data, plot_single_jets\n", "\n", @@ -52,53 +56,65 @@ "# load everything from experiment config\n", "with hydra.initialize(version_base=None, config_path=\"../configs/\"):\n", " cfg = hydra.compose(config_name=\"train.yaml\", overrides=[f\"experiment={experiment}\"])\n", + " print(type(cfg))\n", " print(OmegaConf.to_yaml(cfg))\n", "\n", "datamodule = hydra.utils.instantiate(cfg.data)\n", + "# datamodule.hparams.number_of_used_jets = 1_000_000\n", "# set remove_etadiff_tails=False when checking the pT_jet distribution calculated from particle pT\n", "# datamodule.hparams.remove_etadiff_tails = False\n", "model = hydra.utils.instantiate(cfg.model)\n", - "datamodule.setup()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + "datamodule.setup()\n", + "\n", + "# ------------------------------------------------\n", + "# Some printouts about shape to check if it's what we expect\n", "test_data = np.array(datamodule.tensor_test)\n", "test_mask = np.array(datamodule.mask_test)\n", "test_cond = np.array(datamodule.tensor_conditioning_test)\n", + "test_spectator = np.array(datamodule.tensor_spectator_test)\n", "val_data = np.array(datamodule.tensor_val)\n", "val_mask = np.array(datamodule.mask_val)\n", "val_cond = np.array(datamodule.tensor_conditioning_val)\n", + "val_spectator = np.array(datamodule.tensor_spectator_val)\n", "train_data = np.array(datamodule.tensor_train)\n", "train_mask = np.array(datamodule.mask_train)\n", "train_cond = np.array(datamodule.tensor_conditioning_train)\n", + "train_spectator = np.array(datamodule.tensor_spectator_train)\n", "means = np.array(datamodule.means)\n", "stds = np.array(datamodule.stds)\n", "\n", "print(test_data.shape)\n", "print(test_mask.shape)\n", "print(test_cond.shape)\n", + "print(test_spectator.shape)\n", "print(val_data.shape)\n", "print(val_mask.shape)\n", "print(val_cond.shape)\n", + "print(val_spectator.shape)\n", "print(train_data.shape)\n", "print(train_mask.shape)\n", "print(train_cond.shape)\n", + "print(train_spectator.shape)\n", "print(means)\n", "print(stds)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "ckpt = \"/beegfs/desy/user/birkjosc/epic-fm/logs/jetclass_flow_matching/runs/2023-07-27_23-46-46/checkpoints/last-EMA.ckpt\"\n", + "# load the model from the checkpoint\n", + "ckpt_path = \"/beegfs/desy/user/birkjosc/epic-fm/logs/jetclass_cond_jettype/runs/2023-08-06_22-30-00/checkpoints\"\n", + "ckpt = f\"{ckpt_path}/last-EMA.ckpt\"\n", "model = model.load_from_checkpoint(ckpt)" ] }, @@ -108,11 +124,14 @@ "metadata": {}, "outputs": [], "source": [ - "factor = 1\n", + "# optional: increase the size of the test data for better statistics\n", + "factor = 1 # this is the factor by which the test data is increased/repeated\n", "# chosse between test and val\n", - "mask_real = test_mask\n", - "data_real = test_data\n", - "cond_real = test_cond\n", + "stop = 100_000\n", + "mask_real = test_mask[:stop]\n", + "data_real = test_data[:stop]\n", + "cond_real = test_cond[:stop]\n", + "spectator_real = test_spectator[:stop]\n", "\n", "# increase size for better statistics\n", "big_mask_real = np.repeat(mask_real, factor, axis=0)\n", @@ -133,7 +152,7 @@ " cond=torch.tensor(big_cond_real),\n", " variable_set_sizes=True,\n", " mask=torch.tensor(big_mask_real),\n", - " normalized_data=False,\n", + " normalized_data=True,\n", " means=means,\n", " stds=stds,\n", " ode_solver=\"midpoint\",\n", @@ -147,16 +166,21 @@ "metadata": {}, "outputs": [], "source": [ + "# plot the generated features and compare sim. data to gen. data\n", "fig, ax = plt.subplots(1, 3, figsize=(12, 4))\n", "ax = ax.flatten()\n", - "hist_kwargs = dict(bins=100, alpha=0.5, density=True)\n", + "hist_kwargs = dict(density=True)\n", "for i in range(3):\n", - " ax[i].hist(data_real[:, :, i][mask_real[:, :, 0] != 0].flatten(), **hist_kwargs, label=\"real\")\n", + " values_sim = data_real[:, :, i][mask_real[:, :, 0] != 0].flatten()\n", + " values_gen = data_generated[:, :, i][mask_real[:, :, 0] != 0].flatten()\n", + " _, bin_edges = np.histogram(np.concatenate([values_sim, values_gen]), bins=100)\n", + " hist_kwargs[\"bins\"] = bin_edges\n", + " ax[i].hist(values_sim, **hist_kwargs, label=\"Sim. data\", alpha=0.5)\n", " ax[i].hist(\n", - " data_generated[:, :, i][mask_real[:, :, 0] != 0].flatten(),\n", - " **hist_kwargs,\n", - " label=\"generated\",\n", + " values_gen,\n", + " label=\"Gen. data\",\n", " histtype=\"step\",\n", + " **hist_kwargs,\n", " )\n", " ax[i].set_yscale(\"log\")\n", "ax[2].legend(frameon=False)\n", @@ -169,6 +193,7 @@ "metadata": {}, "outputs": [], "source": [ + "# calculate the Wasserstein distance between the simulated and generated data\n", "w_dists_big = calculate_all_wasserstein_metrics(\n", " data_real[..., :3],\n", " data_generated[..., :3],\n", @@ -191,27 +216,17 @@ "metadata": {}, "outputs": [], "source": [ - "w_dists_big_non_averaged = calculate_all_wasserstein_metrics(\n", - " data_real[..., :3],\n", - " data_generated[..., :3],\n", - " None,\n", - " None,\n", - " num_eval_samples=len(data_real),\n", - " num_batches=factor,\n", - " calculate_efps=True,\n", - " use_masks=False,\n", - ")\n", - "w_dists_big_non_averaged" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "cond_real_repeat = np.repeat(cond_real[:, np.newaxis, :], mask_real.shape[1], axis=1)\n", - "cond_real_repeat.shape" + "# w_dists_big_non_averaged = calculate_all_wasserstein_metrics(\n", + "# data_real[..., :3],\n", + "# data_generated[..., :3],\n", + "# None,\n", + "# None,\n", + "# num_eval_samples=len(data_real),\n", + "# num_batches=factor,\n", + "# calculate_efps=True,\n", + "# use_masks=False,\n", + "# )\n", + "# w_dists_big_non_averaged" ] }, { @@ -230,8 +245,13 @@ "\n", "from copy import deepcopy\n", "\n", + "import cplt\n", + "\n", "from src.data.components.utils import calculate_jet_features\n", "\n", + "cplt.utils.set_mpl_colours()\n", + "# cplt.utils.reset_mpl_colours()\n", + "\n", "fig, ax = plt.subplots(1, 3, figsize=(15, 4))\n", "hist_kwargs = dict(bins=100, histtype=\"step\")\n", "\n", @@ -239,11 +259,13 @@ "particle_features = deepcopy(data_real)\n", "\n", "# re-scale particle pt with jet pt\n", - "particle_features[:, :, 2] *= cond_real_repeat[:, :, 0]\n", + "particle_features[:, :, 2] *= np.repeat(\n", + " spectator_real[:, np.newaxis, :], mask_real.shape[1], axis=1\n", + ")[:, :, 0]\n", "\n", "# calculate jet features (both with pT_rel and pT)\n", "jet_features_rel = calculate_jet_features(data_real) # pT_rel\n", - "jet_features = calculate_jet_features(particle_features) # pT\n", + "jet_features_calculated = calculate_jet_features(particle_features) # pT\n", "\n", "# Note: the jet pt which is calculated from the constituent pt does not\n", "# yield exactly the same distribution if the etadiff tails are removed!\n", @@ -254,23 +276,129 @@ "ax[1].set_xlabel(\"$p_T^{particle}$\")\n", "ax[0].set_yscale(\"log\")\n", "ax[1].set_yscale(\"log\")\n", - "ax[2].hist(jet_features[:, 0], **hist_kwargs, label=\"Calculated from $p_T^{particle}$\")\n", + "ax[2].hist(jet_features_calculated[:, 0], **hist_kwargs, label=\"Calculated from $p_T^{particle}$\")\n", "ax[2].hist(cond_real[:, 0], **hist_kwargs, label=\"Original value\", ls=\"--\")\n", "ax[2].legend(frameon=False)\n", "ax[2].set_xlabel(\"$p_T^{jet}$\")\n", "fig.tight_layout()\n", "plt.show()\n", "\n", - "fig, ax = plt.subplots(1, 2, figsize=(15, 4))\n", - "ax[0].hist(jet_features[:, 3], **hist_kwargs, label=\"Calculated from $p_T^{particle}$\")\n", - "ax[0].set_xlabel(\"$m_{jet}$ - using $p_T^{particle}$\")\n", - "ax[1].hist(\n", - " jet_features_rel[:, 3], **hist_kwargs, label=\"Calculated from $p_T^{particle} / p_T^{jet}$\"\n", - ")\n", - "ax[1].set_xlabel(\"$m_{jet}$ - using $p_T^{particle} / p_T^{jet}$\")\n", + "fig, ax = plt.subplots(1, 2, figsize=(15, 6))\n", + "hist_kwargs = dict(histtype=\"step\", density=True, linewidth=2)\n", + "\n", + "import yaml\n", + "\n", + "# load labels from labels.yaml\n", + "with open(\"../configs/plotting/labels.yaml\", \"r\") as f:\n", + " labels = yaml.load(f, Loader=yaml.SafeLoader)\n", + " latex_labels = labels[\"latex_labels\"]\n", + " print(latex_labels)\n", + "\n", + "\n", + "for i, jet_type in enumerate(datamodule.names_labels):\n", + " # print(jet_type)\n", + " mask = cond_real[:, i] == 1\n", + " # hist_kwargs[\"bins\"] = 10\n", + " # print(mask.shape)\n", + " hist_kwargs[\"linestyle\"] = (\n", + " \"solid\"\n", + " if i < len(cplt.utils.get_good_colours())\n", + " else cplt.utils.get_good_linestyles(\"densely dotted\")\n", + " )\n", + " ax[0].hist(\n", + " jet_features_calculated[:, 3][mask],\n", + " label=latex_labels[jet_type],\n", + " bins=np.linspace(0, 300, 60),\n", + " **hist_kwargs,\n", + " )\n", + " ax[0].set_xlabel(\"$m_{jet}$ (using $p_\\\\mathrm{T}^\\\\mathrm{particle}$)\")\n", + " ax[1].hist(\n", + " jet_features_rel[:, 3][mask],\n", + " label=latex_labels[jet_type],\n", + " bins=np.linspace(0, 0.6, 60),\n", + " **hist_kwargs,\n", + " )\n", + " ax[1].set_xlabel(\n", + " \"$m_{jet}$ (using $p_\\\\mathrm{T}^\\\\mathrm{particle} / p_\\\\mathrm{T}^\\\\mathrm{jet}$)\"\n", + " )\n", + "ax[0].legend(frameon=False)\n", "fig.tight_layout()\n", "plt.show()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# plot the (relative) jet mass for each jet type individually and compare between\n", + "# generated and real jets\n", + "\n", + "# calculate jet features\n", + "jet_features_real = calculate_jet_features(data_real)\n", + "jet_features_generated = calculate_jet_features(data_generated)\n", + "\n", + "# plot the jet mass for each jet type\n", + "fig, ax = plt.subplots(10, 5, figsize=(18, 30))\n", + "hist_kwargs = dict(bins=100, density=True)\n", + "# ax= ax.flatten()\n", + "\n", + "for i, jet_type in enumerate(datamodule.names_labels):\n", + " # print(jet_type)\n", + " # if i> 0:\n", + " # break\n", + " mask = cond_real[:, i] == 1\n", + " mask_particle_level = np.repeat(\n", + " mask[:, np.newaxis, np.newaxis], data_real.shape[1], axis=1\n", + " ) & (mask_real != 0)\n", + " # print(mask.shape)\n", + " # print(mask_particle_level.shape)\n", + " # hist_kwargs[\"bins\"] = 10\n", + " ax[i, 0].set_title(jet_type)\n", + " # eta_rel\n", + " for j in range(3):\n", + " _, bin_edges, _ = ax[i, j].hist(\n", + " data_real[:, :, j][mask_particle_level[:, :, 0]].flatten(),\n", + " **hist_kwargs,\n", + " label=\"Sim. data\",\n", + " histtype=\"stepfilled\",\n", + " alpha=0.5,\n", + " )\n", + " ax[i, j].hist(\n", + " data_generated[:, :, j][mask_particle_level[:, :, 0]].flatten(),\n", + " bins=bin_edges,\n", + " density=True,\n", + " label=\"Gen. data\",\n", + " histtype=\"step\",\n", + " )\n", + " ax[i, j].set_yscale(\"log\")\n", + " ax[i, 0].legend(frameon=False)\n", + " ax[i, 0].set_xlabel(\"$\\\\eta_\\\\mathrm{rel}$\")\n", + " ax[i, 1].set_xlabel(\"$\\\\phi_\\\\mathrm{rel}$\")\n", + " ax[i, 2].set_xlabel(\"$p_\\\\mathrm{T}^\\\\mathrm{rel}$\")\n", + " # jet mass\n", + " ax[i, 3].hist(\n", + " jet_features_real[:, 3][mask],\n", + " **hist_kwargs,\n", + " label=\"Sim. data\",\n", + " histtype=\"stepfilled\",\n", + " alpha=0.5,\n", + " )\n", + " ax[i, 3].hist(\n", + " jet_features_generated[:, 3][mask], **hist_kwargs, label=\"Gen. data\", histtype=\"step\"\n", + " )\n", + " ax[i, 3].set_xlabel(\"$m_\\\\mathrm{jet}$ (using $p_\\\\mathrm{T}^\\\\mathrm{rel}$)\")\n", + "fig.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { From a8b262e9fdcb678a380ce5c7afeb5f4ec21f964c Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Mon, 7 Aug 2023 14:09:21 +0200 Subject: [PATCH 41/50] Saving the training config as yaml to the run directory --> improved reproducibility --- src/train.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/train.py b/src/train.py index 49a66687..72fcb5e1 100644 --- a/src/train.py +++ b/src/train.py @@ -4,7 +4,7 @@ import pyrootutils import pytorch_lightning as pl import torch -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer from pytorch_lightning.loggers import Logger @@ -50,6 +50,11 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]: if cfg.get("seed"): pl.seed_everything(cfg.seed, workers=True) + # save config for reproducibility and debugging + cfg_backup_file = f'{cfg.trainer.get("default_root_dir")}/config.yaml' + with open(cfg_backup_file, "w") as f: + OmegaConf.save(cfg, f) + log.info(f"Instantiating datamodule <{cfg.data._target_}>") datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) From 226941c1fa6f35fbe8a66106fed07d09cc554ceb Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Mon, 7 Aug 2023 14:22:14 +0200 Subject: [PATCH 42/50] Adapt notebook to new syntax that loads the run config + some refinements --- notebooks/30_jetclass_eval.ipynb | 80 +++++++++++++++----------------- 1 file changed, 37 insertions(+), 43 deletions(-) diff --git a/notebooks/30_jetclass_eval.ipynb b/notebooks/30_jetclass_eval.ipynb index 9172f7d7..6a9b6388 100644 --- a/notebooks/30_jetclass_eval.ipynb +++ b/notebooks/30_jetclass_eval.ipynb @@ -50,20 +50,39 @@ "metadata": {}, "outputs": [], "source": [ + "# specify here the path to the run directory of the model you want to evaluate\n", + "run_dir = \"/beegfs/desy/user/birkjosc/epic-fm/logs/jetclass_cond_jettype/runs/2023-08-06_22-30-00\"\n", + "# run_dir = \"/beegfs/desy/user/birkjosc/epic-fm/logs/jetclass_flow_matching_dev/runs/2023-08-07_13-24-30\"\n", + "cfg_backup_file = f\"{run_dir}/config.yaml\"\n", + "\n", + "# -----------------------------------------------------------\n", + "# for backward-compatability: load the config file from the run directory and save it to the run directory\n", "experiment = \"jetclass_cond.yaml\"\n", "model_name_for_saving = \"nb_fm_tops_jetclass\"\n", - "\n", "# load everything from experiment config\n", "with hydra.initialize(version_base=None, config_path=\"../configs/\"):\n", - " cfg = hydra.compose(config_name=\"train.yaml\", overrides=[f\"experiment={experiment}\"])\n", - " print(type(cfg))\n", - " print(OmegaConf.to_yaml(cfg))\n", + " if os.path.exists(cfg_backup_file):\n", + " print(\"config file already exists --> loading from run directory\")\n", + " else:\n", + " cfg = hydra.compose(config_name=\"train.yaml\", overrides=[f\"experiment={experiment}\"])\n", + " print(f\"saving config file as {cfg_backup_file}\")\n", + " with open(cfg_backup_file, \"w\") as f:\n", + " OmegaConf.save(cfg, f)\n", + "# -----------------------------------------------------------\n", + "\n", + "# load everything from run directory (safer in terms of reproducing results)\n", + "cfg = OmegaConf.load(cfg_backup_file)\n", + "print(type(cfg))\n", + "print(OmegaConf.to_yaml(cfg))\n", "\n", "datamodule = hydra.utils.instantiate(cfg.data)\n", "# datamodule.hparams.number_of_used_jets = 1_000_000\n", "# set remove_etadiff_tails=False when checking the pT_jet distribution calculated from particle pT\n", "# datamodule.hparams.remove_etadiff_tails = False\n", "model = hydra.utils.instantiate(cfg.model)\n", + "# load the model from the checkpoint\n", + "ckpt = f\"{run_dir}/checkpoints/last-EMA.ckpt\"\n", + "model = model.load_from_checkpoint(ckpt)\n", "datamodule.setup()\n", "\n", "# ------------------------------------------------\n", @@ -99,25 +118,6 @@ "print(stds)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# load the model from the checkpoint\n", - "ckpt_path = \"/beegfs/desy/user/birkjosc/epic-fm/logs/jetclass_cond_jettype/runs/2023-08-06_22-30-00/checkpoints\"\n", - "ckpt = f\"{ckpt_path}/last-EMA.ckpt\"\n", - "model = model.load_from_checkpoint(ckpt)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -125,29 +125,23 @@ "outputs": [], "source": [ "# optional: increase the size of the test data for better statistics\n", - "factor = 1 # this is the factor by which the test data is increased/repeated\n", - "# chosse between test and val\n", - "stop = 100_000\n", - "mask_real = test_mask[:stop]\n", - "data_real = test_data[:stop]\n", - "cond_real = test_cond[:stop]\n", - "spectator_real = test_spectator[:stop]\n", + "FACTOR_REPEAT_MASK_COND = 1 # this is the factor by which the test data is increased/repeated\n", + "NUMER_OF_GENERATED_JETS = 1_000\n", + "\n", + "# choose between test and val\n", + "mask_real = test_mask[:NUMER_OF_GENERATED_JETS]\n", + "data_real = test_data[:NUMER_OF_GENERATED_JETS]\n", + "cond_real = test_cond[:NUMER_OF_GENERATED_JETS]\n", + "spectator_real = test_spectator[:NUMER_OF_GENERATED_JETS]\n", "\n", "# increase size for better statistics\n", - "big_mask_real = np.repeat(mask_real, factor, axis=0)\n", - "big_data_real = np.repeat(data_real, factor, axis=0)\n", - "big_cond_real = np.repeat(cond_real, factor, axis=0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + "big_mask_real = np.repeat(mask_real, FACTOR_REPEAT_MASK_COND, axis=0)\n", + "big_data_real = np.repeat(data_real, FACTOR_REPEAT_MASK_COND, axis=0)\n", + "big_cond_real = np.repeat(cond_real, FACTOR_REPEAT_MASK_COND, axis=0)\n", + "\n", "data_generated, generation_time = generate_data(\n", " model,\n", - " num_jet_samples=factor * len(mask_real),\n", + " num_jet_samples=FACTOR_REPEAT_MASK_COND * len(mask_real),\n", " batch_size=1000,\n", " cond=torch.tensor(big_cond_real),\n", " variable_set_sizes=True,\n", @@ -200,7 +194,7 @@ " None,\n", " None,\n", " num_eval_samples=len(data_real),\n", - " num_batches=factor,\n", + " num_batches=FACTOR_REPEAT_MASK_COND,\n", " calculate_efps=True,\n", " use_masks=False,\n", ")\n", From 2c962163f3cd9aa6c948b389f02a0d64a4a3e479 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Mon, 7 Aug 2023 14:50:38 +0200 Subject: [PATCH 43/50] Add the actual jet type names to the conditioning names array --- src/data/jetclass_datamodule.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/data/jetclass_datamodule.py b/src/data/jetclass_datamodule.py index 97cbcf06..83e932e8 100644 --- a/src/data/jetclass_datamodule.py +++ b/src/data/jetclass_datamodule.py @@ -353,7 +353,7 @@ def setup(self, stage: Optional[str] = None): self.names_conditioning = None else: conditioning_features, self.names_conditioning = self._handle_conditioning( - jet_features, names_jet_features, labels + jet_features, names_jet_features, labels, names_labels ) (conditioning_train, conditioning_val, conditioning_test) = np.split( conditioning_features, @@ -526,7 +526,13 @@ def load_state_dict(self, state_dict: Dict[str, Any]): """Things to do when loading checkpoint.""" pass - def _handle_conditioning(self, jet_data: np.array, names_jet_data: np.array, labels: np.array): + def _handle_conditioning( + self, + jet_data: np.array, + names_jet_data: np.array, + labels: np.array, + names_labels: np.array, + ): """Select the conditioning variables. Args: @@ -534,6 +540,9 @@ def _handle_conditioning(self, jet_data: np.array, names_jet_data: np.array, lab names_jet_data: np.array of shape (n_features,) which contains the names of the features labels: np.array of shape (n_jets,) which contains the labels / jet-types + names_labels: np.array of shape (n_jet_types,) which contains the names of + the jet-types (e.g. if there are three jet types: ['q', 'g', 't'], then + a label 0 would correspond to 'q', 1 to 'g' and 2 to 't') Returns: conditioning_data: np.array of shape (n_jets, n_conditioning_features) names_conditioning_data: np.array of shape (n_conditioning_features,) which @@ -561,7 +570,7 @@ def _handle_conditioning(self, jet_data: np.array, names_jet_data: np.array, lab if self.hparams.conditioning_jet_type: keep_col += list(np.arange(one_hot_len)) - names_conditioning_data += [f"jet_type_{i:.0f}" for i in categories] + names_conditioning_data += [f"jet_type_{names_labels[int(i)]}" for i in categories] if self.hparams.conditioning_pt: keep_col.append(get_feat_index(names_jet_data, "jet_pt") + one_hot_len - 1) names_conditioning_data.append("jet_pt") From f6bf4fb77de10b7f837f104938e742006f44850d Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Mon, 7 Aug 2023 15:01:59 +0200 Subject: [PATCH 44/50] Adapting notebook to new jet type naming array convention --- notebooks/30_jetclass_eval.ipynb | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/notebooks/30_jetclass_eval.ipynb b/notebooks/30_jetclass_eval.ipynb index 6a9b6388..25d04df8 100644 --- a/notebooks/30_jetclass_eval.ipynb +++ b/notebooks/30_jetclass_eval.ipynb @@ -51,7 +51,7 @@ "outputs": [], "source": [ "# specify here the path to the run directory of the model you want to evaluate\n", - "run_dir = \"/beegfs/desy/user/birkjosc/epic-fm/logs/jetclass_cond_jettype/runs/2023-08-06_22-30-00\"\n", + "run_dir = \"/beegfs/desy/user/birkjosc/epic-fm/logs/jetclass_cond_jettype/runs/2023-08-07_11-56-01\"\n", "# run_dir = \"/beegfs/desy/user/birkjosc/epic-fm/logs/jetclass_flow_matching_dev/runs/2023-08-07_13-24-30\"\n", "cfg_backup_file = f\"{run_dir}/config.yaml\"\n", "\n", @@ -223,6 +223,16 @@ "# w_dists_big_non_averaged" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "s = \"blah_label_Hbb\"\n", + "print(s.split(\"blah_\"))" + ] + }, { "cell_type": "code", "execution_count": null, @@ -289,11 +299,12 @@ " print(latex_labels)\n", "\n", "\n", - "for i, jet_type in enumerate(datamodule.names_labels):\n", + "for i, conditioning_variable in enumerate(datamodule.names_conditioning):\n", " # print(jet_type)\n", + " if \"jet_type\" not in conditioning_variable:\n", + " continue\n", " mask = cond_real[:, i] == 1\n", - " # hist_kwargs[\"bins\"] = 10\n", - " # print(mask.shape)\n", + " jet_type = conditioning_variable.split(\"jet_type_\")[-1]\n", " hist_kwargs[\"linestyle\"] = (\n", " \"solid\"\n", " if i < len(cplt.utils.get_good_colours())\n", @@ -338,7 +349,10 @@ "hist_kwargs = dict(bins=100, density=True)\n", "# ax= ax.flatten()\n", "\n", - "for i, jet_type in enumerate(datamodule.names_labels):\n", + "for i, conditioning_variable in enumerate(datamodule.names_conditioning):\n", + " if \"jet_type\" not in conditioning_variable:\n", + " continue\n", + " jet_type = conditioning_variable.split(\"jet_type_\")[-1]\n", " # print(jet_type)\n", " # if i> 0:\n", " # break\n", From efc2c0bcf57d7c7506de0b94062aa881346ec9f4 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Mon, 7 Aug 2023 15:06:46 +0200 Subject: [PATCH 45/50] Add jetclass_dev.yaml for data stuff --- configs/data/jetclass_dev.yaml | 61 ++++++++++++++++++++++++++++ configs/experiment/jetclass_dev.yaml | 8 ++-- 2 files changed, 65 insertions(+), 4 deletions(-) create mode 100644 configs/data/jetclass_dev.yaml diff --git a/configs/data/jetclass_dev.yaml b/configs/data/jetclass_dev.yaml new file mode 100644 index 00000000..28564edc --- /dev/null +++ b/configs/data/jetclass_dev.yaml @@ -0,0 +1,61 @@ +_target_: src.data.jetclass_datamodule.JetClassDataModule + +# overall parameters +batch_size: 1024 +num_workers: 32 +pin_memory: False +val_fraction: 0.10 +test_fraction: 0.30 + +# preprocessing +normalize: True +normalize_sigma: 5 + +spectator_jet_features: + - jet_pt + + +# files and jet types to use +data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass_npz + +jet_types: + ttbar: + files: + - ${data.data_dir}/jetclass_TTBar_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBar_300_000.npz + qcd: + files: + - ${data.data_dir}/jetclass_ZJetsToNuNu_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_300_000.npz + # hbb: + # files: + # - ${data.data_dir}/jetclass_HToBB_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_300_000.npz + # hcc: + # files: + # - ${data.data_dir}/jetclass_HToCC_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToCC_300_000.npz + # hgg: + # files: + # - ${data.data_dir}/jetclass_HToGG_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToGG_300_000.npz + # h4q: + # files: + # - ${data.data_dir}/jetclass_HToWW4Q_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW4Q_300_000.npz + # hqql: + # files: + # - ${data.data_dir}/jetclass_HToWW2Q1L_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW2Q1L_300_000.npz + # zqq: + # files: + # - ${data.data_dir}/jetclass_ZToQQ_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZToQQ_300_000.npz + # wqq: + # files: + # - ${data.data_dir}/jetclass_WToQQ_300_000.npz + # # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_WToQQ_300_000.npz + # ttbarlep: + # files: + # - ${data.data_dir}/jetclass_TTBarLep_300_000.npz + # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBarLep_300_000.npz diff --git a/configs/experiment/jetclass_dev.yaml b/configs/experiment/jetclass_dev.yaml index 3afd565f..0c1aa1c0 100644 --- a/configs/experiment/jetclass_dev.yaml +++ b/configs/experiment/jetclass_dev.yaml @@ -7,7 +7,7 @@ # python train.py experiment=jetclass defaults: - - override /data: jetclass.yaml + - override /data: jetclass_dev.yaml - override /model: flow_matching.yaml - override /callbacks: jetclass.yaml - override /trainer: gpu.yaml @@ -30,12 +30,12 @@ trainer: model: num_particles: 128 - global_cond_dim: 10 # needs to be calculated when using conditioning (= number of jet types + additional conditioning variables) + global_cond_dim: 2 # needs to be calculated when using conditioning (= number of jet types + additional conditioning variables) local_cond_dim: 0 data: # preprocessing - number_of_used_jets: 300000 + number_of_used_jets: 30_000 use_custom_eta_centering: True # this means we are using eta_rel = eta_particle - eta_jet remove_etadiff_tails: True # remove tracks with | eta_rel | > 1 normalize: True @@ -59,7 +59,7 @@ callbacks: additional_eval_epochs: [] # evaluate at these epochs as well num_jet_samples: 100 # jet samples to generate jetclass_eval_test: - num_jet_samples: 6000 # jet samples to generate + num_jet_samples: 100 # jet samples to generate #early_stopping: # monitor: "val/loss" From 47fa019e2d72dc0fc68609606261f24e4998ec40 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Mon, 7 Aug 2023 15:29:16 +0200 Subject: [PATCH 46/50] Add more logging statements --- src/data/jetclass_datamodule.py | 41 ++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/src/data/jetclass_datamodule.py b/src/data/jetclass_datamodule.py index 83e932e8..932e3dc2 100644 --- a/src/data/jetclass_datamodule.py +++ b/src/data/jetclass_datamodule.py @@ -225,19 +225,34 @@ def setup(self, stage: Optional[str] = None): pylogger.info("Loaded data.") pylogger.info("Shapes of arrays as available in files:") - pylogger.info(f"particle_features.shape = {particle_features.shape}") + pylogger.info(f"particle_features names = {names_particle_features}") + pylogger.info(f"particle_features shape = {particle_features.shape}") + pylogger.info(f"jet_features names = {names_jet_features}") pylogger.info(f"jet_features.shape = {jet_features.shape}") + pylogger.info(f"labels names = {names_labels}") pylogger.info(f"labels.shape = {labels.shape}") + pylogger.info("Now processing data...") if self.hparams.number_of_used_jets is not None: if self.hparams.number_of_used_jets < len(jet_features): + pylogger.info( + f"Using only {self.hparams.number_of_used_jets} jets " + f"out of {len(jet_features)}." + ) particle_features = particle_features[: self.hparams.number_of_used_jets] jet_features = jet_features[: self.hparams.number_of_used_jets] labels = labels[: self.hparams.number_of_used_jets] + else: + pylogger.warning( + f"More jets requested ({self.hparams.number_of_used_jets:_}) than " + f"available ({len(jet_features):_})." + "--> Using all available jets." + ) # NOTE: everything below here assumes that the particle features # array after preprocessing stores the features [eta_rel, phi_rel, pt_rel] + pylogger.info("Using eta_rel, phi_rel, pt_rel as particle features.") # check if the particle features are in the correct order index_part_deta = get_feat_index(names_particle_features, "part_deta") assert index_part_deta == 0, "part_deta is not the first feature" @@ -254,6 +269,7 @@ def setup(self, stage: Optional[str] = None): # instead of using the part_deta variable, use part_eta - jet_eta if self.hparams.use_custom_eta_centering: + pylogger.info("Using custom eta centering -> calculating particle_eta - jet_eta") if "part_eta" not in names_particle_features: raise ValueError( "`use_custom_eta_centering` is True, but `part_eta` is not in " @@ -276,6 +292,7 @@ def setup(self, stage: Optional[str] = None): particle_features[:, :, 0] = particle_eta_minus_jet_eta * mask if self.hparams.remove_etadiff_tails: + pylogger.info("Removing eta tails -> removing particles with |eta_rel| > 1") # remove/zero-padd particles with |eta - jet_eta| > 1 mask_etadiff_larger_1 = np.abs(particle_features[:, :, 0]) > 1 particle_features[:, :, :][mask_etadiff_larger_1] = 0 @@ -298,6 +315,7 @@ def setup(self, stage: Optional[str] = None): particle_mask_zero_entries, repeats=particle_features.shape[2], axis=2 ), ) + pylogger.info("Checking that there are no jets without any constituents.") n_jets_without_particles = np.sum(np.sum(~particle_mask_zero_entries, axis=1) == 0) if n_jets_without_particles > 0: raise NotImplementedError( @@ -395,6 +413,7 @@ def setup(self, stage: Optional[str] = None): ) if self.hparams.normalize: + pylogger.info("Standardizing the particle features.") # calculate means and stds only based on the training data self.means = np.ma.mean(dataset_train, axis=(0, 1)) self.stds = np.ma.std(dataset_train, axis=(0, 1)) @@ -452,6 +471,7 @@ def setup(self, stage: Optional[str] = None): self.tensor_conditioning_val_dl = self.tensor_conditioning_val self.tensor_conditioning_test_dl = self.tensor_conditioning_test + pylogger.info("Checking for NaNs in the data.") # check if particle data contains nan values if ( torch.isnan(self.tensor_train_dl).any() @@ -468,6 +488,25 @@ def setup(self, stage: Optional[str] = None): ): raise ValueError("NaNs found in conditioning data!") + pylogger.info("--- Done setting up the dataloader. ---") + pylogger.info("Particle features: eta_rel, phi_rel, pT_rel") + pylogger.info("Conditioning features: %s", self.names_conditioning) + + pylogger.info("--- Shape of the training data: ---") + pylogger.info("particle features: %s", self.tensor_train_dl.shape) + pylogger.info("mask: %s", self.mask_train.shape) + pylogger.info("conditioning features: %s", self.tensor_conditioning_train_dl.shape) + + pylogger.info("--- Shape of the validation data: ---") + pylogger.info("particle features: %s", self.tensor_val_dl.shape) + pylogger.info("mask: %s", self.mask_val.shape) + pylogger.info("conditioning features: %s", self.tensor_conditioning_val_dl.shape) + + pylogger.info("--- Shape of the test data: ---") + pylogger.info("particle features: %s", self.tensor_test_dl.shape) + pylogger.info("mask: %s", self.mask_test.shape) + pylogger.info("conditioning features: %s", self.tensor_conditioning_test_dl.shape) + self.data_train = TensorDataset( self.tensor_train_dl, self.mask_train, From fe214d1c64003cc5cf0c90a00f0b0731056d69c1 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Mon, 7 Aug 2023 15:39:21 +0200 Subject: [PATCH 47/50] Turn on logging in jupyter notebook --- notebooks/30_jetclass_eval.ipynb | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/notebooks/30_jetclass_eval.ipynb b/notebooks/30_jetclass_eval.ipynb index 25d04df8..01499e7f 100644 --- a/notebooks/30_jetclass_eval.ipynb +++ b/notebooks/30_jetclass_eval.ipynb @@ -30,6 +30,8 @@ "load_dotenv()\n", "os.environ[\"DATA_DIR\"] = os.environ.get(\"DATA_DIR\")\n", "\n", + "import logging\n", + "\n", "# plots and metrics\n", "import matplotlib.pyplot as plt\n", "\n", @@ -41,6 +43,11 @@ "from src.utils.data_generation import generate_data\n", "from src.utils.plotting import apply_mpl_styles, create_and_plot_data, plot_single_jets\n", "\n", + "# set up logging for jupyter notebook\n", + "logger = logging.getLogger()\n", + "logger.setLevel(logging.DEBUG)\n", + "logging.debug(\"test\")\n", + "\n", "apply_mpl_styles()" ] }, From 9154ded5690122beaf01b1c8cbfe2efe65e04e64 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Mon, 7 Aug 2023 16:15:56 +0200 Subject: [PATCH 48/50] Move imports and add notes about relative jet mass --- notebooks/30_jetclass_eval.ipynb | 44 +++++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/notebooks/30_jetclass_eval.ipynb b/notebooks/30_jetclass_eval.ipynb index 01499e7f..34744d29 100644 --- a/notebooks/30_jetclass_eval.ipynb +++ b/notebooks/30_jetclass_eval.ipynb @@ -31,6 +31,9 @@ "os.environ[\"DATA_DIR\"] = os.environ.get(\"DATA_DIR\")\n", "\n", "import logging\n", + "from copy import deepcopy\n", + "\n", + "import cplt\n", "\n", "# plots and metrics\n", "import matplotlib.pyplot as plt\n", @@ -40,6 +43,7 @@ " inverse_normalize_tensor,\n", " normalize_tensor,\n", ")\n", + "from src.data.components.utils import calculate_jet_features\n", "from src.utils.data_generation import generate_data\n", "from src.utils.plotting import apply_mpl_styles, create_and_plot_data, plot_single_jets\n", "\n", @@ -254,11 +258,6 @@ "# - jet mass calculated from rescaled pT_particle and eta_rel, phi_rel\n", "# - jet mass calculated from pT_rel, eta_rel, phi_rel\n", "\n", - "from copy import deepcopy\n", - "\n", - "import cplt\n", - "\n", - "from src.data.components.utils import calculate_jet_features\n", "\n", "cplt.utils.set_mpl_colours()\n", "# cplt.utils.reset_mpl_colours()\n", @@ -413,7 +412,40 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "# investivate the relative jet mass a bit:\n", + "# one important note: the relative jet mass only depends on the direction and relative\n", + "# momentum of the constituents, not on their absolute momentum.\n", + "# this means that two jets with different momenta, but the same (in terms of direction\n", + "# and relative pT) constituents, will have the same relative jet mass.\n", + "\n", + "# jet constituent coordinates: (eta_rel, phi_rel, pt_rel)\n", + "jet_constituents_artificial = np.array(\n", + " [\n", + " [\n", + " [-1, 0, 0.5],\n", + " [1, 0, 0.5],\n", + " ],\n", + " [\n", + " [-0.5, 0, 0.5],\n", + " [0.5, 0, 0.5],\n", + " ],\n", + " ]\n", + ")\n", + "jet_features_artificial = calculate_jet_features(jet_constituents_artificial)\n", + "print(jet_features_artificial)\n", + "\n", + "jet_constituents_artificial60 = deepcopy(jet_constituents_artificial)\n", + "jet_constituents_artificial60[:, :, 2] *= 60\n", + "jet_features_artificial60 = calculate_jet_features(jet_constituents_artificial60)\n", + "print(jet_features_artificial60)\n", + "\n", + "jet_constituents_artificial100 = deepcopy(jet_constituents_artificial)\n", + "jet_constituents_artificial100[:, :, 2] *= 100\n", + "jet_features_artificial100 = calculate_jet_features(jet_constituents_artificial100)\n", + "\n", + "print(jet_features_artificial100)" + ] } ], "metadata": { From 16fde64ec9102a74afeb8352325c1da930036566 Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Mon, 7 Aug 2023 16:40:57 +0200 Subject: [PATCH 49/50] Fix typo --- configs/data/jetclass.yaml | 2 +- configs/data/jetclass_dev.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/data/jetclass.yaml b/configs/data/jetclass.yaml index 80dd1b66..751320de 100644 --- a/configs/data/jetclass.yaml +++ b/configs/data/jetclass.yaml @@ -1,6 +1,6 @@ _target_: src.data.jetclass_datamodule.JetClassDataModule -# overall parameters +# general parameters batch_size: 1024 num_workers: 32 pin_memory: False diff --git a/configs/data/jetclass_dev.yaml b/configs/data/jetclass_dev.yaml index 28564edc..4e2893fa 100644 --- a/configs/data/jetclass_dev.yaml +++ b/configs/data/jetclass_dev.yaml @@ -1,6 +1,6 @@ _target_: src.data.jetclass_datamodule.JetClassDataModule -# overall parameters +# general parameters batch_size: 1024 num_workers: 32 pin_memory: False From bfdf2ef76f00a3b3a838d81e4839291ce59608dd Mon Sep 17 00:00:00 2001 From: Joschka Birk Date: Mon, 7 Aug 2023 17:50:15 +0200 Subject: [PATCH 50/50] Change logging level in notebook --- notebooks/30_jetclass_eval.ipynb | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/notebooks/30_jetclass_eval.ipynb b/notebooks/30_jetclass_eval.ipynb index 34744d29..6ebfcbd3 100644 --- a/notebooks/30_jetclass_eval.ipynb +++ b/notebooks/30_jetclass_eval.ipynb @@ -49,8 +49,8 @@ "\n", "# set up logging for jupyter notebook\n", "logger = logging.getLogger()\n", - "logger.setLevel(logging.DEBUG)\n", - "logging.debug(\"test\")\n", + "logger.setLevel(logging.INFO)\n", + "logging.info(\"test\")\n", "\n", "apply_mpl_styles()" ] @@ -62,8 +62,8 @@ "outputs": [], "source": [ "# specify here the path to the run directory of the model you want to evaluate\n", - "run_dir = \"/beegfs/desy/user/birkjosc/epic-fm/logs/jetclass_cond_jettype/runs/2023-08-07_11-56-01\"\n", - "# run_dir = \"/beegfs/desy/user/birkjosc/epic-fm/logs/jetclass_flow_matching_dev/runs/2023-08-07_13-24-30\"\n", + "# run_dir = \"/beegfs/desy/user/birkjosc/epic-fm/logs/jetclass_cond_jettype/runs/2023-08-07_11-56-01\"\n", + "run_dir = \"/beegfs/desy/user/birkjosc/epic-fm/logs/jetclass_cond_jettype/runs/2023-08-06_22-30-00\"\n", "cfg_backup_file = f\"{run_dir}/config.yaml\"\n", "\n", "# -----------------------------------------------------------\n", @@ -137,7 +137,7 @@ "source": [ "# optional: increase the size of the test data for better statistics\n", "FACTOR_REPEAT_MASK_COND = 1 # this is the factor by which the test data is increased/repeated\n", - "NUMER_OF_GENERATED_JETS = 1_000\n", + "NUMER_OF_GENERATED_JETS = 300_000\n", "\n", "# choose between test and val\n", "mask_real = test_mask[:NUMER_OF_GENERATED_JETS]\n", @@ -165,6 +165,17 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# np.save(f\"{run_dir}/data_generated_from_notebook.npy\", data_generated)\n", + "# array_loaded = np.load(f\"{run_dir}/data_generated_from_notebook.npy\")\n", + "# print(array_loaded.shape)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -234,16 +245,6 @@ "# w_dists_big_non_averaged" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "s = \"blah_label_Hbb\"\n", - "print(s.split(\"blah_\"))" - ] - }, { "cell_type": "code", "execution_count": null,