Skip to content

Commit

Permalink
Merge pull request #19 from joschkabirk/jetclass-eval
Browse files Browse the repository at this point in the history
Evaluation enhancements
  • Loading branch information
ewencedr authored Aug 8, 2023
2 parents a385dc8 + ad4e74c commit 2c37571
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 12 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,10 @@ During training and evaluation, metrics and plots can be logged via comet and wa
python src/eval.py experiment=experiment_name.yaml ckpt_path=checkpoint_path
```

You can also specify the config file that was saved at the beginning of the training

```bash
python src/eval.py cfg_path=<cfg_file_path> ckpt_path=<checkpoint_path>
```

Notebooks are available to quickly train, evaluate models and create plots.
4 changes: 2 additions & 2 deletions configs/callbacks/jetclass_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ jetclass_eval:
every_n_epochs: 100 # evaluate every n epochs
additional_eval_epochs: [1, 20, 50, 80] # evaluate at these epochs as well
num_jet_samples: 50_000 # jet samples to generate for evaluation
image_path: ${paths.log_dir}callback_logs/
model_name: "model-test"
# image_path: ${paths.log_dir}callback_logs/ # if not set, will default to trainer.default_root_dir/plots
model_name: "epic_fm_jetclass"
use_ema: True
log_times: True
log_epoch_zero: False
Expand Down
3 changes: 3 additions & 0 deletions configs/eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ tags: ["dev"]

# passing checkpoint path is necessary for evaluation
ckpt_path: ???

# cfg_path: allows to load the whole config file of a run - if None/null, the config is composed from the experiment
cfg_path: null
9 changes: 8 additions & 1 deletion src/callbacks/jetclass_eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Callback for evaluating the model on the JetClass dataset."""
import os
import warnings
from typing import Callable, Mapping, Optional

Expand Down Expand Up @@ -62,7 +63,7 @@ def __init__(
every_n_epochs: int | Callable = 10,
additional_eval_epochs: list[int] = None,
num_jet_samples: int = -1,
image_path: str = "./logs/callback_images/",
image_path: str = None,
model_name: str = "model",
log_times: bool = True,
log_epoch_zero: bool = False,
Expand Down Expand Up @@ -117,6 +118,12 @@ def on_train_start(self, trainer, pl_module) -> None:
self.log("w1m_mean", 0.005)
self.log("w1p_mean", 0.005)

if self.image_path is None:
self.image_path = f"{trainer.default_root_dir}/plots/"
os.makedirs(self.image_path, exist_ok=True)

pylogger.info("Logging plots during training to %s", self.image_path)

# set number of jet samples if negative
if self.num_jet_samples < 0:
self.datasets_multiplier = abs(self.num_jet_samples)
Expand Down
16 changes: 12 additions & 4 deletions src/data/components/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,16 +312,24 @@ def get_pt_of_selected_multiplicities(
"""Return pt of jets with selected particle multiplicities.
Args:
particle_data (_type_): _description_
selected_multiplicities (list, optional): _description_. Defaults to [20, 30, 40].
num_jets (int, optional): _description_. Defaults to 150.
particle_data (np.ndarray): Particle data of shape (num_jets, num_particles, num_features)
selected_multiplicities (list, optional): List of selected particle multiplicities. Defaults to [20, 30, 40].
num_jets (int, optional): Number of jets to consider. Defaults to 150.
Returns:
dict: _description_
dict: Dict containing {selected_multiplicity: pt_selected_multiplicity} pairs
where pt_selected_multiplicity is a masked array of shape (num_jets, num_particles).
"""
data = {}
for count, selected_multiplicity in enumerate(selected_multiplicities):
# TODO: the line below might be wrong?
# with that we select particles that have the selected multiplicity or more
# --> is this what we want?
particle_data_temp = particle_data[:, :selected_multiplicity, :]
# TODO: the line below might be critical:
# we have to test for pt_rel non-zero to check if a particle is masked
# particles with eta_rel = 0 can actually have pt_rel != 0, so those would
# be masked even though they are valid particles
mask = np.ma.masked_where(
np.count_nonzero(particle_data_temp[:, :, 0], axis=1) == selected_multiplicity,
np.count_nonzero(particle_data_temp[:, :, 0], axis=1),
Expand Down
16 changes: 14 additions & 2 deletions src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import hydra
import pyrootutils
import torch
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
from pytorch_lightning.loggers import Logger

Expand Down Expand Up @@ -44,7 +44,19 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
"""

assert cfg.ckpt_path
# NOTE: some config parameters that are evaluated on runtime will be displayed
# incorrectly (e.g. trainer.default_root_dir will contain the time of evaluation)
# --> this has to be considered in the evaluation scripts (e.g. specify the
# output directory relative to the checkpoint path)

assert cfg.ckpt_path is not None, "`ckpt_path` must be provided for evaluation!"

# load config from cfg_path if provided
if cfg.cfg_path is not None:
log.info(f"Loading config from cfg_path: {cfg.cfg_path}")
ckpt_path = cfg.ckpt_path
cfg = OmegaConf.load(cfg.cfg_path)
cfg.ckpt_path = ckpt_path

log.info(f"Instantiating datamodule <{cfg.data._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
Expand Down
12 changes: 10 additions & 2 deletions src/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,7 @@ def plot_data(
plt.tight_layout()
if save_fig:
plt.savefig(f"{save_folder}{save_name}.png", bbox_inches="tight")
plt.savefig(f"{save_folder}{save_name}.pdf", bbox_inches="tight")
if close_fig:
plt.close(fig)
return fig
Expand Down Expand Up @@ -1135,6 +1136,7 @@ def plot_loss_curves(
# ax.set_yscale("log")
ax.legend(loc="best")
plt.savefig(f"{save_path}/plots/loss_plots_{name}.png")
plt.savefig(f"{save_path}/plots/loss_plots_{name}.pdf")
plt.show()
plt.clf()

Expand Down Expand Up @@ -1191,6 +1193,7 @@ def do_timing_plots(
ax.legend(loc="best")
# plt.title(f"Time to generate {particles_to_generate} jets")
plt.savefig(f"{save_path}/plots/{name}.png")
plt.savefig(f"{save_path}/plots/{name}.pdf")
return np.array(times)


Expand All @@ -1204,8 +1207,10 @@ def prepare_data_for_plotting(
particles and the pt of selected multiplicities.
Args:
data (np.ndarray): data in the shape (n_jets, n_particles, n_features) with
features (pt, eta, phi)
data (list of np.ndarray): list of data where data is in the shape
(n_jets, n_particles, n_features) with features (pt, eta, phi)
--> this allows to process data in batches. Will be concatenated
in the output
calculate_efps (bool, optional): If efps should be calculated. Defaults to False.
selected_particles (list[int], optional): Selected particles. Defaults to [1,3,10].
selected_multiplicities (list[int], optional): Selected multiplicities.
Expand All @@ -1228,6 +1233,7 @@ def prepare_data_for_plotting(
if calculate_efps:
efps_temp = efps(data_temp)
pt_selected_particles_temp = get_pt_of_selected_particles(data_temp, selected_particles)
# TODO: should probably set the number of jets in the function call below?
pt_selected_multiplicities_temp = get_pt_of_selected_multiplicities(
data_temp, selected_multiplicities
)
Expand Down Expand Up @@ -1425,6 +1431,7 @@ def plot_substructure(
plt.tight_layout()
if save_fig:
plt.savefig(f"{save_folder}{save_name}.png", bbox_inches="tight")
plt.savefig(f"{save_folder}{save_name}.pdf", bbox_inches="tight")
if close_fig:
plt.close(fig)
return fig
Expand Down Expand Up @@ -1461,6 +1468,7 @@ def plot_full_substructure(
plt.tight_layout()
if save_fig:
plt.savefig(f"{save_folder}{save_name}.png", bbox_inches="tight")
plt.savefig(f"{save_folder}{save_name}.pdf", bbox_inches="tight")
if close_fig:
plt.close(fig)
return fig
6 changes: 5 additions & 1 deletion src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ def instantiate_callbacks(callbacks_cfg: DictConfig, ckpt_path: str = None) -> L
for _, cb_conf in callbacks_cfg.items():
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
log.info(f"Instantiating callback <{cb_conf._target_}>")
if cb_conf._target_ == "src.callbacks.jetnet_final_eval.JetNetFinalEvaluationCallback":
if (
cb_conf._target_ == "src.callbacks.jetnet_final_eval.JetNetFinalEvaluationCallback"
or cb_conf._target_
== "src.callbacks.jetclass_eval_test.JetClassTestEvaluationCallback"
):
cb_conf.ckpt_path = ckpt_path
callbacks.append(hydra.utils.instantiate(cb_conf))

Expand Down

0 comments on commit 2c37571

Please sign in to comment.