Skip to content

Commit

Permalink
Merge pull request #14 from joschkabirk/jetclass-eval-update
Browse files Browse the repository at this point in the history
Jetclass eval update
  • Loading branch information
ewencedr authored Aug 2, 2023
2 parents f0e7218 + e2abbcc commit 790e13e
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 67 deletions.
2 changes: 1 addition & 1 deletion configs/callbacks/jetclass.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ model_checkpoint:
metric_map:
"val/loss": "loss"


# TODO: add also checkpoints for minimal wasserstein distance

#early_stopping:
# monitor: "val/loss"
Expand Down
5 changes: 3 additions & 2 deletions configs/callbacks/jetclass_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
jetclass_eval:
_target_: src.callbacks.jetclass_eval.JetClassEvaluationCallback
every_n_epochs: 100 # evaluate every n epochs
num_jet_samples: 50000 # jet samples to generate
additional_eval_epochs: [1, 20, 50, 80] # evaluate at these epochs as well
num_jet_samples: 50_000 # jet samples to generate for evaluation
image_path: ${paths.log_dir}callback_logs/
model_name: "model-test"
use_ema: True
Expand All @@ -14,7 +15,7 @@ jetclass_eval:
num_batches: 40
calculate_efps: False
generation_config:
batch_size: 1000
batch_size: 1_000
ode_solver: "midpoint"
ode_steps: 200
verbose: False
Expand Down
2 changes: 1 addition & 1 deletion configs/callbacks/jetclass_eval_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jetclass_eval_test:
fix_seed: true
evaluate_substructure: true
suffix: ""
cond_path: ${paths.root_dir}/data/conditioning.h5
# cond_path: ${paths.root_dir}/data/conditioning.h5
w_dist_config:
num_eval_samples: 50_000
num_batches: 40
Expand Down
2 changes: 1 addition & 1 deletion configs/data/jetclass.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
_target_: src.data.jetclass_datamodule.JetClassDataModule
data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass_npz
data_filename: jetclass_TTBar_200_000.npz
data_filename: jetclass_TTBar_2_000_000.npz
batch_size: 1024
num_workers: 32
pin_memory: False
Expand Down
17 changes: 12 additions & 5 deletions configs/experiment/jetclass.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ tags: ["flow_matching", "JetClass", "uncond"]
seed: 12345

trainer:
min_epochs: 10
max_epochs: 10000
min_epochs: 1
max_epochs: 2_000
gradient_clip_val: 0.5

model:
Expand All @@ -29,7 +29,7 @@ model:
local_cond_dim: 0

data:
number_of_used_jets: 200000
number_of_used_jets: 800_000
use_custom_eta_centering: True # this means we are using eta_rel = eta_particle - eta_jet
remove_etadiff_tails: True # remove tracks with | eta_rel | > 1
conditioning_pt: False
Expand All @@ -44,18 +44,25 @@ callbacks:
start_step: 0
save_ema_weights_in_callback_state: True
evaluate_ema_weights_instead: True
jetclass_eval:
every_n_epochs: 100 # evaluate every n epochs
additional_eval_epochs: [10, 30, 50, 75] # evaluate at these epochs as well
num_jet_samples: 50_000 # jet samples to generate
jetclass_eval_test:
num_jet_samples: 60_000 # jet samples to generate

#early_stopping:
# monitor: "val/loss"
# patience: 2000
# mode: "min"

task_name: "jetclass_flow_matching"
task_name: "jetclass"

logger:
wandb:
tags: ${tags}
group: "jetclass_flow_matching"
group: "flow_matching_jetclass"
name: ${task_name}
comet:
experiment_name: ${task_name}
project_name: "flow-matching"
8 changes: 6 additions & 2 deletions configs/experiment/jetclass_dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ model:
local_cond_dim: 0

data:
number_of_used_jets: 2000
number_of_used_jets: 200
use_custom_eta_centering: True # this means we are using eta_rel = eta_particle - eta_jet
remove_etadiff_tails: True # remove tracks with | eta_rel | > 1
conditioning_pt: False
Expand All @@ -48,8 +48,11 @@ callbacks:
save_ema_weights_in_callback_state: True
evaluate_ema_weights_instead: True
jetclass_eval:
every_n_epochs: 1 # evaluate every n epochs
every_n_epochs: 5 # evaluate every n epochs
additional_eval_epochs: [1] # evaluate at these epochs as well
num_jet_samples: 100 # jet samples to generate
jetclass_eval_test:
num_jet_samples: 6000 # jet samples to generate

#early_stopping:
# monitor: "val/loss"
Expand All @@ -65,3 +68,4 @@ logger:
name: ${task_name}
comet:
experiment_name: ${task_name}
project_name: "flow-matching"
26 changes: 18 additions & 8 deletions src/callbacks/jetclass_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .ema import EMA

log = get_pylogger("JetClassEvaluationCallback")
pylogger = get_pylogger("JetClassEvaluationCallback")

# TODO wandb logging min and max values
# TODO wandb logging video of jets, histograms, and point clouds
Expand All @@ -35,6 +35,7 @@ class JetClassEvaluationCallback(pl.Callback):
Args:
every_n_epochs (int, optional): Log every n epochs. Defaults to 10.
additional_eval_epochs (list, optional): Log additional epochs. Defaults to [].
num_jet_samples (int, optional): How many jet samples to generate.
Negative values define the amount of times the whole dataset is taken,
e.g. -2 would use 2*len(dataset) samples. Defaults to -1.
Expand All @@ -59,6 +60,7 @@ class JetClassEvaluationCallback(pl.Callback):
def __init__(
self,
every_n_epochs: int | Callable = 10,
additional_eval_epochs: list[int] = None,
num_jet_samples: int = -1,
image_path: str = "./logs/callback_images/",
model_name: str = "model",
Expand All @@ -80,6 +82,7 @@ def __init__(
):
super().__init__()
self.every_n_epochs = every_n_epochs
self.additional_eval_epochs = additional_eval_epochs
self.num_jet_samples = num_jet_samples
self.log_times = log_times
self.log_epoch_zero = log_epoch_zero
Expand Down Expand Up @@ -114,10 +117,6 @@ def on_train_start(self, trainer, pl_module) -> None:
self.log("w1m_mean", 0.005)
self.log("w1p_mean", 0.005)

self.log("training_dataset_size", float(len(trainer.datamodule.tensor_train)))
self.log("validation_dataset_size", float(len(trainer.datamodule.tensor_val)))
self.log("test_dataset_size", float(len(trainer.datamodule.tensor_test)))

# set number of jet samples if negative
if self.num_jet_samples < 0:
self.datasets_multiplier = abs(self.num_jet_samples)
Expand All @@ -131,10 +130,16 @@ def on_train_start(self, trainer, pl_module) -> None:
)
else:
self.datasets_multiplier = -1
self.log("number_of_generated_val_jets", float(self.num_jet_samples))

hparams_to_log = {
"training_dataset_size": float(len(trainer.datamodule.tensor_train)),
"validation_dataset_size": float(len(trainer.datamodule.tensor_val)),
"test_dataset_size": float(len(trainer.datamodule.tensor_test)),
"number_of_generated_val_jets": float(self.num_jet_samples),
}
# get loggers
for logger in trainer.loggers:
logger.log_hyperparams(hparams_to_log)
if isinstance(logger, pl.loggers.CometLogger):
self.comet_logger = logger.experiment
elif isinstance(logger, pl.loggers.WandbLogger):
Expand All @@ -148,7 +153,7 @@ def on_train_start(self, trainer, pl_module) -> None:
" not found. Using normal weights."
)
elif self.ema_callback is not None and self.use_ema:
log.info("Using EMA weights for logging.")
pylogger.info("Using EMA weights for evaluation.")

def on_train_epoch_end(self, trainer, pl_module):
if self.fix_seed:
Expand All @@ -162,7 +167,6 @@ def on_train_epoch_end(self, trainer, pl_module):

# determine if logging should happen
log = False
print(f"Finishing epoch {trainer.current_epoch}")
if type(self.every_n_epochs) is int:
if trainer.current_epoch % self.every_n_epochs == 0 and log_epoch:
log = True
Expand All @@ -174,8 +178,13 @@ def on_train_epoch_end(self, trainer, pl_module):
log = custom_logging_schedule(trainer.current_epoch)
except KeyError:
raise KeyError("Custom logging schedule not available.")
# log at additional epochs
if self.additional_eval_epochs is not None:
if trainer.current_epoch in self.additional_eval_epochs and log_epoch:
log = True

if log:
pylogger.info(f"Evaluating model after epoch {trainer.current_epoch}.")
# Get background data for plotting and calculating Wasserstein distances
if self.data_type == "test":
background_data = np.array(trainer.datamodule.tensor_test)[: self.num_jet_samples]
Expand Down Expand Up @@ -222,6 +231,7 @@ def on_train_epoch_end(self, trainer, pl_module):
stds=trainer.datamodule.stds,
**self.generation_config,
)
pylogger.info(f"Generated {len(data)} samples in {generation_time:.0f} seconds.")

# Get normal weights back after sampling
if (
Expand Down
46 changes: 29 additions & 17 deletions src/callbacks/jetclass_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,37 @@

from .ema import EMA, EMAModelCheckpoint

log = get_pylogger("JetClassTestEvaluationCallback")
pylogger = get_pylogger("JetClassTestEvaluationCallback")


# TODO cond_path is currently only working for mass and pt
class JetClassTestEvaluationCallback(pl.Callback):
"""Callback to do final evaluation of the model after training. Specific to JetClass dataset.
Args:
use_ema (bool, optional): Use exponential moving average weights for logging. Defaults to False.
use_ema (bool, optional): Use exponential moving average weights for logging.
Defaults to False.
dataset (str, optional): Dataset to evaluate on. Defaults to "test".
nr_checkpoint_callbacks (int, optional): Number of checkpoint callback that is used to select best epoch. Will only be used when ckpt_path is None. Defaults to 0.
use_last_checkpoint (bool, optional): Use last checkpoint instead of best checkpoint. Defaults to True.
ckpt_path (Optional[str], optional): Path to checkpoint. If given, this ckpt will be used for evaluation. Defaults to None.
num_jet_samples (int, optional): How many jet samples to generate. Negative values define the amount of times the whole dataset is taken, e.g. -2 would use 2*len(dataset) samples. Defaults to -1.
fix_seed (bool, optional): Fix seed for data generation to have better reproducibility and comparability between epochs. Defaults to True.
evaluate_substructure (bool, optional): Evaluate substructure metrics. Takes very long. Defaults to True.
nr_checkpoint_callbacks (int, optional): Number of checkpoint callback that is used to
select best epoch. Will only be used when ckpt_path is None. Defaults to 0.
use_last_checkpoint (bool, optional): Use last checkpoint instead of best checkpoint.
Defaults to True.
ckpt_path (Optional[str], optional): Path to checkpoint. If given, this ckpt will be
used for evaluation. Defaults to None.
num_jet_samples (int, optional): How many jet samples to generate. Negative values define
the amount of times the whole dataset is taken, e.g. -2 would use 2*len(dataset)
samples. Defaults to -1.
fix_seed (bool, optional): Fix seed for data generation to have better reproducibility
and comparability between epochs. Defaults to True.
evaluate_substructure (bool, optional): Evaluate substructure metrics. Takes very long.
Defaults to True.
suffix (str, optional): Suffix for logging. Defaults to "".
cond_path (Optional[str], optional): Path for conditioning that is used during generation. If not provided, the selected dataset will be used for conditioning. Defaults to None.
w_dist_config (Mapping, optional): Configuration for Wasserstein distance calculation. Defaults to {'num_jet_samples': 10_000, 'num_batches': 40}.
generation_config (Mapping, optional): Configuration for data generation. Defaults to {"batch_size": 256, "ode_solver": "midpoint", "ode_steps": 100}.
cond_path (Optional[str], optional): Path for conditioning that is used during generation.
If not provided, the selected dataset will be used for conditioning. Defaults to None.
w_dist_config (Mapping, optional): Configuration for Wasserstein distance calculation.
Defaults to {'num_jet_samples': 10_000, 'num_batches': 40}.
generation_config (Mapping, optional): Configuration for data generation.
Defaults to {"batch_size": 256, "ode_solver": "midpoint", "ode_steps": 100}.
plot_config (Mapping, optional): Configuration for plotting. Defaults to {}.
"""

Expand All @@ -61,7 +72,7 @@ def __init__(
fix_seed: bool = True,
evaluate_substructure: bool = True,
suffix: str = "",
cond_path: Optional[str] = None,
cond_path: Optional[str] = None, # TODO: figure out when to use this
w_dist_config: Mapping = {
"num_eval_samples": 50_000,
"num_batches": 40,
Expand Down Expand Up @@ -97,7 +108,7 @@ def __init__(
self.plot_config = plot_config

def on_test_start(self, trainer, pl_module) -> None:
log.info(
pylogger.info(
"JetClassFinalEvaluationCallback will be used for evaluating the model after training."
)

Expand Down Expand Up @@ -132,11 +143,11 @@ def _get_ema_callback(self, trainer: "pl.Trainer") -> Optional[EMA]:
return ema_callback

def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
log.info(f"Evaluating model on {self.dataset} dataset.")
pylogger.info(f"Evaluating model on {self.dataset} dataset.")

ckpt = self._get_checkpoint(trainer, use_last_checkpoint=self.use_last_checkpoint)

log.info(f"Loading checkpoint from {ckpt}")
pylogger.info(f"Loading checkpoint from {ckpt}")
model = pl_module.load_from_checkpoint(ckpt)

if self.fix_seed:
Expand Down Expand Up @@ -210,6 +221,7 @@ def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> Non
stds=trainer.datamodule.stds,
**self.generation_config,
)
pylogger.info(f"Generated {len(data)} samples in {generation_time:.0f} seconds.")

# save generated data
path = "/".join(ckpt.split("/")[:-2]) + "/"
Expand Down Expand Up @@ -253,7 +265,7 @@ def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> Non
# Plotting
plot_name = f"final_plot{self.suffix}"
img_path = "/".join(ckpt.split("/")[:-2]) + "/"
fig = plot_data(
plot_data(
particle_data=np.array([data_plotting]),
sim_data=background_data,
jet_data_sim=jet_data_sim,
Expand Down Expand Up @@ -379,7 +391,7 @@ def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> Non
)

yaml_path = "/".join(ckpt.split("/")[:-2]) + f"/final_eval_metrics{self.suffix}.yml"
log.info(f"Writing final evaluation metrics to {yaml_path}")
pylogger.info(f"Writing final evaluation metrics to {yaml_path}")

# transform numpy.float64 for better readability in yaml file
metrics = {k: float(v) for k, v in metrics.items()}
Expand Down
Loading

0 comments on commit 790e13e

Please sign in to comment.