generated from ashleve/lightning-hydra-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
jetclass_classifier_hl.yaml
82 lines (70 loc) · 1.94 KB
/
jetclass_classifier_hl.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# @package _global_
# to execute this experiment run:
# python train.py experiment=jetclass_classifier
defaults:
- override /data: classifier_data_jetclass.yaml
- override /model: mlp_classifier.yaml
# - override /model: particlenet_classifier.yaml
# - override /model: particlenet_lite_classifier.yaml
# - override /model: ParT_classifier.yaml
- override /callbacks: jetclass_classifier.yaml
- override /trainer: gpu.yaml
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
# add here checkpoint to continue training
# ckpt_path: XXX/checkpoints/last-EMA.ckpt
tags: ["fm-classifier_test", "JetClass", "ClassifierTest"]
run_note: ""
seed: 122
vars:
epochs: 200
warmup: 5
val_check_interval: null
data:
# batch_size: 2048 # ParticleNet-Lite
batch_size: 256 # ParticleNet
# train_val_test_split: [0.8, 0.1, 0.1]
# kin_only: true
# set_energy_equal_to_p: true
hl_features_list:
- tau1
- tau2
- tau3
# - tau21
# - tau32
# - mass
data_file: XXX.h5
# number_of_jets: 1000
# debug_sim_only: true
# debug_sim_gen_fraction: 0.8
used_flavor: Tbqq
# setting load_weights_from will load the weights from the given checkpoint, but start training from scratch
# load_weights_from: XXX.ckpt
model:
use_hl_features: true
net_config:
fc_params:
- [64, 0.1]
- [128, 0.1]
- [128, 0.1]
- [64, 0.1]
input_dim: 3
optimizer:
lr: 0.005
scheduler:
warmup: ${vars.warmup}
max_iters: ${vars.epochs}
trainer:
min_epochs: 1
max_epochs: ${vars.epochs}
val_check_interval: ${vars.val_check_interval}
# gradient_clip_val: 0.5 # ParticleNet 0.02, ParticleNet-Lite 0.1
task_name: "jetclass_classifier"
logger:
wandb:
tags: ${tags}
group: "flow_matching_jetclass"
name: ${task_name}
comet:
experiment_name: null
project_name: "flow-matching-classifierTest"