Skip to content

Commit

Permalink
better bayesian plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
theoheimel committed Mar 17, 2023
1 parent 1ae7b08 commit 9d33c2e
Showing 1 changed file with 83 additions and 54 deletions.
137 changes: 83 additions & 54 deletions src/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,7 @@ def plot_single_loss(
ax.plot(epochs, curve, label=label)
ax.set_xlabel("epoch")
ax.set_ylabel(ylabel)
ax.text(
x = 0.95,
y = 0.95,
s = self.title,
horizontalalignment = "right",
verticalalignment = "top",
transform = ax.transAxes
)
self.corner_text(ax, self.title, "right", "top")
ax.set_yscale(yscale)
if any(label is not None for label in labels):
ax.legend(loc="center right", frameon=False)
Expand Down Expand Up @@ -167,23 +160,14 @@ def plot_roc(self, file: str):

ax.set_xlabel(r"$\epsilon_S$")
ax.set_ylabel(r"$\epsilon_B$")
ax.text(
x = 0.05,
y = 0.95,
s = f"AUC = ${np.mean(auc):.3f} \\pm {np.std(auc):.3f}$"
self.corner_text(
ax,
f"AUC = ${np.mean(auc):.3f} \\pm {np.std(auc):.3f}$"
if self.bayesian else f"AUC = {auc:.3f}",
horizontalalignment = "left",
verticalalignment = "top",
transform = ax.transAxes
)
ax.text(
x = 0.95,
y = 0.05,
s = self.title,
horizontalalignment = "right",
verticalalignment = "bottom",
transform = ax.transAxes
"left",
"top"
)
self.corner_text(ax, self.title, "right", "bottom")
plt.savefig(file, bbox_inches="tight")
plt.close()

Expand Down Expand Up @@ -260,12 +244,28 @@ def plot_single_weight_hist(
)[0] for i in range(weights_combined.shape[1])
], axis=1)

y_true = np.mean(true_hists, axis=1)
y_true_err = np.std(true_hists, axis=1)
y_fake = np.mean(fake_hists, axis=1)
y_fake_err = np.std(fake_hists, axis=1)
y_combined = np.mean(combined_hists, axis=1)
y_combined_err = np.std(combined_hists, axis=1)
y_true = np.median(true_hists, axis=1)
y_true_err = np.stack((
np.quantile(true_hists, 0.159, axis=1),
np.quantile(true_hists, 0.841, axis=1)
), axis=0)
y_fake = np.median(fake_hists, axis=1)
y_fake_err = np.stack((
np.quantile(fake_hists, 0.159, axis=1),
np.quantile(fake_hists, 0.841, axis=1)
), axis=0)
y_combined = np.median(combined_hists, axis=1)
y_combined_err = np.stack((
np.quantile(combined_hists, 0.159, axis=1),
np.quantile(combined_hists, 0.841, axis=1)
), axis=0)

#y_true = np.mean(true_hists, axis=1)
#y_true_err = np.std(true_hists, axis=1)
#y_fake = np.mean(fake_hists, axis=1)
#y_fake_err = np.std(fake_hists, axis=1)
#y_combined = np.mean(combined_hists, axis=1)
#y_combined_err = np.std(combined_hists, axis=1)

else:
y_true = np.histogram(self.weights_true, bins=bins)[0]
Expand Down Expand Up @@ -300,19 +300,14 @@ def plot_single_weight_hist(
label = "Gen",
color = self.colors[2]
)
ax.text(
x = 0.95,
y = 0.95,
s = self.title,
horizontalalignment = "right",
verticalalignment = "top",
transform = ax.transAxes
)
self.corner_text(ax, self.title, "right", "top")
ax.set_xlabel("weight")
ax.set_ylabel("normalized")
ax.set_xscale(xscale)
ax.set_yscale(yscale)
ax.set_xlim(bins[0], bins[-1])
if yscale == "linear":
ax.set_ylim(bottom=0)
ax.legend(frameon=False)
plt.savefig(pdf, format="pdf", bbox_inches="tight")
plt.close()
Expand Down Expand Up @@ -373,14 +368,7 @@ def plot_weight_pulls(self, file: str):
label = "Gen",
color = self.colors[2]
)
ax.text(
x = 0.95,
y = 0.95,
s = self.title,
horizontalalignment = "right",
verticalalignment = "top",
transform = ax.transAxes
)
self.corner_text(ax, self.title, "right", "top")

ax.set_xlabel(r"$(\mu - 1) / \sigma$")
ax.set_ylabel("normalized")
Expand Down Expand Up @@ -423,8 +411,13 @@ def plot_single_observable(self, pdf: PdfPages, observable: Observable):
density = True
)[0] for i in range(self.weights_fake.shape[1])
], axis=1)
rw_mean = np.mean(rw_hists, axis=1)
rw_std = np.std(rw_hists, axis=1)
#rw_mean = np.mean(rw_hists, axis=1)
#rw_std = np.std(rw_hists, axis=1)
rw_mean = np.median(rw_hists, axis=1)
rw_std = np.stack((
np.quantile(rw_hists, 0.159, axis=1),
np.quantile(rw_hists, 0.841, axis=1)
), axis=0)
else:
rw_mean = np.histogram(
observable.fake_data[self.fake_mask],
Expand Down Expand Up @@ -503,7 +496,7 @@ def plot_single_clustering(
"""
bins = observable.bins
if self.bayesian:
weights_fake = np.mean(self.weights_fake, axis=1)
weights_fake = np.median(self.weights_fake, axis=1)
else:
weights_fake = self.weights_fake

Expand Down Expand Up @@ -538,6 +531,30 @@ def plot_single_clustering(
self.hist_plot(pdf, lines, bins, observable, show_ratios=False, show_weights=False)


def corner_text(
self,
ax: mpl.axes.Axes,
text: str,
horizontal_pos: str,
vertical_pos: str
):
ax.text(
x = 0.95 if horizontal_pos == "right" else 0.05,
y = 0.95 if vertical_pos == "top" else 0.05,
s = text,
horizontalalignment = horizontal_pos,
verticalalignment = vertical_pos,
transform = ax.transAxes
)
# Dummy line for automatic legend placement
plt.plot(
0.8 if horizontal_pos == "right" else 0.2,
0.8 if vertical_pos == "top" else 0.2,
transform=ax.transAxes,
color="none"
)


def hist_plot(
self,
pdf: PdfPages,
Expand Down Expand Up @@ -598,8 +615,12 @@ def hist_plot(
ratio = (line.y * scale) / (y_ref * ref_scale)
ratio_isnan = np.isnan(ratio)
if line.y_err is not None:
ratio_err = np.sqrt((line.y_err / line.y)**2)
ratio_err[ratio_isnan] = 0.
if len(line.y_err.shape) == 2:
ratio_err = (line.y_err * scale) / (y_ref * ref_scale)
ratio_err[:,ratio_isnan] = 0.
else:
ratio_err = np.sqrt((line.y_err / line.y)**2)
ratio_err[ratio_isnan] = 0.
else:
ratio_err = None
ratio[ratio_isnan] = 1.
Expand All @@ -608,6 +629,7 @@ def hist_plot(
axs[0].legend(frameon=False)
axs[0].set_ylabel("normalized")
axs[0].set_yscale(observable.yscale)
self.corner_text(axs[0], self.title, "right", "top")

if show_ratios:
axs[1].set_ylabel(r"$\frac{\mathrm{Model}}{\mathrm{Truth}}$")
Expand Down Expand Up @@ -673,26 +695,33 @@ def hist_line(
where = "post",
)
if y_err is not None:
if len(y_err.shape) == 2:
y_low = y_err[0]
y_high = y_err[1]
else:
y_low = y - y_err
y_high = y + y_err

ax.step(
bins,
dup_last(y + y_err),
dup_last(y_high),
color = color,
alpha = 0.5,
linewidth = 0.5,
where = "post"
)
ax.step(
bins,
dup_last(y - y_err),
dup_last(y_low),
color = color,
alpha = 0.5,
linewidth = 0.5,
where = "post"
)
ax.fill_between(
bins,
dup_last(y - y_err),
dup_last(y + y_err),
dup_last(y_low),
dup_last(y_high),
facecolor = color,
alpha = 0.3,
step = "post"
Expand Down

0 comments on commit 9d33c2e

Please sign in to comment.