Skip to content

Commit

Permalink
improve preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
theoheimel committed Mar 16, 2023
1 parent 865722d commit bcfa399
Showing 1 changed file with 22 additions and 19 deletions.
41 changes: 22 additions & 19 deletions src/loaders/prec_inn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def load(params: dict) -> list[DiscriminatorData]:
multiplicity_fake = np.sum(fake_momenta[:,:,0] != 0., axis=1)

subsets = [
{"multiplicity": 3, "label": "$Z+1j$", "suffix": "z1j"},
{"multiplicity": 4, "label": "$Z+2j$", "suffix": "z2j"},
#{"multiplicity": 3, "label": "$Z+1j$", "suffix": "z1j"},
#{"multiplicity": 4, "label": "$Z+2j$", "suffix": "z2j"},
{"multiplicity": 5, "label": "$Z+3j$", "suffix": "z3j"},
]
datasets = []
Expand All @@ -61,12 +61,14 @@ def load(params: dict) -> list[DiscriminatorData]:
"append_mass": params.get("append_mass", False),
"append_delta_r": params.get("append_delta_r", False)
}
pp_train_true = compute_preprocessing(train_true, **preproc_kwargs)
pp_train_all = compute_preprocessing(
np.concatenate((train_true, train_fake), axis=0), **preproc_kwargs
)
datasets.append(DiscriminatorData(
label = subset["label"],
suffix = subset["suffix"],
dim = pp_train_true.shape[1],
train_true = pp_train_true,
dim = pp_train_all.shape[1],
train_true = compute_preprocessing(train_true, **preproc_kwargs),
train_fake = compute_preprocessing(train_fake, **preproc_kwargs),
test_true = compute_preprocessing(test_true, **preproc_kwargs),
test_fake = compute_preprocessing(test_fake, **preproc_kwargs),
Expand Down Expand Up @@ -122,11 +124,12 @@ def compute_preprocessing(
input_obs.append(mass)

if append_delta_r:
drinv = lambda x: np.minimum(1/(x+1e-7), 20)
if mult > 3:
input_obs.append(dr(obs.phi[:,2], obs.phi[:,3], obs.eta[:,2], obs.eta[:,3]))
input_obs.append(drinv(dr(obs.phi[:,2], obs.phi[:,3], obs.eta[:,2], obs.eta[:,3])))
if mult > 4:
input_obs.append(dr(obs.phi[:,3], obs.phi[:,4], obs.eta[:,3], obs.eta[:,4]))
input_obs.append(dr(obs.phi[:,2], obs.phi[:,4], obs.eta[:,2], obs.eta[:,4]))
input_obs.append(drinv(dr(obs.phi[:,3], obs.phi[:,4], obs.eta[:,3], obs.eta[:,4])))
input_obs.append(drinv(dr(obs.phi[:,2], obs.phi[:,4], obs.eta[:,2], obs.eta[:,4])))

data_preproc = np.stack(input_obs, axis=1)
if "means" not in norm:
Expand Down Expand Up @@ -181,17 +184,17 @@ def compute_observables(true_data: np.ndarray, fake_data: np.ndarray) -> list[Ob
),
unit = "GeV"
))
observables.append(Observable(
true_data = obs_two_true[(0,1)].m,
fake_data = obs_two_fake[(0,1)].m,
tex_label = r"M_{\mu\mu}",
bins = np.linspace(
np.quantile(obs_two_true[(0,1)].m, 0.005),
np.quantile(obs_two_true[(0,1)].m, 0.995),
50
),
unit = "GeV"
))
observables.append(Observable(
true_data = obs_two_true[(0,1)].m,
fake_data = obs_two_fake[(0,1)].m,
tex_label = r"M_{\mu\mu}",
bins = np.linspace(
np.quantile(obs_two_true[(0,1)].m, 0.005),
np.quantile(obs_two_true[(0,1)].m, 0.995),
50
),
unit = "GeV"
))
for i, j in [(2,3), (2,4), (3,4)]:
if i > mult-1 or j > mult-1:
continue
Expand Down

0 comments on commit bcfa399

Please sign in to comment.