Skip to content

Commit

Permalink
make plots look nice, new plots for disc vs binn
Browse files Browse the repository at this point in the history
  • Loading branch information
theoheimel committed May 2, 2023
1 parent b6b0e35 commit 592a624
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 54 deletions.
2 changes: 1 addition & 1 deletion src/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def main():
print(f" Classifier score: {clf_score:.7f}")

print(" Creating plots")
lab_def = ["Comb.", "True", "Gen."]
lab_def = ["Comb", "Truth", "Gen"]
labels = params.get('w_labels', lab_def)
add_comb = params.get('add_w_comb', True)
plots = Plots(
Expand Down
16 changes: 9 additions & 7 deletions src/loaders/prec_inn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from ..dataset import DiscriminatorData
from ..observable import Observable

PARTICLE_NAMES = [r"\mu_1", r"\mu_2", "j_1", "j_2", "j_3"]

def load(params: dict) -> list[DiscriminatorData]:
"""
Loads the training, test and validation data (truth samples and generated samples)
Expand Down Expand Up @@ -171,7 +173,7 @@ def compute_observables(true_data: np.ndarray, fake_data: np.ndarray) -> list[Ob
Observable(
true_data = obs_one_true[i].pt,
fake_data = obs_one_fake[i].pt,
tex_label = f"p_{{T,{i}}}",
tex_label = f"p_{{T,{PARTICLE_NAMES[i]}}}",
bins = np.linspace(
np.min(obs_one_true[i].pt),
np.quantile(obs_one_true[i].pt, 0.99),
Expand All @@ -183,15 +185,15 @@ def compute_observables(true_data: np.ndarray, fake_data: np.ndarray) -> list[Ob
Observable(
true_data = obs_one_true[i].eta,
fake_data = obs_one_fake[i].eta,
tex_label = f"\\eta_{i}",
tex_label = f"\\eta_{{{PARTICLE_NAMES[i]}}}",
bins = np.linspace(-6, 6, 50),
)
])
if i >= 2:
observables.append(Observable(
true_data = obs_one_true[i].m,
fake_data = obs_one_fake[i].m,
tex_label = f"M_{i}",
tex_label = f"M_{{{PARTICLE_NAMES[i]}}}",
bins = np.linspace(
np.quantile(obs_one_true[i].m, 0.005),
np.quantile(obs_one_true[i].m, 0.995),
Expand All @@ -217,19 +219,19 @@ def compute_observables(true_data: np.ndarray, fake_data: np.ndarray) -> list[Ob
Observable(
true_data = obs_two_true[(i,j)].delta_r,
fake_data = obs_two_fake[(i,j)].delta_r,
tex_label = f"\\Delta R_{{{i},{j}}}",
bins = np.linspace(0, 12, 50),
tex_label = f"\\Delta R_{{{PARTICLE_NAMES[i]},{PARTICLE_NAMES[j]}}}",
bins = np.linspace(0, 8, 50),
),
Observable(
true_data = obs_two_true[(i,j)].delta_eta,
fake_data = obs_two_fake[(i,j)].delta_eta,
tex_label = f"\\Delta \\eta_{{{i},{j}}}",
tex_label = f"\\Delta \\eta_{{{PARTICLE_NAMES[i]},{PARTICLE_NAMES[j]}}}",
bins = np.linspace(-10, 10, 50),
),
Observable(
true_data = obs_two_true[(i,j)].delta_phi,
fake_data = obs_two_fake[(i,j)].delta_phi,
tex_label = f"\\Delta \\phi_{{{i},{j}}}",
tex_label = f"\\Delta \\phi_{{{PARTICLE_NAMES[i]},{PARTICLE_NAMES[j]}}}",
bins = np.linspace(-np.pi, np.pi, 50),
)
])
Expand Down
Loading

0 comments on commit 592a624

Please sign in to comment.