Skip to content

Commit

Permalink
Merge pull request #18 from joschkabirk/add-jettypes-jetclass
Browse files Browse the repository at this point in the history
Add support for multiple jet types on JetClass dataset
  • Loading branch information
ewencedr authored Aug 8, 2023
2 parents 53b565e + bfdf2ef commit a385dc8
Show file tree
Hide file tree
Showing 11 changed files with 717 additions and 133 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
lightning_logs/
logs/
.cometml-runs

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
55 changes: 53 additions & 2 deletions configs/data/jetclass.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,61 @@
_target_: src.data.jetclass_datamodule.JetClassDataModule
data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass_npz
data_filename: jetclass_TTBar_2_000_000.npz

# general parameters
batch_size: 1024
num_workers: 32
pin_memory: False
val_fraction: 0.10
test_fraction: 0.30

# preprocessing
normalize: True
normalize_sigma: 5

spectator_jet_features:
- jet_pt


# files and jet types to use
data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass_npz

jet_types:
ttbar:
files:
- ${data.data_dir}/jetclass_TTBar_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBar_300_000.npz
qcd:
files:
- ${data.data_dir}/jetclass_ZJetsToNuNu_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_300_000.npz
hbb:
files:
- ${data.data_dir}/jetclass_HToBB_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_300_000.npz
hcc:
files:
- ${data.data_dir}/jetclass_HToCC_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToCC_300_000.npz
hgg:
files:
- ${data.data_dir}/jetclass_HToGG_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToGG_300_000.npz
h4q:
files:
- ${data.data_dir}/jetclass_HToWW4Q_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW4Q_300_000.npz
hqql:
files:
- ${data.data_dir}/jetclass_HToWW2Q1L_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW2Q1L_300_000.npz
zqq:
files:
- ${data.data_dir}/jetclass_ZToQQ_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZToQQ_300_000.npz
wqq:
files:
- ${data.data_dir}/jetclass_WToQQ_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_WToQQ_300_000.npz
ttbarlep:
files:
- ${data.data_dir}/jetclass_TTBarLep_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBarLep_300_000.npz
61 changes: 61 additions & 0 deletions configs/data/jetclass_dev.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
_target_: src.data.jetclass_datamodule.JetClassDataModule

# general parameters
batch_size: 1024
num_workers: 32
pin_memory: False
val_fraction: 0.10
test_fraction: 0.30

# preprocessing
normalize: True
normalize_sigma: 5

spectator_jet_features:
- jet_pt


# files and jet types to use
data_dir: /beegfs/desy/user/birkjosc/datasets/jetclass_npz

jet_types:
ttbar:
files:
- ${data.data_dir}/jetclass_TTBar_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBar_300_000.npz
qcd:
files:
- ${data.data_dir}/jetclass_ZJetsToNuNu_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZJetsToNuNu_300_000.npz
# hbb:
# files:
# - ${data.data_dir}/jetclass_HToBB_300_000.npz
# # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToBB_300_000.npz
# hcc:
# files:
# - ${data.data_dir}/jetclass_HToCC_300_000.npz
# # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToCC_300_000.npz
# hgg:
# files:
# - ${data.data_dir}/jetclass_HToGG_300_000.npz
# # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToGG_300_000.npz
# h4q:
# files:
# - ${data.data_dir}/jetclass_HToWW4Q_300_000.npz
# # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW4Q_300_000.npz
# hqql:
# files:
# - ${data.data_dir}/jetclass_HToWW2Q1L_300_000.npz
# # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_HToWW2Q1L_300_000.npz
# zqq:
# files:
# - ${data.data_dir}/jetclass_ZToQQ_300_000.npz
# # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_ZToQQ_300_000.npz
# wqq:
# files:
# - ${data.data_dir}/jetclass_WToQQ_300_000.npz
# # - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_WToQQ_300_000.npz
# ttbarlep:
# files:
# - ${data.data_dir}/jetclass_TTBarLep_300_000.npz
# - /beegfs/desy/user/birkjosc/datasets/jetclass_npz/jetclass_TTBarLep_300_000.npz
3 changes: 2 additions & 1 deletion configs/experiment/jetclass.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ defaults:

tags: ["flow_matching", "JetClass", "uncond"]

run_note: "This is a note"
run_note: ""

seed: 12345

Expand All @@ -42,6 +42,7 @@ data:
conditioning_eta: False
conditioning_mass: False
conditioning_num_particles: False
conditioning_jet_type: False

callbacks:
ema:
Expand Down
30 changes: 22 additions & 8 deletions configs/experiment/jetclass_cond.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,33 @@ defaults:

tags: ["flow_matching", "JetClass", "cond"]

run_note: "jet-type conditioning, using all 10 jet types"

seed: 12345

trainer:
min_epochs: 10
max_epochs: 10000
min_epochs: 1
max_epochs: 2_000
gradient_clip_val: 0.5

model:
num_particles: 128
global_cond_dim: 2 # needs to be calculated when using conditioning
global_cond_dim: 10 # needs to be calculated when using conditioning (= number of jet types + additional conditioning variables)
local_cond_dim: 0

data:
number_of_used_jets: 200000
# preprocessing
number_of_used_jets: 3_000_000
use_custom_eta_centering: True # this means we are using eta_rel = eta_particle - eta_jet
remove_etadiff_tails: True # remove tracks with | eta_rel | > 1
conditioning_pt: True
normalize: True
normalize_sigma: 5
# conditioning
conditioning_pt: False
conditioning_eta: False
conditioning_mass: True
conditioning_mass: False
conditioning_num_particles: False
conditioning_jet_type: True

callbacks:
ema:
Expand All @@ -44,18 +51,25 @@ callbacks:
start_step: 0
save_ema_weights_in_callback_state: True
evaluate_ema_weights_instead: True
jetclass_eval:
every_n_epochs: 50 # evaluate every n epochs
additional_eval_epochs: [1, 30, 75] # evaluate at these epochs as well
num_jet_samples: 50_000 # jet samples to generate
jetclass_eval_test:
num_jet_samples: 200_000 # jet samples to generate

#early_stopping:
# monitor: "val/loss"
# patience: 2000
# mode: "min"

task_name: "jetclass_flow_matching"
task_name: "jetclass_cond_jettype"

logger:
wandb:
tags: ${tags}
group: "jetclass_flow_matching"
group: "flow_matching_jetclass"
name: ${task_name}
comet:
experiment_name: ${task_name}
project_name: "flow-matching"
18 changes: 12 additions & 6 deletions configs/experiment/jetclass_dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# python train.py experiment=jetclass

defaults:
- override /data: jetclass.yaml
- override /data: jetclass_dev.yaml
- override /model: flow_matching.yaml
- override /callbacks: jetclass.yaml
- override /trainer: gpu.yaml
Expand All @@ -17,8 +17,9 @@ defaults:



tags: ["flow_matching", "JetClass", "uncond", "dev", "debug"]
run_note: "This is a note"
tags: ["flow_matching", "JetClass", "dev", "debug"]

run_note: "Test run with conditioning on jet type"

seed: 12345

Expand All @@ -29,17 +30,22 @@ trainer:

model:
num_particles: 128
global_cond_dim: 0 # needs to be calculated when using conditioning
global_cond_dim: 2 # needs to be calculated when using conditioning (= number of jet types + additional conditioning variables)
local_cond_dim: 0

data:
number_of_used_jets: 200
# preprocessing
number_of_used_jets: 30_000
use_custom_eta_centering: True # this means we are using eta_rel = eta_particle - eta_jet
remove_etadiff_tails: True # remove tracks with | eta_rel | > 1
normalize: True
normalize_sigma: 1
# conditioning
conditioning_pt: False
conditioning_eta: False
conditioning_mass: False
conditioning_num_particles: False
conditioning_jet_type: True

callbacks:
ema:
Expand All @@ -53,7 +59,7 @@ callbacks:
additional_eval_epochs: [] # evaluate at these epochs as well
num_jet_samples: 100 # jet samples to generate
jetclass_eval_test:
num_jet_samples: 6000 # jet samples to generate
num_jet_samples: 100 # jet samples to generate

#early_stopping:
# monitor: "val/loss"
Expand Down
11 changes: 11 additions & 0 deletions configs/plotting/labels.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
latex_labels:
label_QCD: QCD
label_Hbb: $H(\rightarrow b\bar{b})$
label_Hcc: $H(\rightarrow c\bar{c})$
label_Hgg: $H(\rightarrow gg)$
label_H4q: $H(\rightarrow 4q)$
label_Hqql: $H(\rightarrow l \nu qq')$
label_Zqq: $Z(\rightarrow q\bar{q})$
label_Wqq: $W(\rightarrow q\bar{q})$
label_Tbqq: $t(\rightarrow b qq')$
label_Tbl: $t(\rightarrow b l \nu)$
Loading

0 comments on commit a385dc8

Please sign in to comment.