Skip to content

Commit

Permalink
fix obs plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
luigifvr committed Apr 4, 2023
1 parent 1e7ae70 commit 882d5f4
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 20 deletions.
15 changes: 8 additions & 7 deletions params/calo_inn.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
run_name: calo_inn_p2
run_name: calo_inn_e2
dtype: float64

p_type: piplus

#Dataset
loader_module: calo_inn
loader_params:
p_type: pions # gammas, eplus, pions
geant_file: /remote/gpu06/favaro/discriminator-metric/data/calo_cls_geant/full_cls_piplus.hdf5
generated_file: /remote/gpu06/favaro/discriminator-metric/data/calo_bay_samples/samples_piplus.hdf5
add_log_energy: True
add_log_layer_ens: True
add_logit_step: False
add_cut: 0.0
train_split: 0.6
test_split: 0.2
train_split: 0.5
test_split: 0.3

# Model
activation: leaky_relu
Expand All @@ -38,10 +39,10 @@ checkpoint_interval: 5

# Evaluation
bayesian_samples: 5
#lower_cluster_thresholds: [0.01, 0.1]
#upper_cluster_thresholds: [0.9, 0.99]
lower_cluster_thresholds: [0.1]
upper_cluster_thresholds: [2.0]

#Plotting
w_labels: [ placeh, Truth, Gen.]
w_labels: [ placeh, Geant, Gen.]
add_w_comb: False

63 changes: 57 additions & 6 deletions src/loaders/calo_inn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,23 @@ def load(params: dict) -> list[DiscriminatorData]:
"""
datasets = []
p_type = params.get('p_type', None)
if p_type == 'pions':
p_lab = '$\pi^{+}$'
elif p_type == 'gammas':
p_lab = '$\gamma$'
elif p_type == 'eplus':
p_lab = '$e^{+}$'
preproc_kwargs = {
"add_log_energy": params.get("add_log_energy", False),
"add_log_layer_ens": params.get("add_log_layer_ens", False),
"add_logit_step": params.get("add_logit_step", False),
"add_cut": params.get("add_cut", 0.0),
}
datasets_list = [
{'level': 'low', 'normalize': True, 'label': 'Norm.', 'suffix': 'norm'},
{'level': 'low', 'normalize': False, 'label': 'Unnorm.', 'suffix': 'unnorm'},
{'level': 'high', 'normalize': False, 'label': 'High', 'suffix': 'high'},
{'level': 'low', 'normalize': False, 'label': p_lab+' Unnorm.', 'suffix': 'unnorm'},
{'level': 'low', 'normalize': True, 'label': p_lab+' Norm.', 'suffix': 'norm'},
{'level': 'high', 'normalize': False, 'label': p_lab+' High', 'suffix': 'high'},
]

for dataset in datasets_list:
Expand All @@ -51,7 +58,10 @@ def load(params: dict) -> list[DiscriminatorData]:
params["train_split"],
params["test_split"]
)


observables = []
if dataset['level']=='low' and dataset['normalize']==False:
observables = compute_energies(test_true, test_fake)
datasets.append(DiscriminatorData(
label = dataset['label'],
suffix = dataset['suffix'],
Expand All @@ -62,7 +72,7 @@ def load(params: dict) -> list[DiscriminatorData]:
test_fake = test_fake,
val_true = val_true,
val_fake = val_fake,
observables = [],
observables = observables,
)
)
return datasets
Expand Down Expand Up @@ -103,7 +113,7 @@ def create_data(data_path, dataset_list, **kwargs):
data = torch.cat((data, en0_t, en1_t, en2_t), axis=1)
if kwargs['add_logit_step']:
raise ValueError('Not implemented yet')
return data
return data.numpy()

def create_data_high(data_path, dataset_list, **kwargs):
cut = kwargs['add_cut']
Expand Down Expand Up @@ -170,6 +180,47 @@ def create_data_high(data_path, dataset_list, **kwargs):
ret = torch.cat((ret, incident_energy), 1)
return ret.numpy()

def compute_energies(true_data: np.ndarray, fake_data: np.ndarray) -> list[Observable]:
observables = []
en_0_true = np.sum(true_data[:, :288], axis=1)
en_0_fake = np.sum(fake_data[:, :288], axis=1)
en_1_true = np.sum(true_data[:,288:432], axis=1)
en_1_fake = np.sum(fake_data[:,288:432], axis=1)
en_2_true = np.sum(true_data[:,432:504], axis=1)
en_2_fake = np.sum(fake_data[:,432:504], axis=1)

observables.extend([
Observable(
true_data = en_0_true,
fake_data = en_0_fake,
tex_label = f'E0',
bins = np.logspace(-5, -1, 100),
xscale = 'log',
yscale = 'log',
unit = 'GeV'
),
Observable(
true_data = en_1_true,
fake_data = en_1_fake,
tex_label = f'E1',
bins = np.logspace(-4, 0, 100),
xscale = 'log',
yscale = 'log',
unit = 'GeV'
),
Observable(
true_data = en_2_true,
fake_data = en_2_fake,
tex_label = f'E2',
bins = np.logspace(-5, -1, 100),
xscale = 'log',
yscale = 'log',
unit = 'GeV'
),
])

return observables

def split_data(
data: np.ndarray,
train_split: float,
Expand Down
15 changes: 8 additions & 7 deletions src/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def __init__(

def process_weights(self, weights_true, weights_fake):
w_comb = np.concatenate((weights_true, weights_fake), axis=0)
self.p_low = np.percentile(w_comb[w_comb!=0], 0.5)
self.p_high = np.percentile(w_comb[w_comb!=np.inf], 99.5)
self.p_low = np.percentile(w_comb[w_comb!=0], 0.1)
self.p_high = np.percentile(w_comb[w_comb!=np.inf], 99.9)

weights_true[weights_true >= self.p_high] = self.p_high
weights_fake[weights_fake <= self.p_low] = self.p_low
Expand Down Expand Up @@ -126,7 +126,7 @@ def plot_single_loss(
labels: Labels of the loss curves
yscale: Y axis scale, "linear" or "log"
"""
fig, ax = plt.subplots(figsize=(5,5))
fig, ax = plt.subplots(figsize=(4,3.5))
for i, (curve, label) in enumerate(zip(curves, labels)):
epochs = np.arange(1, len(curve)+1)
ax.plot(epochs, curve, label=label)
Expand Down Expand Up @@ -283,7 +283,7 @@ def plot_single_weight_hist(
y_combined = np.histogram(weights_combined, bins=bins)[0]
y_combined_err = None

fig, ax = plt.subplots(figsize=(4, 3.5))
fig, ax = plt.subplots(figsize=(4.5, 4.5))
if self.add_comb:
self.hist_line(
ax,
Expand Down Expand Up @@ -400,7 +400,7 @@ def plot_single_observable(self, pdf: PdfPages, observable: Observable):
if self.bayesian:
rw_hists = np.stack([
np.histogram(
observable.fake_data[self.fake_mask],
observable.fake_data,
bins = bins,
weights = self.weights_fake[:,i],
density = True
Expand All @@ -414,8 +414,9 @@ def plot_single_observable(self, pdf: PdfPages, observable: Observable):
np.quantile(rw_hists, 0.841, axis=1)
), axis=0)
else:
print(observable.fake_data.shape, self.weights_fake.shape)
rw_mean = np.histogram(
observable.fake_data[self.fake_mask],
observable.fake_data,
bins=bins,
weights=self.weights_fake
)[0]
Expand Down Expand Up @@ -505,7 +506,7 @@ def plot_single_clustering(
true_hist, _ = np.histogram(observable.true_data, bins=bins, density=True)
hists = [
np.histogram(
observable.fake_data[self.fake_mask][mask],
observable.fake_data[mask],
bins=bins,
density=True
)[0]
Expand Down

0 comments on commit 882d5f4

Please sign in to comment.