Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/ewencedr/DeepLearning into …
Browse files Browse the repository at this point in the history
…lhco_unprocessed_data
  • Loading branch information
ewencedr committed Aug 2, 2023
2 parents 40bba60 + 9bb80fa commit 464e2f2
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 12 deletions.
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ repos:
- id: bandit
args: ["-s", "B101"]

# python ruff linter
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.0.254"
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]

# yaml formatting
#- repo: https://github.com/pre-commit/mirrors-prettier
# rev: v2.1.0
Expand Down
4 changes: 2 additions & 2 deletions configs/callbacks/jetclass_eval.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Generate data, calculate plots and metrics and log them to the logger
jetclass_eval:
_target_: src.callbacks.jetclass_eval.JetClassEvaluationCallback
every_n_epochs: 10 # evaluate every n epochs
num_jet_samples: 10000 # jet samples to generate
every_n_epochs: 100 # evaluate every n epochs
num_jet_samples: 50000 # jet samples to generate
image_path: ${paths.log_dir}callback_logs/
model_name: "model-test"
use_ema: True
Expand Down
69 changes: 67 additions & 2 deletions notebooks/30_jetclass_eval.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"%matplotlib inline\n",
"%config InlineBackend.figure_format='retina'\n",
"\n",
"import hydra\n",
"import numpy as np\n",
"import pytorch_lightning as pl\n",
Expand Down Expand Up @@ -52,7 +55,8 @@
" print(OmegaConf.to_yaml(cfg))\n",
"\n",
"datamodule = hydra.utils.instantiate(cfg.data)\n",
"# datamodule.hparams.number_of_used_jets = 1000\n",
"# set remove_etadiff_tails=False when checking the pT_jet distribution calculated from particle pT\n",
"# datamodule.hparams.remove_etadiff_tails = False\n",
"model = hydra.utils.instantiate(cfg.model)\n",
"datamodule.setup()"
]
Expand Down Expand Up @@ -205,7 +209,68 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"cond_real_repeat = np.repeat(cond_real[:, np.newaxis, :], mask_real.shape[1], axis=1)\n",
"cond_real_repeat.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Crosscheck plots:\n",
"# - pT_particle / pT_jet (as in dataset)\n",
"# - pT_particle when rescaled with jet pT\n",
"# - pT_jet when calculated from constituents\n",
"#\n",
"# - jet mass calculated from rescaled pT_particle and eta_rel, phi_rel\n",
"# - jet mass calculated from pT_rel, eta_rel, phi_rel\n",
"\n",
"from copy import deepcopy\n",
"\n",
"from src.data.components.utils import calculate_jet_features\n",
"\n",
"fig, ax = plt.subplots(1, 3, figsize=(15, 4))\n",
"hist_kwargs = dict(bins=100, histtype=\"step\")\n",
"\n",
"# make copy of particle features\n",
"particle_features = deepcopy(data_real)\n",
"\n",
"# re-scale particle pt with jet pt\n",
"particle_features[:, :, 2] *= cond_real_repeat[:, :, 0]\n",
"\n",
"# calculate jet features (both with pT_rel and pT)\n",
"jet_features_rel = calculate_jet_features(data_real) # pT_rel\n",
"jet_features = calculate_jet_features(particle_features) # pT\n",
"\n",
"# Note: the jet pt which is calculated from the constituent pt does not\n",
"# yield exactly the same distribution if the etadiff tails are removed!\n",
"# the distributions should match though when using all constituents.\n",
"ax[0].hist(data_real[:, :, 2][mask_real[:, :, 0] != 0].flatten(), **hist_kwargs)\n",
"ax[0].set_xlabel(\"$p_T^{particle} / p_T^{jet}$\")\n",
"ax[1].hist(particle_features[:, :, 2][mask_real[:, :, 0] != 0].flatten(), **hist_kwargs)\n",
"ax[1].set_xlabel(\"$p_T^{particle}$\")\n",
"ax[0].set_yscale(\"log\")\n",
"ax[1].set_yscale(\"log\")\n",
"ax[2].hist(jet_features[:, 0], **hist_kwargs, label=\"Calculated from $p_T^{particle}$\")\n",
"ax[2].hist(cond_real[:, 0], **hist_kwargs, label=\"Original value\", ls=\"--\")\n",
"ax[2].legend(frameon=False)\n",
"ax[2].set_xlabel(\"$p_T^{jet}$\")\n",
"fig.tight_layout()\n",
"plt.show()\n",
"\n",
"fig, ax = plt.subplots(1, 2, figsize=(15, 4))\n",
"ax[0].hist(jet_features[:, 3], **hist_kwargs, label=\"Calculated from $p_T^{particle}$\")\n",
"ax[0].set_xlabel(\"$m_{jet}$ - using $p_T^{particle}$\")\n",
"ax[1].hist(\n",
" jet_features_rel[:, 3], **hist_kwargs, label=\"Calculated from $p_T^{particle} / p_T^{jet}$\"\n",
")\n",
"ax[1].set_xlabel(\"$m_{jet}$ - using $p_T^{particle} / p_T^{jet}$\")\n",
"fig.tight_layout()\n",
"plt.show()"
]
}
],
"metadata": {
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@ exclude_lines = [
[tool.black]
line-length = 99
preview = "True"

[tool.ruff]
line-length = 99
22 changes: 14 additions & 8 deletions src/callbacks/jetclass_eval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Callback for evaluating the model on the JetClass dataset."""
import warnings
from typing import Any, Callable, Dict, Mapping, Optional
from typing import Callable, Mapping, Optional

import numpy as np
import pytorch_lightning as pl
Expand Down Expand Up @@ -38,15 +38,21 @@ class JetClassEvaluationCallback(pl.Callback):
num_jet_samples (int, optional): How many jet samples to generate.
Negative values define the amount of times the whole dataset is taken,
e.g. -2 would use 2*len(dataset) samples. Defaults to -1.
image_path (str, optional): Folder where the images are saved. Defaults to "./logs/callback_images/".
image_path (str, optional): Folder where the images are saved. Defaults
to "./logs/callback_images/".
model_name (str, optional): Name for saving the model. Defaults to "model-test".
log_times (bool, optional): Log generation times of data. Defaults to True.
log_epoch_zero (bool, optional): Log in first epoch. Default to False.
data_type (str, optional): Type of data to plot. Options are 'test' and 'val'. Defaults to "test".
use_ema (bool, optional): Use exponential moving average weights for logging. Defaults to False.
fix_seed (bool, optional): Fix seed for data generation to have better reproducibility and comparability between epochs. Defaults to True.
w_dist_config (Mapping, optional): Configuration for Wasserstein distance calculation. Defaults to {'num_jet_samples': 10_000, 'num_batches': 40}.
generation_config (Mapping, optional): Configuration for data generation. Defaults to {"batch_size": 256, "ode_solver": "midpoint", "ode_steps": 100}.
data_type (str, optional): Type of data to plot. Options are 'test' and 'val'.
Defaults to "test".
use_ema (bool, optional): Use exponential moving average weights for logging.
Defaults to False.
fix_seed (bool, optional): Fix seed for data generation to have better
reproducibility and comparability between epochs. Defaults to True.
w_dist_config (Mapping, optional): Configuration for Wasserstein distance
calculation. Defaults to {'num_jet_samples': 10_000, 'num_batches': 40}.
generation_config (Mapping, optional): Configuration for data generation.
Defaults to {"batch_size": 256, "ode_solver": "midpoint", "ode_steps": 100}.
plot_config (Mapping, optional): Configuration for plotting. Defaults to {}.
"""

Expand Down Expand Up @@ -275,7 +281,7 @@ def on_train_epoch_end(self, trainer, pl_module):

# Plotting
plot_name = f"{self.model_name}--epoch{trainer.current_epoch}"
fig = plot_data(
_ = plot_data(
particle_data=np.array([data]),
sim_data=background_data,
jet_data_sim=jet_data_sim,
Expand Down

0 comments on commit 464e2f2

Please sign in to comment.