From 1ae7b084156bd00b0aab271ec6dff06125e6e367 Mon Sep 17 00:00:00 2001 From: Theo Heimel Date: Fri, 17 Mar 2023 14:27:25 +0100 Subject: [PATCH] bayesian and plotting improvements --- src/loaders/prec_inn.py | 8 +++---- src/model.py | 29 +++++++++++++----------- src/plots.py | 50 +++++++++++++++++++++++++++++------------ src/train.py | 22 ++++++++++++------ 4 files changed, 71 insertions(+), 38 deletions(-) diff --git a/src/loaders/prec_inn.py b/src/loaders/prec_inn.py index e5886e6..c09935a 100644 --- a/src/loaders/prec_inn.py +++ b/src/loaders/prec_inn.py @@ -36,8 +36,8 @@ def load(params: dict) -> list[DiscriminatorData]: multiplicity_fake = np.sum(fake_momenta[:,:,0] != 0., axis=1) subsets = [ - #{"multiplicity": 3, "label": "$Z+1j$", "suffix": "z1j"}, - #{"multiplicity": 4, "label": "$Z+2j$", "suffix": "z2j"}, + {"multiplicity": 3, "label": "$Z+1j$", "suffix": "z1j"}, + {"multiplicity": 4, "label": "$Z+2j$", "suffix": "z2j"}, {"multiplicity": 5, "label": "$Z+3j$", "suffix": "z3j"}, ] datasets = [] @@ -189,8 +189,8 @@ def compute_observables(true_data: np.ndarray, fake_data: np.ndarray) -> list[Ob fake_data = obs_two_fake[(0,1)].m, tex_label = r"M_{\mu\mu}", bins = np.linspace( - np.quantile(obs_two_true[(0,1)].m, 0.005), - np.quantile(obs_two_true[(0,1)].m, 0.995), + 70, + 110, 50 ), unit = "GeV" diff --git a/src/model.py b/src/model.py index c542cb9..c5a9723 100644 --- a/src/model.py +++ b/src/model.py @@ -20,14 +20,15 @@ def __init__(self, input_dim: int, params: dict): self.bayesian_layers = [] if params["bayesian"]: - layer_class = VBLinear - layer_kwargs = { - "prior_prec": params.get("prior_prec", 1.0), - "std_init": params.get("std_init", -9) - } - else: - layer_class = nn.Linear - layer_kwargs = {} + def make_bayesian_layer(n_in, n_out): + layer = VBLinear( + n_in, + n_out, + prior_prec = params.get("prior_prec", 1.0), + std_init = params.get("std_init", -9) + ) + self.bayesian_layers.append(layer) + return layer activation = { "relu": nn.ReLU, @@ -40,17 +41,19 @@ def __init__(self, input_dim: int, params: dict): layer_size = input_dim for i in range(params["layers"] - 1): hidden_size = params["hidden_size"] - layer = layer_class(layer_size, hidden_size, **layer_kwargs) - if params["bayesian"]: - self.bayesian_layers.append(layer) + if params["bayesian"] and i >= params.get("skip_bayesian_layers", 0): + layer = make_bayesian_layer(layer_size, hidden_size) + else: + layer = nn.Linear(layer_size, hidden_size) layers.append(layer) if dropout > 0: layers.append(nn.Dropout(p=dropout)) layers.append(activation()) layer_size = hidden_size - layer = layer_class(layer_size, 1, **layer_kwargs) if params["bayesian"]: - self.bayesian_layers.append(layer) + layer = make_bayesian_layer(layer_size, 1) + else: + layer = nn.Linear(layer_size, 1) layers.append(layer) self.layers = nn.Sequential(*layers) diff --git a/src/plots.py b/src/plots.py index 9a5ee4e..2317014 100644 --- a/src/plots.py +++ b/src/plots.py @@ -9,8 +9,8 @@ Line = namedtuple( "Line", - ["y", "y_err", "y_ref", "y_orig", "label", "color"], - defaults = [None, None, None, None, None] + ["y", "y_err", "y_ref", "y_orig", "label", "color", "fill"], + defaults = [None, None, None, None, None, False] ) class Plots: @@ -514,6 +514,7 @@ def plot_single_clustering( for threshold in upper_thresholds: masks.append(weights_fake > threshold) labels.append(f"$w > {threshold}$") + true_hist, _ = np.histogram(observable.true_data, bins=bins, density=True) hists = [ np.histogram( observable.fake_data[self.fake_mask][mask], @@ -523,8 +524,16 @@ def plot_single_clustering( for mask in masks ] lines = [ - Line(y=hist, label=label, color=color) - for hist, label, color in zip(hists, labels, self.colors) + Line( + y = true_hist, + label = "Truth", + color = "k", + fill = True + ), + *[ + Line(y=hist, label=label, color=color) + for hist, label, color in zip(hists, labels, self.colors) + ] ] self.hist_plot(pdf, lines, bins, observable, show_ratios=False, show_weights=False) @@ -575,7 +584,8 @@ def hist_plot( line.y * scale, line.y_err * scale if line.y_err is not None else None, label=line.label, - color=line.color + color=line.color, + fill=line.fill ) ratio_panels = [] @@ -626,7 +636,8 @@ def hist_line( y: np.ndarray, y_err: np.ndarray, label: str, - color: str + color: str, + fill: bool = False ): """ Plot a stepped line for a histogram, optionally with error bars. @@ -638,18 +649,29 @@ def hist_line( y_err: Y errors for the bins label: Label of the line color: Color of the line + fill: Filled histogram """ dup_last = lambda a: np.append(a, a[-1]) - ax.step( - bins, - dup_last(y), - label = label, - color = color, - linewidth = 1.0, - where = "post", - ) + if fill: + ax.fill_between( + bins, + dup_last(y), + label = label, + facecolor = color, + step = "post", + alpha = 0.2 + ) + else: + ax.step( + bins, + dup_last(y), + label = label, + color = color, + linewidth = 1.0, + where = "post", + ) if y_err is not None: ax.step( bins, diff --git a/src/train.py b/src/train.py index 10354aa..7dcc993 100644 --- a/src/train.py +++ b/src/train.py @@ -187,15 +187,23 @@ def train(self): epoch_lr = self.optimizer.param_groups[0]["lr"] self.losses["lr"].append(epoch_lr) if self.bayesian: - self.losses["train_bce_loss"].append(torch.stack(epoch_bce_losses).mean().item()) - self.losses["train_kl_loss"].append(torch.stack(epoch_kl_losses).mean().item()) + train_bce_loss = torch.stack(epoch_bce_losses).mean().item() + train_kl_loss = torch.stack(epoch_kl_losses).mean().item() + self.losses["train_bce_loss"].append(train_bce_loss) + self.losses["train_kl_loss"].append(train_kl_loss) self.losses["val_bce_loss"].append(val_bce_loss.item()) self.losses["val_kl_loss"].append(val_kl_loss.item()) - print(f" Epoch {epoch:3d}: train loss {train_loss:.6f}, " + - f"val loss {val_loss:.6f}, LR {epoch_lr:.3e}", flush=True) - - if val_loss < best_val_loss: - best_val_loss = val_loss + print(f" Epoch {epoch:3d}: train loss {train_loss:.6f} " + + f"(BCE {train_bce_loss:.6f}, KL {train_kl_loss:.6f}), " + + f"val loss {val_loss:.6f} " + + f"(BCE {val_bce_loss:.6f}, KL {val_kl_loss:.6f}), " + + f"LR {epoch_lr:.3e}", flush=True) + else: + print(f" Epoch {epoch:3d}: train loss {train_loss:.6f}, " + + f"val loss {val_loss:.6f}, LR {epoch_lr:.3e}", flush=True) + + if val_bce_loss < best_val_loss: + best_val_loss = val_bce_loss self.save("best") if checkpoint_interval is not None and (epoch+1) % checkpoint_interval == 0: