diff --git a/src/loaders/prec_inn.py b/src/loaders/prec_inn.py index 39ef27c..e5886e6 100644 --- a/src/loaders/prec_inn.py +++ b/src/loaders/prec_inn.py @@ -36,8 +36,8 @@ def load(params: dict) -> list[DiscriminatorData]: multiplicity_fake = np.sum(fake_momenta[:,:,0] != 0., axis=1) subsets = [ - {"multiplicity": 3, "label": "$Z+1j$", "suffix": "z1j"}, - {"multiplicity": 4, "label": "$Z+2j$", "suffix": "z2j"}, + #{"multiplicity": 3, "label": "$Z+1j$", "suffix": "z1j"}, + #{"multiplicity": 4, "label": "$Z+2j$", "suffix": "z2j"}, {"multiplicity": 5, "label": "$Z+3j$", "suffix": "z3j"}, ] datasets = [] @@ -61,12 +61,14 @@ def load(params: dict) -> list[DiscriminatorData]: "append_mass": params.get("append_mass", False), "append_delta_r": params.get("append_delta_r", False) } - pp_train_true = compute_preprocessing(train_true, **preproc_kwargs) + pp_train_all = compute_preprocessing( + np.concatenate((train_true, train_fake), axis=0), **preproc_kwargs + ) datasets.append(DiscriminatorData( label = subset["label"], suffix = subset["suffix"], - dim = pp_train_true.shape[1], - train_true = pp_train_true, + dim = pp_train_all.shape[1], + train_true = compute_preprocessing(train_true, **preproc_kwargs), train_fake = compute_preprocessing(train_fake, **preproc_kwargs), test_true = compute_preprocessing(test_true, **preproc_kwargs), test_fake = compute_preprocessing(test_fake, **preproc_kwargs), @@ -122,11 +124,12 @@ def compute_preprocessing( input_obs.append(mass) if append_delta_r: + drinv = lambda x: np.minimum(1/(x+1e-7), 20) if mult > 3: - input_obs.append(dr(obs.phi[:,2], obs.phi[:,3], obs.eta[:,2], obs.eta[:,3])) + input_obs.append(drinv(dr(obs.phi[:,2], obs.phi[:,3], obs.eta[:,2], obs.eta[:,3]))) if mult > 4: - input_obs.append(dr(obs.phi[:,3], obs.phi[:,4], obs.eta[:,3], obs.eta[:,4])) - input_obs.append(dr(obs.phi[:,2], obs.phi[:,4], obs.eta[:,2], obs.eta[:,4])) + input_obs.append(drinv(dr(obs.phi[:,3], obs.phi[:,4], obs.eta[:,3], obs.eta[:,4]))) + input_obs.append(drinv(dr(obs.phi[:,2], obs.phi[:,4], obs.eta[:,2], obs.eta[:,4]))) data_preproc = np.stack(input_obs, axis=1) if "means" not in norm: @@ -181,17 +184,17 @@ def compute_observables(true_data: np.ndarray, fake_data: np.ndarray) -> list[Ob ), unit = "GeV" )) - observables.append(Observable( - true_data = obs_two_true[(0,1)].m, - fake_data = obs_two_fake[(0,1)].m, - tex_label = r"M_{\mu\mu}", - bins = np.linspace( - np.quantile(obs_two_true[(0,1)].m, 0.005), - np.quantile(obs_two_true[(0,1)].m, 0.995), - 50 - ), - unit = "GeV" - )) + observables.append(Observable( + true_data = obs_two_true[(0,1)].m, + fake_data = obs_two_fake[(0,1)].m, + tex_label = r"M_{\mu\mu}", + bins = np.linspace( + np.quantile(obs_two_true[(0,1)].m, 0.005), + np.quantile(obs_two_true[(0,1)].m, 0.995), + 50 + ), + unit = "GeV" + )) for i, j in [(2,3), (2,4), (3,4)]: if i > mult-1 or j > mult-1: continue