From 9d33c2e58d8e5d05cea10bd0de6c18f2b57773e0 Mon Sep 17 00:00:00 2001 From: Theo Heimel Date: Fri, 17 Mar 2023 17:01:53 +0100 Subject: [PATCH] better bayesian plotting --- src/plots.py | 137 +++++++++++++++++++++++++++++++-------------------- 1 file changed, 83 insertions(+), 54 deletions(-) diff --git a/src/plots.py b/src/plots.py index 2317014..ec80674 100644 --- a/src/plots.py +++ b/src/plots.py @@ -117,14 +117,7 @@ def plot_single_loss( ax.plot(epochs, curve, label=label) ax.set_xlabel("epoch") ax.set_ylabel(ylabel) - ax.text( - x = 0.95, - y = 0.95, - s = self.title, - horizontalalignment = "right", - verticalalignment = "top", - transform = ax.transAxes - ) + self.corner_text(ax, self.title, "right", "top") ax.set_yscale(yscale) if any(label is not None for label in labels): ax.legend(loc="center right", frameon=False) @@ -167,23 +160,14 @@ def plot_roc(self, file: str): ax.set_xlabel(r"$\epsilon_S$") ax.set_ylabel(r"$\epsilon_B$") - ax.text( - x = 0.05, - y = 0.95, - s = f"AUC = ${np.mean(auc):.3f} \\pm {np.std(auc):.3f}$" + self.corner_text( + ax, + f"AUC = ${np.mean(auc):.3f} \\pm {np.std(auc):.3f}$" if self.bayesian else f"AUC = {auc:.3f}", - horizontalalignment = "left", - verticalalignment = "top", - transform = ax.transAxes - ) - ax.text( - x = 0.95, - y = 0.05, - s = self.title, - horizontalalignment = "right", - verticalalignment = "bottom", - transform = ax.transAxes + "left", + "top" ) + self.corner_text(ax, self.title, "right", "bottom") plt.savefig(file, bbox_inches="tight") plt.close() @@ -260,12 +244,28 @@ def plot_single_weight_hist( )[0] for i in range(weights_combined.shape[1]) ], axis=1) - y_true = np.mean(true_hists, axis=1) - y_true_err = np.std(true_hists, axis=1) - y_fake = np.mean(fake_hists, axis=1) - y_fake_err = np.std(fake_hists, axis=1) - y_combined = np.mean(combined_hists, axis=1) - y_combined_err = np.std(combined_hists, axis=1) + y_true = np.median(true_hists, axis=1) + y_true_err = np.stack(( + np.quantile(true_hists, 0.159, axis=1), + np.quantile(true_hists, 0.841, axis=1) + ), axis=0) + y_fake = np.median(fake_hists, axis=1) + y_fake_err = np.stack(( + np.quantile(fake_hists, 0.159, axis=1), + np.quantile(fake_hists, 0.841, axis=1) + ), axis=0) + y_combined = np.median(combined_hists, axis=1) + y_combined_err = np.stack(( + np.quantile(combined_hists, 0.159, axis=1), + np.quantile(combined_hists, 0.841, axis=1) + ), axis=0) + + #y_true = np.mean(true_hists, axis=1) + #y_true_err = np.std(true_hists, axis=1) + #y_fake = np.mean(fake_hists, axis=1) + #y_fake_err = np.std(fake_hists, axis=1) + #y_combined = np.mean(combined_hists, axis=1) + #y_combined_err = np.std(combined_hists, axis=1) else: y_true = np.histogram(self.weights_true, bins=bins)[0] @@ -300,19 +300,14 @@ def plot_single_weight_hist( label = "Gen", color = self.colors[2] ) - ax.text( - x = 0.95, - y = 0.95, - s = self.title, - horizontalalignment = "right", - verticalalignment = "top", - transform = ax.transAxes - ) + self.corner_text(ax, self.title, "right", "top") ax.set_xlabel("weight") ax.set_ylabel("normalized") ax.set_xscale(xscale) ax.set_yscale(yscale) ax.set_xlim(bins[0], bins[-1]) + if yscale == "linear": + ax.set_ylim(bottom=0) ax.legend(frameon=False) plt.savefig(pdf, format="pdf", bbox_inches="tight") plt.close() @@ -373,14 +368,7 @@ def plot_weight_pulls(self, file: str): label = "Gen", color = self.colors[2] ) - ax.text( - x = 0.95, - y = 0.95, - s = self.title, - horizontalalignment = "right", - verticalalignment = "top", - transform = ax.transAxes - ) + self.corner_text(ax, self.title, "right", "top") ax.set_xlabel(r"$(\mu - 1) / \sigma$") ax.set_ylabel("normalized") @@ -423,8 +411,13 @@ def plot_single_observable(self, pdf: PdfPages, observable: Observable): density = True )[0] for i in range(self.weights_fake.shape[1]) ], axis=1) - rw_mean = np.mean(rw_hists, axis=1) - rw_std = np.std(rw_hists, axis=1) + #rw_mean = np.mean(rw_hists, axis=1) + #rw_std = np.std(rw_hists, axis=1) + rw_mean = np.median(rw_hists, axis=1) + rw_std = np.stack(( + np.quantile(rw_hists, 0.159, axis=1), + np.quantile(rw_hists, 0.841, axis=1) + ), axis=0) else: rw_mean = np.histogram( observable.fake_data[self.fake_mask], @@ -503,7 +496,7 @@ def plot_single_clustering( """ bins = observable.bins if self.bayesian: - weights_fake = np.mean(self.weights_fake, axis=1) + weights_fake = np.median(self.weights_fake, axis=1) else: weights_fake = self.weights_fake @@ -538,6 +531,30 @@ def plot_single_clustering( self.hist_plot(pdf, lines, bins, observable, show_ratios=False, show_weights=False) + def corner_text( + self, + ax: mpl.axes.Axes, + text: str, + horizontal_pos: str, + vertical_pos: str + ): + ax.text( + x = 0.95 if horizontal_pos == "right" else 0.05, + y = 0.95 if vertical_pos == "top" else 0.05, + s = text, + horizontalalignment = horizontal_pos, + verticalalignment = vertical_pos, + transform = ax.transAxes + ) + # Dummy line for automatic legend placement + plt.plot( + 0.8 if horizontal_pos == "right" else 0.2, + 0.8 if vertical_pos == "top" else 0.2, + transform=ax.transAxes, + color="none" + ) + + def hist_plot( self, pdf: PdfPages, @@ -598,8 +615,12 @@ def hist_plot( ratio = (line.y * scale) / (y_ref * ref_scale) ratio_isnan = np.isnan(ratio) if line.y_err is not None: - ratio_err = np.sqrt((line.y_err / line.y)**2) - ratio_err[ratio_isnan] = 0. + if len(line.y_err.shape) == 2: + ratio_err = (line.y_err * scale) / (y_ref * ref_scale) + ratio_err[:,ratio_isnan] = 0. + else: + ratio_err = np.sqrt((line.y_err / line.y)**2) + ratio_err[ratio_isnan] = 0. else: ratio_err = None ratio[ratio_isnan] = 1. @@ -608,6 +629,7 @@ def hist_plot( axs[0].legend(frameon=False) axs[0].set_ylabel("normalized") axs[0].set_yscale(observable.yscale) + self.corner_text(axs[0], self.title, "right", "top") if show_ratios: axs[1].set_ylabel(r"$\frac{\mathrm{Model}}{\mathrm{Truth}}$") @@ -673,9 +695,16 @@ def hist_line( where = "post", ) if y_err is not None: + if len(y_err.shape) == 2: + y_low = y_err[0] + y_high = y_err[1] + else: + y_low = y - y_err + y_high = y + y_err + ax.step( bins, - dup_last(y + y_err), + dup_last(y_high), color = color, alpha = 0.5, linewidth = 0.5, @@ -683,7 +712,7 @@ def hist_line( ) ax.step( bins, - dup_last(y - y_err), + dup_last(y_low), color = color, alpha = 0.5, linewidth = 0.5, @@ -691,8 +720,8 @@ def hist_line( ) ax.fill_between( bins, - dup_last(y - y_err), - dup_last(y + y_err), + dup_last(y_low), + dup_last(y_high), facecolor = color, alpha = 0.3, step = "post"