Skip to content

Commit

Permalink
Merge pull request #22 from joschkabirk/jetclass-dataloader-update
Browse files Browse the repository at this point in the history
Jetclass dataloader update
  • Loading branch information
joschkabirk authored Aug 10, 2023
2 parents 2c37571 + 8ca8c3a commit 76da861
Show file tree
Hide file tree
Showing 9 changed files with 466 additions and 472 deletions.
55 changes: 10 additions & 45 deletions configs/data/jetclass.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,51 +11,16 @@ test_fraction: 0.30
normalize: True
normalize_sigma: 5

spectator_jet_features:
- jet_pt
# spectator_jet_features:
# - jet_pt

# select jet types to use
# list of the following: QCD, Hbb, Hcc, Hgg, H4q, Hqql, Zqq, Wqq, Tbqq, Tbl
used_jet_types: null # null means all jet types

# files and jet types to use
data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass_npz

jet_types:
ttbar:
files:
- ${data.data_dir}/jetclass_TTBar_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBar_300_000.npz
qcd:
files:
- ${data.data_dir}/jetclass_ZJetsToNuNu_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_300_000.npz
hbb:
files:
- ${data.data_dir}/jetclass_HToBB_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_300_000.npz
hcc:
files:
- ${data.data_dir}/jetclass_HToCC_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToCC_300_000.npz
hgg:
files:
- ${data.data_dir}/jetclass_HToGG_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToGG_300_000.npz
h4q:
files:
- ${data.data_dir}/jetclass_HToWW4Q_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW4Q_300_000.npz
hqql:
files:
- ${data.data_dir}/jetclass_HToWW2Q1L_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW2Q1L_300_000.npz
zqq:
files:
- ${data.data_dir}/jetclass_ZToQQ_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZToQQ_300_000.npz
wqq:
files:
- ${data.data_dir}/jetclass_WToQQ_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_WToQQ_300_000.npz
ttbarlep:
files:
- ${data.data_dir}/jetclass_TTBarLep_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBarLep_300_000.npz
data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass_h5
filename_dict:
train: ${data.data_dir}/train_100M/merged_standardized.h5
val: ${data.data_dir}/val_5M/merged_standardized.h5
test: ${data.data_dir}/test_20M/merged_standardized.h5
55 changes: 10 additions & 45 deletions configs/data/jetclass_dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,51 +11,16 @@ test_fraction: 0.30
normalize: True
normalize_sigma: 5

spectator_jet_features:
- jet_pt
# spectator_jet_features:
# - jet_pt

# select jet types to use
# list of the following: QCD, Hbb, Hcc, Hgg, H4q, Hqql, Zqq, Wqq, Tbqq, Tbl
used_jet_types: null # null means all jet types

# files and jet types to use
data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass_npz

jet_types:
ttbar:
files:
- ${data.data_dir}/jetclass_TTBar_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBar_300_000.npz
qcd:
files:
- ${data.data_dir}/jetclass_ZJetsToNuNu_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_300_000.npz
# hbb:
# files:
# - ${data.data_dir}/jetclass_HToBB_300_000.npz
# # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_300_000.npz
# hcc:
# files:
# - ${data.data_dir}/jetclass_HToCC_300_000.npz
# # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToCC_300_000.npz
# hgg:
# files:
# - ${data.data_dir}/jetclass_HToGG_300_000.npz
# # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToGG_300_000.npz
# h4q:
# files:
# - ${data.data_dir}/jetclass_HToWW4Q_300_000.npz
# # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW4Q_300_000.npz
# hqql:
# files:
# - ${data.data_dir}/jetclass_HToWW2Q1L_300_000.npz
# # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW2Q1L_300_000.npz
# zqq:
# files:
# - ${data.data_dir}/jetclass_ZToQQ_300_000.npz
# # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZToQQ_300_000.npz
# wqq:
# files:
# - ${data.data_dir}/jetclass_WToQQ_300_000.npz
# # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_WToQQ_300_000.npz
# ttbarlep:
# files:
# - ${data.data_dir}/jetclass_TTBarLep_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBarLep_300_000.npz
data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass_h5
filename_dict:
train: ${data.data_dir}/train_100M/merged_standardized.h5
val: ${data.data_dir}/val_5M/merged_standardized.h5
test: ${data.data_dir}/test_20M/merged_standardized.h5
2 changes: 0 additions & 2 deletions configs/experiment/jetclass_cond.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ model:
data:
# preprocessing
number_of_used_jets: 3_000_000
use_custom_eta_centering: True # this means we are using eta_rel = eta_particle - eta_jet
remove_etadiff_tails: True # remove tracks with | eta_rel | > 1
normalize: True
normalize_sigma: 5
# conditioning
Expand Down
4 changes: 1 addition & 3 deletions configs/experiment/jetclass_dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ model:
data:
# preprocessing
number_of_used_jets: 30_000
use_custom_eta_centering: True # this means we are using eta_rel = eta_particle - eta_jet
remove_etadiff_tails: True # remove tracks with | eta_rel | > 1
normalize: True
normalize_sigma: 1
# conditioning
Expand All @@ -59,7 +57,7 @@ callbacks:
additional_eval_epochs: [] # evaluate at these epochs as well
num_jet_samples: 100 # jet samples to generate
jetclass_eval_test:
num_jet_samples: 100 # jet samples to generate
num_jet_samples: 1_000 # jet samples to generate

#early_stopping:
# monitor: "val/loss"
Expand Down
153 changes: 135 additions & 18 deletions notebooks/30_jetclass_eval.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@
"source": [
"# optional: increase the size of the test data for better statistics\n",
"FACTOR_REPEAT_MASK_COND = 1 # this is the factor by which the test data is increased/repeated\n",
"NUMER_OF_GENERATED_JETS = 300_000\n",
"NUMER_OF_GENERATED_JETS = 200_000\n",
"\n",
"# choose between test and val\n",
"mask_real = test_mask[:NUMER_OF_GENERATED_JETS]\n",
Expand All @@ -148,8 +148,15 @@
"# increase size for better statistics\n",
"big_mask_real = np.repeat(mask_real, FACTOR_REPEAT_MASK_COND, axis=0)\n",
"big_data_real = np.repeat(data_real, FACTOR_REPEAT_MASK_COND, axis=0)\n",
"big_cond_real = np.repeat(cond_real, FACTOR_REPEAT_MASK_COND, axis=0)\n",
"\n",
"big_cond_real = np.repeat(cond_real, FACTOR_REPEAT_MASK_COND, axis=0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data_generated, generation_time = generate_data(\n",
" model,\n",
" num_jet_samples=FACTOR_REPEAT_MASK_COND * len(mask_real),\n",
Expand All @@ -171,7 +178,18 @@
"metadata": {},
"outputs": [],
"source": [
"# np.save(f\"{run_dir}/data_generated_from_notebook.npy\", data_generated)\n",
"!ls -l \"{run_dir}\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# from datetime import datetime\n",
"# now = datetime.now().strftime(\"%Y-%m-%d_%H-%M-%S\")\n",
"# np.save(f\"{run_dir}/data_generated_from_notebook_{len(data_generated)}_{now}.npy\", data_generated)\n",
"# array_loaded = np.load(f\"{run_dir}/data_generated_from_notebook.npy\")\n",
"# print(array_loaded.shape)"
]
Expand Down Expand Up @@ -256,8 +274,6 @@
"# - pT_particle when rescaled with jet pT\n",
"# - pT_jet when calculated from constituents\n",
"#\n",
"# - jet mass calculated from rescaled pT_particle and eta_rel, phi_rel\n",
"# - jet mass calculated from pT_rel, eta_rel, phi_rel\n",
"\n",
"\n",
"cplt.utils.set_mpl_colours()\n",
Expand Down Expand Up @@ -294,7 +310,10 @@
"fig.tight_layout()\n",
"plt.show()\n",
"\n",
"fig, ax = plt.subplots(1, 2, figsize=(15, 6))\n",
"# - jet mass calculated from rescaled pT_particle and eta_rel, phi_rel\n",
"# - jet mass calculated from pT_rel, eta_rel, phi_rel\n",
"\n",
"fig, ax = plt.subplots(1, 2, figsize=(13, 5))\n",
"hist_kwargs = dict(histtype=\"step\", density=True, linewidth=2)\n",
"\n",
"import yaml\n",
Expand Down Expand Up @@ -323,18 +342,79 @@
" bins=np.linspace(0, 300, 60),\n",
" **hist_kwargs,\n",
" )\n",
" ax[0].set_xlabel(\"$m_{jet}$ (using $p_\\\\mathrm{T}^\\\\mathrm{particle}$)\")\n",
" ax[0].set_xlabel(\"$m_\\\\mathrm{jet}$ (using $p_\\\\mathrm{T}^\\\\mathrm{particle}$)\")\n",
" ax[0].set_ylabel(\"Normalized\")\n",
" ax[1].hist(\n",
" jet_features_rel[:, 3][mask],\n",
" label=latex_labels[jet_type],\n",
" bins=np.linspace(0, 0.6, 60),\n",
" **hist_kwargs,\n",
" )\n",
" ax[1].set_xlabel(\n",
" \"$m_{jet}$ (using $p_\\\\mathrm{T}^\\\\mathrm{particle} / p_\\\\mathrm{T}^\\\\mathrm{jet}$)\"\n",
" \"$m_\\\\mathrm{jet}$ (using $p_\\\\mathrm{T}^\\\\mathrm{particle} /\"\n",
" \" p_\\\\mathrm{T}^\\\\mathrm{jet}$)\"\n",
" )\n",
" ax[1].set_ylabel(\"Normalized\")\n",
"ax[0].legend(frameon=False)\n",
"fig.tight_layout()\n",
"fig.savefig(\"jet_mass_comparison.pdf\", bbox_inches=\"tight\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(1, 2, figsize=(13, 5))\n",
"hist_kwargs = dict(histtype=\"step\", density=True, linewidth=2)\n",
"\n",
"import yaml\n",
"\n",
"# load labels from labels.yaml\n",
"with open(\"../configs/plotting/labels.yaml\", \"r\") as f:\n",
" labels = yaml.load(f, Loader=yaml.SafeLoader)\n",
" latex_labels = labels[\"latex_labels\"]\n",
" print(latex_labels)\n",
"\n",
"n_particles_real = np.sum(data_real[:, :, 2] != 0, axis=1)\n",
"print(n_particles_real)\n",
"\n",
"for i, conditioning_variable in enumerate(datamodule.names_conditioning):\n",
" # print(jet_type)\n",
" if \"jet_type\" not in conditioning_variable:\n",
" continue\n",
" mask_this_jet_type = cond_real[:, i] == 1\n",
" jet_type = conditioning_variable.split(\"jet_type_\")[-1]\n",
" hist_kwargs[\"linestyle\"] = (\n",
" \"solid\"\n",
" if i < len(cplt.utils.get_good_colours())\n",
" else cplt.utils.get_good_linestyles(\"densely dotted\")\n",
" )\n",
" ax[0].hist(\n",
" n_particles_real[mask_this_jet_type],\n",
" label=latex_labels[jet_type],\n",
" bins=np.linspace(-5.5, 120.5, 127),\n",
" **hist_kwargs,\n",
" )\n",
" ax[0].set_xlabel(\"Number of jet constituents\")\n",
" ax[0].set_ylabel(\"Normalized\")\n",
" mask_this_jet_type_and_isvalid = np.logical_and(\n",
" mask_real[:, :, 0] != 0, np.repeat(mask_this_jet_type[:, np.newaxis], 128, axis=1)\n",
" )\n",
" ax[1].hist(\n",
" data_real[:, :, 0][mask_this_jet_type_and_isvalid].flatten(),\n",
" label=latex_labels[jet_type],\n",
" bins=np.linspace(-1.1, 1.1, 100),\n",
" **hist_kwargs,\n",
" )\n",
" ax[1].set_xlabel(\"$\\\\eta^\\\\mathrm{rel}$\")\n",
" ax[1].set_ylabel(\"Normalized\")\n",
" ax[1].set_yscale(\"log\")\n",
"ax[0].legend(frameon=False)\n",
"fig.tight_layout()\n",
"fig.savefig(\"num_constituents_and_etarel.pdf\", bbox_inches=\"tight\")\n",
"plt.show()"
]
},
Expand All @@ -352,21 +432,27 @@
"jet_features_generated = calculate_jet_features(data_generated)\n",
"\n",
"# plot the jet mass for each jet type\n",
"fig, ax = plt.subplots(10, 5, figsize=(18, 30))\n",
"fig, ax = plt.subplots(11, 5, figsize=(18, 30))\n",
"hist_kwargs = dict(bins=100, density=True)\n",
"# ax= ax.flatten()\n",
"\n",
"for i, conditioning_variable in enumerate(datamodule.names_conditioning):\n",
" if \"jet_type\" not in conditioning_variable:\n",
"for i, conditioning_variable in enumerate([\"all\"] + list(datamodule.names_conditioning)):\n",
" if not (\"jet_type\" in conditioning_variable or \"all\" in conditioning_variable):\n",
" continue\n",
" jet_type = conditioning_variable.split(\"jet_type_\")[-1]\n",
" if \"all\" in conditioning_variable:\n",
" jet_type = \"All jets types\"\n",
" mask_particle_level = mask_real != 0\n",
" mask = np.ones(len(cond_real)) > 0\n",
"\n",
" else:\n",
" jet_type = conditioning_variable.split(\"jet_type_\")[-1]\n",
" mask = cond_real[:, i - 1] == 1\n",
" mask_particle_level = np.repeat(\n",
" mask[:, np.newaxis, np.newaxis], data_real.shape[1], axis=1\n",
" ) & (mask_real != 0)\n",
" # print(jet_type)\n",
" # if i> 0:\n",
" # if i> 1:\n",
" # break\n",
" mask = cond_real[:, i] == 1\n",
" mask_particle_level = np.repeat(\n",
" mask[:, np.newaxis, np.newaxis], data_real.shape[1], axis=1\n",
" ) & (mask_real != 0)\n",
" # print(mask.shape)\n",
" # print(mask_particle_level.shape)\n",
" # hist_kwargs[\"bins\"] = 10\n",
Expand Down Expand Up @@ -447,6 +533,37 @@
"\n",
"print(jet_features_artificial100)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"arr = np.array(\n",
" [\n",
" [\n",
" [-1, 0, 0.5],\n",
" [1, 0, 0.5],\n",
" ],\n",
" [\n",
" [-0.5, 0, 0.5],\n",
" [0.5, 0, 0.5],\n",
" ],\n",
" ]\n",
")\n",
"print(arr)\n",
"mask = np.array([True, False])\n",
"print()\n",
"print(arr[mask])\n",
"\n",
"print((arr[:, :, 0] > 0).shape)\n",
"\n",
"arr = np.array([[1, 2, 3], [4, 5, 6], [1, 2, 3]])\n",
"mask = arr[:, 0] > 2\n",
"print(mask.shape)\n",
"arr[mask]"
]
}
],
"metadata": {
Expand Down
2 changes: 1 addition & 1 deletion src/callbacks/jetclass_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def on_train_epoch_end(self, trainer, pl_module):
)

# Plotting
plot_name = f"{self.model_name}--epoch{trainer.current_epoch}"
plot_name = f"{self.model_name}_epoch{trainer.current_epoch}"
_ = plot_data(
particle_data=np.array([data]),
sim_data=background_data,
Expand Down
Loading

0 comments on commit 76da861

Please sign in to comment.