Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/ewencedr/DeepLearning into …
Browse files Browse the repository at this point in the history
…jet_feature_model
  • Loading branch information
ewencedr committed Aug 11, 2023
2 parents ba172b5 + 76da861 commit cd817dc
Show file tree
Hide file tree
Showing 20 changed files with 917 additions and 291 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
lightning_logs/
logs/
.cometml-runs

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,10 @@ During training and evaluation, metrics and plots can be logged via comet and wa
python src/eval.py experiment=experiment_name.yaml ckpt_path=checkpoint_path
```

You can also specify the config file that was saved at the beginning of the training

```bash
python src/eval.py cfg_path=<cfg_file_path> ckpt_path=<checkpoint_path>
```

Notebooks are available to quickly train, evaluate models and create plots.
4 changes: 2 additions & 2 deletions configs/callbacks/jetclass_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ jetclass_eval:
every_n_epochs: 100 # evaluate every n epochs
additional_eval_epochs: [1, 20, 50, 80] # evaluate at these epochs as well
num_jet_samples: 50_000 # jet samples to generate for evaluation
image_path: ${paths.log_dir}callback_logs/
model_name: "model-test"
# image_path: ${paths.log_dir}callback_logs/ # if not set, will default to trainer.default_root_dir/plots
model_name: "epic_fm_jetclass"
use_ema: True
log_times: True
log_epoch_zero: False
Expand Down
20 changes: 18 additions & 2 deletions configs/data/jetclass.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
_target_: src.data.jetclass_datamodule.JetClassDataModule
data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass_npz
data_filename: jetclass_TTBar_2_000_000.npz

# general parameters
batch_size: 1024
num_workers: 32
pin_memory: False
val_fraction: 0.10
test_fraction: 0.30

# preprocessing
normalize: True
normalize_sigma: 5

# spectator_jet_features:
# - jet_pt

# select jet types to use
# list of the following: QCD, Hbb, Hcc, Hgg, H4q, Hqql, Zqq, Wqq, Tbqq, Tbl
used_jet_types: null # null means all jet types

# files and jet types to use
data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass_h5
filename_dict:
train: ${data.data_dir}/train_100M/merged_standardized.h5
val: ${data.data_dir}/val_5M/merged_standardized.h5
test: ${data.data_dir}/test_20M/merged_standardized.h5
26 changes: 26 additions & 0 deletions configs/data/jetclass_dev.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
_target_: src.data.jetclass_datamodule.JetClassDataModule

# general parameters
batch_size: 1024
num_workers: 32
pin_memory: False
val_fraction: 0.10
test_fraction: 0.30

# preprocessing
normalize: True
normalize_sigma: 5

# spectator_jet_features:
# - jet_pt

# select jet types to use
# list of the following: QCD, Hbb, Hcc, Hgg, H4q, Hqql, Zqq, Wqq, Tbqq, Tbl
used_jet_types: null # null means all jet types

# files and jet types to use
data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass_h5
filename_dict:
train: ${data.data_dir}/train_100M/merged_standardized.h5
val: ${data.data_dir}/val_5M/merged_standardized.h5
test: ${data.data_dir}/test_20M/merged_standardized.h5
3 changes: 3 additions & 0 deletions configs/eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ tags: ["dev"]

# passing checkpoint path is necessary for evaluation
ckpt_path: ???

# cfg_path: allows to load the whole config file of a run - if None/null, the config is composed from the experiment
cfg_path: null
3 changes: 2 additions & 1 deletion configs/experiment/jetclass.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ defaults:

tags: ["flow_matching", "JetClass", "uncond"]

run_note: "This is a note"
run_note: ""

seed: 12345

Expand All @@ -42,6 +42,7 @@ data:
conditioning_eta: False
conditioning_mass: False
conditioning_num_particles: False
conditioning_jet_type: False

callbacks:
ema:
Expand Down
32 changes: 22 additions & 10 deletions configs/experiment/jetclass_cond.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,31 @@ defaults:

tags: ["flow_matching", "JetClass", "cond"]

run_note: "jet-type conditioning, using all 10 jet types"

seed: 12345

trainer:
min_epochs: 10
max_epochs: 10000
min_epochs: 1
max_epochs: 2_000
gradient_clip_val: 0.5

model:
num_particles: 128
global_cond_dim: 2 # needs to be calculated when using conditioning
global_cond_dim: 10 # needs to be calculated when using conditioning (= number of jet types + additional conditioning variables)
local_cond_dim: 0

data:
number_of_used_jets: 200000
use_custom_eta_centering: True # this means we are using eta_rel = eta_particle - eta_jet
remove_etadiff_tails: True # remove tracks with | eta_rel | > 1
conditioning_pt: True
# preprocessing
number_of_used_jets: 3_000_000
normalize: True
normalize_sigma: 5
# conditioning
conditioning_pt: False
conditioning_eta: False
conditioning_mass: True
conditioning_mass: False
conditioning_num_particles: False
conditioning_jet_type: True

callbacks:
ema:
Expand All @@ -44,18 +49,25 @@ callbacks:
start_step: 0
save_ema_weights_in_callback_state: True
evaluate_ema_weights_instead: True
jetclass_eval:
every_n_epochs: 50 # evaluate every n epochs
additional_eval_epochs: [1, 30, 75] # evaluate at these epochs as well
num_jet_samples: 50_000 # jet samples to generate
jetclass_eval_test:
num_jet_samples: 200_000 # jet samples to generate

#early_stopping:
# monitor: "val/loss"
# patience: 2000
# mode: "min"

task_name: "jetclass_flow_matching"
task_name: "jetclass_cond_jettype"

logger:
wandb:
tags: ${tags}
group: "jetclass_flow_matching"
group: "flow_matching_jetclass"
name: ${task_name}
comet:
experiment_name: ${task_name}
project_name: "flow-matching"
20 changes: 12 additions & 8 deletions configs/experiment/jetclass_dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# python train.py experiment=jetclass

defaults:
- override /data: jetclass.yaml
- override /data: jetclass_dev.yaml
- override /model: flow_matching.yaml
- override /callbacks: jetclass.yaml
- override /trainer: gpu.yaml
Expand All @@ -17,8 +17,9 @@ defaults:



tags: ["flow_matching", "JetClass", "uncond", "dev", "debug"]
run_note: "This is a note"
tags: ["flow_matching", "JetClass", "dev", "debug"]

run_note: "Test run with conditioning on jet type"

seed: 12345

Expand All @@ -29,17 +30,20 @@ trainer:

model:
num_particles: 128
global_cond_dim: 0 # needs to be calculated when using conditioning
global_cond_dim: 2 # needs to be calculated when using conditioning (= number of jet types + additional conditioning variables)
local_cond_dim: 0

data:
number_of_used_jets: 200
use_custom_eta_centering: True # this means we are using eta_rel = eta_particle - eta_jet
remove_etadiff_tails: True # remove tracks with | eta_rel | > 1
# preprocessing
number_of_used_jets: 30_000
normalize: True
normalize_sigma: 1
# conditioning
conditioning_pt: False
conditioning_eta: False
conditioning_mass: False
conditioning_num_particles: False
conditioning_jet_type: True

callbacks:
ema:
Expand All @@ -53,7 +57,7 @@ callbacks:
additional_eval_epochs: [] # evaluate at these epochs as well
num_jet_samples: 100 # jet samples to generate
jetclass_eval_test:
num_jet_samples: 6000 # jet samples to generate
num_jet_samples: 1_000 # jet samples to generate

#early_stopping:
# monitor: "val/loss"
Expand Down
11 changes: 11 additions & 0 deletions configs/plotting/labels.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
latex_labels:
label_QCD: QCD
label_Hbb: $H(\rightarrow b\bar{b})$
label_Hcc: $H(\rightarrow c\bar{c})$
label_Hgg: $H(\rightarrow gg)$
label_H4q: $H(\rightarrow 4q)$
label_Hqql: $H(\rightarrow l \nu qq')$
label_Zqq: $Z(\rightarrow q\bar{q})$
label_Wqq: $W(\rightarrow q\bar{q})$
label_Tbqq: $t(\rightarrow b qq')$
label_Tbl: $t(\rightarrow b l \nu)$
Loading

0 comments on commit cd817dc

Please sign in to comment.