diff --git a/src/__main__.py b/src/__main__.py index c6a906b..2f5b921 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -85,7 +85,8 @@ def main(): weights_true, weights_fake, training.losses, - data.label + data.label, + data.test_logw ) print(" Plotting losses") plots.plot_losses(doc.add_file(f"losses_{data.suffix}.pdf")) @@ -93,9 +94,9 @@ def main(): plots.plot_roc(doc.add_file(f"roc_{data.suffix}.pdf")) print(" Plotting weights") plots.plot_weight_hist(doc.add_file(f"weights_{data.suffix}.pdf")) - if plots.bayesian: - print(" Plotting pulls") - plots.plot_weight_pulls(doc.add_file(f"pulls_{data.suffix}.pdf")) + if data.test_logw is not None: + print(" Plotting generator errors") + plots.plot_bgen_weights(doc.add_file(f"gen_errors_{data.suffix}.pdf")) print(" Plotting observables") plots.plot_observables(doc.add_file(f"observables_{data.suffix}.pdf")) lower_thresholds = params.get("lower_cluster_thresholds", []) diff --git a/src/dataset.py b/src/dataset.py index 1ee0bb0..3cf9eb7 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional import numpy as np from .observable import Observable @@ -16,6 +17,7 @@ class DiscriminatorData: train_fake: Training data (generated samples) test_true: Test data (truth samples) test_fake: Test data (generated samples) + test_logw: For Bayesian generative models: log weight distribution for test data val_true: Validation data (truth samples) val_fake: Validation data (generated samples) observables: List observables for plotting @@ -30,3 +32,4 @@ class DiscriminatorData: val_true: np.ndarray val_fake: np.ndarray observables: list[Observable] + test_logw: Optional[np.ndarray] = None diff --git a/src/loaders/prec_inn.py b/src/loaders/prec_inn.py index c09935a..505835b 100644 --- a/src/loaders/prec_inn.py +++ b/src/loaders/prec_inn.py @@ -29,7 +29,9 @@ def load(params: dict) -> list[DiscriminatorData]: List of three DiscriminatorData objects, one for each jet multiplicity """ true_momenta = pd.read_hdf(params["truth_file"]).to_numpy().reshape(-1, 5, 4) - fake_momenta = pd.read_hdf(params["generated_file"]).to_numpy().reshape(-1, 5, 4) + generated = pd.read_hdf(params["generated_file"]).to_numpy() + fake_momenta = generated[:,:20].reshape(-1, 5, 4) + fake_log_weights = generated[:,20:] if generated.shape[1] > 20 else None true_momenta = true_momenta[np.all(np.isfinite(true_momenta), axis=(1,2))] fake_momenta = fake_momenta[np.all(np.isfinite(fake_momenta), axis=(1,2))] multiplicity_true = np.sum(true_momenta[:,:,0] != 0., axis=1) @@ -45,6 +47,7 @@ def load(params: dict) -> list[DiscriminatorData]: mult = subset["multiplicity"] subset_true = true_momenta[multiplicity_true == mult][:,:mult] subset_fake = fake_momenta[multiplicity_fake == mult][:,:mult] + subset_logw = fake_log_weights[multiplicity_fake == mult] train_true, test_true, val_true = split_data( subset_true, params["train_split"], @@ -55,6 +58,11 @@ def load(params: dict) -> list[DiscriminatorData]: params["train_split"], params["test_split"] ) + train_logw, test_logw, val_logw = split_data( + subset_logw, + params["train_split"], + params["test_split"] + ) preproc_kwargs = { "norm": {}, "include_momenta": params.get("include_momenta", True), @@ -75,6 +83,7 @@ def load(params: dict) -> list[DiscriminatorData]: val_true = compute_preprocessing(val_true, **preproc_kwargs), val_fake = compute_preprocessing(val_fake, **preproc_kwargs), observables = compute_observables(test_true, test_fake), + test_logw = test_logw, )) return datasets diff --git a/src/plots.py b/src/plots.py index 37c2242..8e27a50 100644 --- a/src/plots.py +++ b/src/plots.py @@ -1,5 +1,6 @@ import warnings from collections import namedtuple +from typing import Optional import numpy as np import matplotlib.pyplot as plt import matplotlib as mpl @@ -24,7 +25,8 @@ def __init__( weights_true: np.ndarray, weights_fake: np.ndarray, losses: dict, - title: str + title: str, + log_gen_weights: Optional[np.ndarray] = None ): """ Initializes the plotting pipeline with the data to be plotted. @@ -35,6 +37,7 @@ def __init__( weights_fake: Discriminator weights for generated samples losses: Dictionary with loss terms and learning rate as a function of the epoch title: Title added in all the plots + log_gen_weights: For Bayesian generators: sampled log weights """ self.observables = observables self.bayesian = len(weights_true.shape) == 2 @@ -48,6 +51,7 @@ def __init__( self.weights_fake = weights_fake[self.fake_mask] self.losses = losses self.title = title + self.log_gen_weights = log_gen_weights plt.rc("font", family="serif", size=16) plt.rc("axes", titlesize="medium") @@ -313,67 +317,37 @@ def plot_single_weight_hist( plt.close() - def plot_weight_pulls(self, file: str): + def plot_bgen_weights(self, file: str): """ - Plots histograms of the weight pulls extracted from the Bayesian network, - defined as (mu - 1) / sigma, where mu and sigma are mean and standard - deviation from sampling over the trainable weight posterior. + Plots 2d histogram of the error on the weights from a Bayesian generator network + against the weights found by the discriminator. Args: file: Output file name """ - assert self.bayesian - bins = np.linspace(-5, 5, 50) + assert self.log_gen_weights is not None - with warnings.catch_warnings(): - warnings.simplefilter("ignore", RuntimeWarning) - - w_normed_true = self.weights_true / np.mean(self.weights_true) - mu_true = np.mean(w_normed_true, axis=1) - sigma_true = np.std(w_normed_true, axis=1) - pull_true = (mu_true - 1.) / sigma_true - - w_normed_fake = self.weights_fake / np.mean(self.weights_fake) - mu_fake = np.mean(w_normed_fake, axis=1) - sigma_fake = np.std(w_normed_fake, axis=1) - pull_fake = (mu_fake - 1.) / sigma_fake - - pull_combined = np.concatenate((pull_true, pull_fake)) - y_combined, _ = np.histogram(pull_combined, bins=bins, density=True) - y_true, _ = np.histogram(pull_true, bins=bins, density=True) - y_fake, _ = np.histogram(pull_fake, bins=bins, density=True) + w_bgen = np.std(self.log_gen_weights, axis=1) + if self.bayesian: + w_bgen = np.repeat(w_bgen[:,None], self.weights_fake.shape[1], axis=1).flatten() + w_disc = self.weights_fake.flatten() + else: + w_disc = self.weights_fake + x_bins = np.linspace(0,4,30) + y_bins = np.linspace(0,3,30) - fig, ax = plt.subplots(figsize=(4, 3.5)) - self.hist_line( - ax, - bins, - y_combined, - y_err = None, - label = "Comb", - color = self.colors[0] - ) - self.hist_line( - ax, - bins, - y_true, - y_err = None, - label = "Truth", - color = self.colors[1] - ) - self.hist_line( - ax, - bins, - y_fake, - y_err = None, - label = "Gen", - color = self.colors[2] + fig, ax = plt.subplots(figsize=(4,3.5)) + ax.hist2d( + w_bgen, + w_disc, + bins=(x_bins, y_bins), + rasterized=True, + norm = mpl.colors.LogNorm(), + density=True, + cmap="jet" ) - self.corner_text(ax, self.title, "right", "top") - - ax.set_xlabel(r"$(\mu - 1) / \sigma$") - ax.set_ylabel("normalized") - ax.set_yscale("log") - ax.set_xlim(bins[0], bins[-1]) + ax.set_xlabel(r"$\sigma(\log w_\text{gen})$") + ax.set_ylabel(r"$w_\text{disc}$") plt.savefig(file, bbox_inches="tight") plt.close()