Skip to content

Commit

Permalink
bayesian generators
Browse files Browse the repository at this point in the history
  • Loading branch information
theoheimel committed Mar 21, 2023
1 parent d215749 commit eb173ab
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 59 deletions.
9 changes: 5 additions & 4 deletions src/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,18 @@ def main():
weights_true,
weights_fake,
training.losses,
data.label
data.label,
data.test_logw
)
print(" Plotting losses")
plots.plot_losses(doc.add_file(f"losses_{data.suffix}.pdf"))
print(" Plotting ROC")
plots.plot_roc(doc.add_file(f"roc_{data.suffix}.pdf"))
print(" Plotting weights")
plots.plot_weight_hist(doc.add_file(f"weights_{data.suffix}.pdf"))
if plots.bayesian:
print(" Plotting pulls")
plots.plot_weight_pulls(doc.add_file(f"pulls_{data.suffix}.pdf"))
if data.test_logw is not None:
print(" Plotting generator errors")
plots.plot_bgen_weights(doc.add_file(f"gen_errors_{data.suffix}.pdf"))
print(" Plotting observables")
plots.plot_observables(doc.add_file(f"observables_{data.suffix}.pdf"))
lower_thresholds = params.get("lower_cluster_thresholds", [])
Expand Down
3 changes: 3 additions & 0 deletions src/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import Optional
import numpy as np

from .observable import Observable
Expand All @@ -16,6 +17,7 @@ class DiscriminatorData:
train_fake: Training data (generated samples)
test_true: Test data (truth samples)
test_fake: Test data (generated samples)
test_logw: For Bayesian generative models: log weight distribution for test data
val_true: Validation data (truth samples)
val_fake: Validation data (generated samples)
observables: List observables for plotting
Expand All @@ -30,3 +32,4 @@ class DiscriminatorData:
val_true: np.ndarray
val_fake: np.ndarray
observables: list[Observable]
test_logw: Optional[np.ndarray] = None
11 changes: 10 additions & 1 deletion src/loaders/prec_inn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def load(params: dict) -> list[DiscriminatorData]:
List of three DiscriminatorData objects, one for each jet multiplicity
"""
true_momenta = pd.read_hdf(params["truth_file"]).to_numpy().reshape(-1, 5, 4)
fake_momenta = pd.read_hdf(params["generated_file"]).to_numpy().reshape(-1, 5, 4)
generated = pd.read_hdf(params["generated_file"]).to_numpy()
fake_momenta = generated[:,:20].reshape(-1, 5, 4)
fake_log_weights = generated[:,20:] if generated.shape[1] > 20 else None
true_momenta = true_momenta[np.all(np.isfinite(true_momenta), axis=(1,2))]
fake_momenta = fake_momenta[np.all(np.isfinite(fake_momenta), axis=(1,2))]
multiplicity_true = np.sum(true_momenta[:,:,0] != 0., axis=1)
Expand All @@ -45,6 +47,7 @@ def load(params: dict) -> list[DiscriminatorData]:
mult = subset["multiplicity"]
subset_true = true_momenta[multiplicity_true == mult][:,:mult]
subset_fake = fake_momenta[multiplicity_fake == mult][:,:mult]
subset_logw = fake_log_weights[multiplicity_fake == mult]
train_true, test_true, val_true = split_data(
subset_true,
params["train_split"],
Expand All @@ -55,6 +58,11 @@ def load(params: dict) -> list[DiscriminatorData]:
params["train_split"],
params["test_split"]
)
train_logw, test_logw, val_logw = split_data(
subset_logw,
params["train_split"],
params["test_split"]
)
preproc_kwargs = {
"norm": {},
"include_momenta": params.get("include_momenta", True),
Expand All @@ -75,6 +83,7 @@ def load(params: dict) -> list[DiscriminatorData]:
val_true = compute_preprocessing(val_true, **preproc_kwargs),
val_fake = compute_preprocessing(val_fake, **preproc_kwargs),
observables = compute_observables(test_true, test_fake),
test_logw = test_logw,
))
return datasets

Expand Down
82 changes: 28 additions & 54 deletions src/plots.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings
from collections import namedtuple
from typing import Optional
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
Expand All @@ -24,7 +25,8 @@ def __init__(
weights_true: np.ndarray,
weights_fake: np.ndarray,
losses: dict,
title: str
title: str,
log_gen_weights: Optional[np.ndarray] = None
):
"""
Initializes the plotting pipeline with the data to be plotted.
Expand All @@ -35,6 +37,7 @@ def __init__(
weights_fake: Discriminator weights for generated samples
losses: Dictionary with loss terms and learning rate as a function of the epoch
title: Title added in all the plots
log_gen_weights: For Bayesian generators: sampled log weights
"""
self.observables = observables
self.bayesian = len(weights_true.shape) == 2
Expand All @@ -48,6 +51,7 @@ def __init__(
self.weights_fake = weights_fake[self.fake_mask]
self.losses = losses
self.title = title
self.log_gen_weights = log_gen_weights

plt.rc("font", family="serif", size=16)
plt.rc("axes", titlesize="medium")
Expand Down Expand Up @@ -313,67 +317,37 @@ def plot_single_weight_hist(
plt.close()


def plot_weight_pulls(self, file: str):
def plot_bgen_weights(self, file: str):
"""
Plots histograms of the weight pulls extracted from the Bayesian network,
defined as (mu - 1) / sigma, where mu and sigma are mean and standard
deviation from sampling over the trainable weight posterior.
Plots 2d histogram of the error on the weights from a Bayesian generator network
against the weights found by the discriminator.
Args:
file: Output file name
"""
assert self.bayesian
bins = np.linspace(-5, 5, 50)
assert self.log_gen_weights is not None

with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)

w_normed_true = self.weights_true / np.mean(self.weights_true)
mu_true = np.mean(w_normed_true, axis=1)
sigma_true = np.std(w_normed_true, axis=1)
pull_true = (mu_true - 1.) / sigma_true

w_normed_fake = self.weights_fake / np.mean(self.weights_fake)
mu_fake = np.mean(w_normed_fake, axis=1)
sigma_fake = np.std(w_normed_fake, axis=1)
pull_fake = (mu_fake - 1.) / sigma_fake

pull_combined = np.concatenate((pull_true, pull_fake))
y_combined, _ = np.histogram(pull_combined, bins=bins, density=True)
y_true, _ = np.histogram(pull_true, bins=bins, density=True)
y_fake, _ = np.histogram(pull_fake, bins=bins, density=True)
w_bgen = np.std(self.log_gen_weights, axis=1)
if self.bayesian:
w_bgen = np.repeat(w_bgen[:,None], self.weights_fake.shape[1], axis=1).flatten()
w_disc = self.weights_fake.flatten()
else:
w_disc = self.weights_fake
x_bins = np.linspace(0,4,30)
y_bins = np.linspace(0,3,30)

fig, ax = plt.subplots(figsize=(4, 3.5))
self.hist_line(
ax,
bins,
y_combined,
y_err = None,
label = "Comb",
color = self.colors[0]
)
self.hist_line(
ax,
bins,
y_true,
y_err = None,
label = "Truth",
color = self.colors[1]
)
self.hist_line(
ax,
bins,
y_fake,
y_err = None,
label = "Gen",
color = self.colors[2]
fig, ax = plt.subplots(figsize=(4,3.5))
ax.hist2d(
w_bgen,
w_disc,
bins=(x_bins, y_bins),
rasterized=True,
norm = mpl.colors.LogNorm(),
density=True,
cmap="jet"
)
self.corner_text(ax, self.title, "right", "top")

ax.set_xlabel(r"$(\mu - 1) / \sigma$")
ax.set_ylabel("normalized")
ax.set_yscale("log")
ax.set_xlim(bins[0], bins[-1])
ax.set_xlabel(r"$\sigma(\log w_\text{gen})$")
ax.set_ylabel(r"$w_\text{disc}$")

plt.savefig(file, bbox_inches="tight")
plt.close()
Expand Down

0 comments on commit eb173ab

Please sign in to comment.