Skip to content

Commit

Permalink
proposal consistent plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
luigifvr committed Apr 4, 2023
1 parent 826cbe2 commit 1e7ae70
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 52 deletions.
15 changes: 10 additions & 5 deletions params/calo_inn.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
run_name: calo_inn
run_name: calo_inn_p2
dtype: float64


#Dataset
loader_module: calo_inn
loader_params:
geant_file: /remote/gpu06/favaro/discriminator-metric/data/calo_cls_geant/full_cls_eplus.hdf5
generated_file: /remote/gpu06/favaro/discriminator-metric/data/calo_bay_samples/samples_eplus.hdf5
geant_file: /remote/gpu06/favaro/discriminator-metric/data/calo_cls_geant/full_cls_piplus.hdf5
generated_file: /remote/gpu06/favaro/discriminator-metric/data/calo_bay_samples/samples_piplus.hdf5
add_log_energy: True
add_log_layer_ens: True
add_logit_step: False
Expand All @@ -16,9 +16,9 @@ loader_params:

# Model
activation: leaky_relu
negative_slope: 0.2
negative_slope: 0.01
dropout: 0.0
layers: 2
layers: 3
hidden_size: 512

# Training
Expand All @@ -40,3 +40,8 @@ checkpoint_interval: 5
bayesian_samples: 5
#lower_cluster_thresholds: [0.01, 0.1]
#upper_cluster_thresholds: [0.9, 0.99]

#Plotting
w_labels: [ placeh, Truth, Gen.]
add_w_comb: False

7 changes: 6 additions & 1 deletion src/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,18 @@ def main():
print(f" Classifier score: {clf_score:.7f}")

print(" Creating plots")
lab_def = ["Comb.", "True", "Gen."]
labels = params.get('w_labels', lab_def)
add_comb = params.get('add_w_comb', True)
plots = Plots(
data.observables,
weights_true,
weights_fake,
training.losses,
data.label,
data.test_logw
labels,
add_comb,
data.test_logw,
)
print(" Plotting losses")
plots.plot_losses(doc.add_file(f"losses_{data.suffix}.pdf"))
Expand Down
36 changes: 22 additions & 14 deletions src/loaders/calo_inn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,27 +73,34 @@ def create_data(data_path, dataset_list, **kwargs):
lay_0 = f.get('layer_0')[:] / 1e5
lay_1 = f.get('layer_1')[:] / 1e5
lay_2 = f.get('layer_2')[:] / 1e5
data = np.concatenate((lay_0.reshape(-1, 288), lay_1.reshape(-1, 144), lay_2.reshape(-1, 72)), axis=1)

torch_dtype = torch.get_default_dtype()
lay_0 = torch.tensor(lay_0).to(torch_dtype)
lay_1 = torch.tensor(lay_1).to(torch_dtype)
lay_2 = torch.tensor(lay_2).to(torch_dtype)
en_test = torch.tensor(en_test).to(torch_dtype)

en0_t = np.sum(data[:, :288], axis=1, keepdims=True)
en1_t = np.sum(data[:, 288:432], axis=1, keepdims=True)
en2_t = np.sum(data[:, 432:], axis=1, keepdims=True)
data = torch.cat((lay_0.reshape(-1, 288), lay_1.reshape(-1, 144), lay_2.reshape(-1, 72)), axis=1)

en0_t = torch.sum(data[:, :288], axis=1, keepdims=True)
en1_t = torch.sum(data[:, 288:432], axis=1, keepdims=True)
en2_t = torch.sum(data[:, 432:], axis=1, keepdims=True)

if dataset_list['normalize']:
data[:, :288] /= en0_t + 1e-16
data[:, 288:432] /= en1_t + 1e-16
data[:, 432:] /= en2_t + 1e-16

if kwargs['add_log_energy']:
data = np.concatenate((data, np.log10(en_test*10).reshape(-1, 1)), axis=1)
data = torch.cat((data, torch.log10(en_test*10).reshape(-1, 1)), axis=1)
#data = np.nan_to_num(data, posinf=0, neginf=0)

en0_t = np.log10(en0_t + 1e-8) + 2.
en1_t = np.log10(en1_t + 1e-8) +2.
en2_t = np.log10(en2_t + 1e-8) +2.
en0_t = torch.log10(en0_t + 1e-8) + 2.
en1_t = torch.log10(en1_t + 1e-8) +2.
en2_t = torch.log10(en2_t + 1e-8) +2.

if kwargs['add_log_layer_ens']:
data = np.concatenate((data, en0_t, en1_t, en2_t), axis=1)
data = torch.cat((data, en0_t, en1_t, en2_t), axis=1)
if kwargs['add_logit_step']:
raise ValueError('Not implemented yet')
return data
Expand All @@ -105,12 +112,13 @@ def create_data_high(data_path, dataset_list, **kwargs):
lay_0 = f.get('layer_0')[:] / 1e5
lay_1 = f.get('layer_1')[:] / 1e5
lay_2 = f.get('layer_2')[:] / 1e5

incident_energy = torch.log10(torch.tensor(en_test)*10.)
torch_dtype = torch.get_default_dtype()

incident_energy = torch.log10(torch.tensor(en_test).to(torch_dtype)*10.)
# scale them back to MeV
layer0 = torch.tensor(lay_0) * 1e5
layer1 = torch.tensor(lay_1) * 1e5
layer2 = torch.tensor(lay_2) * 1e5
layer0 = torch.tensor(lay_0).to(torch_dtype) * 1e5
layer1 = torch.tensor(lay_1).to(torch_dtype) * 1e5
layer2 = torch.tensor(lay_2).to(torch_dtype) * 1e5
layer0 = to_np_thres(layer0.view(layer0.shape[0], -1), cut)
layer1 = to_np_thres(layer1.view(layer1.shape[0], -1), cut)
layer2 = to_np_thres(layer2.view(layer2.shape[0], -1), cut)
Expand Down
96 changes: 64 additions & 32 deletions src/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def __init__(
weights_fake: np.ndarray,
losses: dict,
title: str,
log_gen_weights: Optional[np.ndarray] = None
labels_w_hist: list[str],
add_comb: bool,
log_gen_weights: Optional[np.ndarray] = None,
):
"""
Initializes the plotting pipeline with the data to be plotted.
Expand All @@ -38,27 +40,36 @@ def __init__(
losses: Dictionary with loss terms and learning rate as a function of the epoch
title: Title added in all the plots
log_gen_weights: For Bayesian generators: sampled log weights
labels: Labels of weight histograms
add_comb: add combined weights hist line
"""
self.observables = observables
self.bayesian = len(weights_true.shape) == 2
self.true_mask = np.all(np.isfinite(
weights_true if self.bayesian else weights_true[:,None]
), axis=1)
self.fake_mask = np.all(np.isfinite(
weights_fake if self.bayesian else weights_fake[:,None]
), axis=1)
self.weights_true = weights_true[self.true_mask]
self.weights_fake = weights_fake[self.fake_mask]
self.weights_true, self.weights_fake = self.process_weights(weights_true, weights_fake)
self.losses = losses
self.title = title
self.log_gen_weights = log_gen_weights
self.labels_w_hist = labels_w_hist
self.add_comb = add_comb
self.eps = 1.0e-10

plt.rc("font", family="serif", size=16)
plt.rc("axes", titlesize="medium")
plt.rc("text.latex", preamble=r"\usepackage{amsmath}")
plt.rc("text", usetex=True)
self.colors = [f"C{i}" for i in range(10)]

def process_weights(self, weights_true, weights_fake):
w_comb = np.concatenate((weights_true, weights_fake), axis=0)
self.p_low = np.percentile(w_comb[w_comb!=0], 0.5)
self.p_high = np.percentile(w_comb[w_comb!=np.inf], 99.5)

weights_true[weights_true >= self.p_high] = self.p_high
weights_fake[weights_fake <= self.p_low] = self.p_low

weights_true[weights_true <= self.p_low] = self.p_low
weights_fake[weights_fake >= self.p_high] = self.p_high
return weights_true, weights_fake

def plot_losses(self, file: str):
"""
Expand Down Expand Up @@ -115,7 +126,7 @@ def plot_single_loss(
labels: Labels of the loss curves
yscale: Y axis scale, "linear" or "log"
"""
fig, ax = plt.subplots(figsize=(4,3.5))
fig, ax = plt.subplots(figsize=(5,5))
for i, (curve, label) in enumerate(zip(curves, labels)):
epochs = np.arange(1, len(curve)+1)
ax.plot(epochs, curve, label=label)
Expand Down Expand Up @@ -185,29 +196,31 @@ def plot_weight_hist(self, file: str):
file: Output file name
"""
with PdfPages(file) as pdf:
clean_array = lambda a: a[np.isfinite(a)]
wmin = min(
np.min(self.weights_true[self.weights_true != 0]),
np.min(self.weights_fake[self.weights_fake != 0])
np.min(self.weights_true),
np.min(self.weights_fake)
)
wmax = max(np.max(self.weights_true), np.max(self.weights_fake))
self.plot_single_weight_hist(
pdf,
bins=np.linspace(0, 3, 50),
xscale="linear",
yscale="linear"
yscale="linear",
secax=True,
)
self.plot_single_weight_hist(
pdf,
bins=np.logspace(np.log10(wmin), np.log10(wmax), 50),
xscale="log",
yscale="log"
bins=np.logspace(np.log10(self.p_low-self.eps), np.log10(self.p_high+self.eps), 50),
xscale="symlog",
yscale="log",
secax=False,
)
self.plot_single_weight_hist(
pdf,
bins=np.logspace(-2, 1, 50),
xscale="log",
yscale="log"
yscale="log",
secax=False,
)


Expand All @@ -216,7 +229,8 @@ def plot_single_weight_hist(
pdf: PdfPages,
bins: np.ndarray,
xscale: str,
yscale: str
yscale: str,
secax: bool
):
"""
Plots a single weight histogram.
Expand All @@ -226,6 +240,7 @@ def plot_single_weight_hist(
bins: Numpy array with the bin boundaries
xscale: X axis scale, "linear" or "log"
yscale: Y axis scale, "linear" or "log"
secax: secondary axes for linear plot
"""
weights_combined = np.concatenate((self.weights_true, self.weights_fake), axis=0)
if self.bayesian:
Expand Down Expand Up @@ -269,36 +284,53 @@ def plot_single_weight_hist(
y_combined_err = None

fig, ax = plt.subplots(figsize=(4, 3.5))
self.hist_line(
ax,
bins,
y_combined / np.sum(y_combined),
y_combined_err / np.sum(y_combined) if y_combined_err is not None else None,
label = "Comb",
color = self.colors[0]
)
if self.add_comb:
self.hist_line(
ax,
bins,
y_combined / np.sum(y_combined),
y_combined_err / np.sum(y_combined) if y_combined_err is not None else None,
label = self.labels_w_hist[0],
color = self.colors[0]
)
self.hist_line(
ax,
bins,
y_true / np.sum(y_true),
y_true_err / np.sum(y_true) if y_true_err is not None else None,
label = "Truth",
label = self.labels_w_hist[1],
color = self.colors[1]
)
self.hist_line(
ax,
bins,
y_fake / np.sum(y_fake),
y_fake_err / np.sum(y_fake) if y_fake_err is not None else None,
label = "Gen",
label = self.labels_w_hist[2],
color = self.colors[2]
)
self.corner_text(ax, self.title, "right", "top")
ax.set_xlabel("weight")
ax.set_ylabel("normalized")
ax.set_xscale(xscale)
ax.set_xlabel("$w(x)$")
ax.set_ylabel("a.u.")
if xscale == 'symlog':
ax.set_xscale(xscale, linthresh=self.p_low)
else:
ax.set_xscale(xscale)
ax.set_yscale(yscale)
ax.set_xlim(bins[0], bins[-1])

#adding Delta
if secax:
def wtoD(x):
return x-1

def Dtow(x):
return x+1

secax = ax.secondary_xaxis('top', functions=(wtoD, Dtow))
secax.set_xlabel('$\Delta(x)$')
secax.tick_params()

if yscale == "linear":
ax.set_ylim(bottom=0)
ax.legend(frameon=False)
Expand Down

0 comments on commit 1e7ae70

Please sign in to comment.