Skip to content

Commit

Permalink
observables update
Browse files Browse the repository at this point in the history
  • Loading branch information
luigifvr committed Apr 6, 2023
1 parent 3e8b465 commit e8a7ace
Showing 1 changed file with 58 additions and 28 deletions.
86 changes: 58 additions & 28 deletions src/loaders/calo_inn.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def load(params: dict) -> list[DiscriminatorData]:
)

observables = []
if dataset['level']=='low' and dataset['normalize']==False:
observables = compute_energies(test_true, test_fake)
if dataset['level']=='low':
observables = compute_observables(test_true, test_fake)
datasets.append(DiscriminatorData(
label = dataset['label'],
suffix = dataset['suffix'],
Expand Down Expand Up @@ -180,42 +180,72 @@ def create_data_high(data_path, dataset_list, **kwargs):
ret = torch.cat((ret, incident_energy), 1)
return ret.numpy()

def compute_energies(true_data: np.ndarray, fake_data: np.ndarray) -> list[Observable]:
def compute_observables(true_data: np.ndarray, fake_data: np.ndarray) -> list[Observable]:
observables = []
en_0_true = np.sum(true_data[:, :288], axis=1)
en_0_fake = np.sum(fake_data[:, :288], axis=1)
en_1_true = np.sum(true_data[:,288:432], axis=1)
en_1_fake = np.sum(fake_data[:,288:432], axis=1)
en_2_true = np.sum(true_data[:,432:504], axis=1)
en_2_fake = np.sum(fake_data[:,432:504], axis=1)
phi0_true = center_of_energy(true_data[:,:288], 0, 'phi')
phi1_true = center_of_energy(true_data[:,288:432], 1, 'phi')
phi2_true = center_of_energy(true_data[:,432:504], 2, 'phi')

phi0_fake = center_of_energy(fake_data[:,:288], 0, 'phi')
phi1_fake = center_of_energy(fake_data[:,288:432], 1, 'phi')
phi2_fake = center_of_energy(fake_data[:,432:504], 2, 'phi')

sparsity0_true = layer_sparsity(true_data[:,:288], 0.0)
sparsity1_true = layer_sparsity(true_data[:,288:432], 0.0)
sparsity2_true = layer_sparsity(true_data[:,432:504], 0.0)

sparsity0_fake = layer_sparsity(fake_data[:,:288], 0.0)
sparsity1_fake = layer_sparsity(fake_data[:,288:432], 0.0)
sparsity2_fake = layer_sparsity(fake_data[:,432:504], 0.0)

observables.extend([
Observable(
true_data = en_0_true,
fake_data = en_0_fake,
tex_label = f'E0',
bins = np.logspace(-5, -1, 100),
xscale = 'log',
true_data = phi0_true,
fake_data = phi0_fake,
tex_label = f'phi0',
bins = np.linspace(-125, 125, 50),
xscale = 'linear',
yscale = 'log',
unit = 'GeV'
),
Observable(
true_data = en_1_true,
fake_data = en_1_fake,
tex_label = f'E1',
bins = np.logspace(-4, 0, 100),
xscale = 'log',
Observable(
true_data = phi1_true,
fake_data = phi1_fake,
tex_label = f'phi1',
bins = np.linspace(-125, 125, 50),
xscale = 'linear',
yscale = 'log',
unit = 'GeV'
),
Observable(
true_data = en_2_true,
fake_data = en_2_fake,
tex_label = f'E2',
bins = np.logspace(-5, -1, 100),
xscale = 'log',
true_data = phi2_true,
fake_data = phi2_fake,
tex_label = f'phi2',
bins = np.linspace(-125, 125, 50),
xscale = 'linear',
yscale = 'log',
unit = 'GeV'
),
Observable(
true_data = sparsity0_true,
fake_data = sparsity0_fake,
tex_label = f'sparsity0',
bins = np.linspace(0, 1, 20),
xscale = 'linear',
yscale = 'linear',
),
Observable(
true_data = sparsity1_true,
fake_data = sparsity1_fake,
tex_label = f'sparsity1',
bins = np.linspace(0, 1, 20),
xscale = 'linear',
yscale = 'linear',
),
Observable(
true_data = sparsity2_true,
fake_data = sparsity2_fake,
tex_label = f'sparsity2',
bins = np.linspace(0, 1, 20),
xscale = 'linear',
yscale = 'linear',
),
])

Expand Down

0 comments on commit e8a7ace

Please sign in to comment.