diff --git a/src/loaders/calo_inn.py b/src/loaders/calo_inn.py index 8fe8c5b..a25cd8e 100644 --- a/src/loaders/calo_inn.py +++ b/src/loaders/calo_inn.py @@ -60,8 +60,8 @@ def load(params: dict) -> list[DiscriminatorData]: ) observables = [] - if dataset['level']=='low' and dataset['normalize']==False: - observables = compute_energies(test_true, test_fake) + if dataset['level']=='low': + observables = compute_observables(test_true, test_fake) datasets.append(DiscriminatorData( label = dataset['label'], suffix = dataset['suffix'], @@ -180,42 +180,72 @@ def create_data_high(data_path, dataset_list, **kwargs): ret = torch.cat((ret, incident_energy), 1) return ret.numpy() -def compute_energies(true_data: np.ndarray, fake_data: np.ndarray) -> list[Observable]: +def compute_observables(true_data: np.ndarray, fake_data: np.ndarray) -> list[Observable]: observables = [] - en_0_true = np.sum(true_data[:, :288], axis=1) - en_0_fake = np.sum(fake_data[:, :288], axis=1) - en_1_true = np.sum(true_data[:,288:432], axis=1) - en_1_fake = np.sum(fake_data[:,288:432], axis=1) - en_2_true = np.sum(true_data[:,432:504], axis=1) - en_2_fake = np.sum(fake_data[:,432:504], axis=1) + phi0_true = center_of_energy(true_data[:,:288], 0, 'phi') + phi1_true = center_of_energy(true_data[:,288:432], 1, 'phi') + phi2_true = center_of_energy(true_data[:,432:504], 2, 'phi') + + phi0_fake = center_of_energy(fake_data[:,:288], 0, 'phi') + phi1_fake = center_of_energy(fake_data[:,288:432], 1, 'phi') + phi2_fake = center_of_energy(fake_data[:,432:504], 2, 'phi') + + sparsity0_true = layer_sparsity(true_data[:,:288], 0.0) + sparsity1_true = layer_sparsity(true_data[:,288:432], 0.0) + sparsity2_true = layer_sparsity(true_data[:,432:504], 0.0) + + sparsity0_fake = layer_sparsity(fake_data[:,:288], 0.0) + sparsity1_fake = layer_sparsity(fake_data[:,288:432], 0.0) + sparsity2_fake = layer_sparsity(fake_data[:,432:504], 0.0) observables.extend([ Observable( - true_data = en_0_true, - fake_data = en_0_fake, - tex_label = f'E0', - bins = np.logspace(-5, -1, 100), - xscale = 'log', + true_data = phi0_true, + fake_data = phi0_fake, + tex_label = f'phi0', + bins = np.linspace(-125, 125, 50), + xscale = 'linear', yscale = 'log', - unit = 'GeV' ), - Observable( - true_data = en_1_true, - fake_data = en_1_fake, - tex_label = f'E1', - bins = np.logspace(-4, 0, 100), - xscale = 'log', + Observable( + true_data = phi1_true, + fake_data = phi1_fake, + tex_label = f'phi1', + bins = np.linspace(-125, 125, 50), + xscale = 'linear', yscale = 'log', - unit = 'GeV' ), Observable( - true_data = en_2_true, - fake_data = en_2_fake, - tex_label = f'E2', - bins = np.logspace(-5, -1, 100), - xscale = 'log', + true_data = phi2_true, + fake_data = phi2_fake, + tex_label = f'phi2', + bins = np.linspace(-125, 125, 50), + xscale = 'linear', yscale = 'log', - unit = 'GeV' + ), + Observable( + true_data = sparsity0_true, + fake_data = sparsity0_fake, + tex_label = f'sparsity0', + bins = np.linspace(0, 1, 20), + xscale = 'linear', + yscale = 'linear', + ), + Observable( + true_data = sparsity1_true, + fake_data = sparsity1_fake, + tex_label = f'sparsity1', + bins = np.linspace(0, 1, 20), + xscale = 'linear', + yscale = 'linear', + ), + Observable( + true_data = sparsity2_true, + fake_data = sparsity2_fake, + tex_label = f'sparsity2', + bins = np.linspace(0, 1, 20), + xscale = 'linear', + yscale = 'linear', ), ])