Skip to content

Commit

Permalink
bayesian and plotting improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
theoheimel committed Mar 17, 2023
1 parent 0b467e0 commit 1ae7b08
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 38 deletions.
8 changes: 4 additions & 4 deletions src/loaders/prec_inn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def load(params: dict) -> list[DiscriminatorData]:
multiplicity_fake = np.sum(fake_momenta[:,:,0] != 0., axis=1)

subsets = [
#{"multiplicity": 3, "label": "$Z+1j$", "suffix": "z1j"},
#{"multiplicity": 4, "label": "$Z+2j$", "suffix": "z2j"},
{"multiplicity": 3, "label": "$Z+1j$", "suffix": "z1j"},
{"multiplicity": 4, "label": "$Z+2j$", "suffix": "z2j"},
{"multiplicity": 5, "label": "$Z+3j$", "suffix": "z3j"},
]
datasets = []
Expand Down Expand Up @@ -189,8 +189,8 @@ def compute_observables(true_data: np.ndarray, fake_data: np.ndarray) -> list[Ob
fake_data = obs_two_fake[(0,1)].m,
tex_label = r"M_{\mu\mu}",
bins = np.linspace(
np.quantile(obs_two_true[(0,1)].m, 0.005),
np.quantile(obs_two_true[(0,1)].m, 0.995),
70,
110,
50
),
unit = "GeV"
Expand Down
29 changes: 16 additions & 13 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ def __init__(self, input_dim: int, params: dict):
self.bayesian_layers = []

if params["bayesian"]:
layer_class = VBLinear
layer_kwargs = {
"prior_prec": params.get("prior_prec", 1.0),
"std_init": params.get("std_init", -9)
}
else:
layer_class = nn.Linear
layer_kwargs = {}
def make_bayesian_layer(n_in, n_out):
layer = VBLinear(
n_in,
n_out,
prior_prec = params.get("prior_prec", 1.0),
std_init = params.get("std_init", -9)
)
self.bayesian_layers.append(layer)
return layer

activation = {
"relu": nn.ReLU,
Expand All @@ -40,17 +41,19 @@ def __init__(self, input_dim: int, params: dict):
layer_size = input_dim
for i in range(params["layers"] - 1):
hidden_size = params["hidden_size"]
layer = layer_class(layer_size, hidden_size, **layer_kwargs)
if params["bayesian"]:
self.bayesian_layers.append(layer)
if params["bayesian"] and i >= params.get("skip_bayesian_layers", 0):
layer = make_bayesian_layer(layer_size, hidden_size)
else:
layer = nn.Linear(layer_size, hidden_size)
layers.append(layer)
if dropout > 0:
layers.append(nn.Dropout(p=dropout))
layers.append(activation())
layer_size = hidden_size
layer = layer_class(layer_size, 1, **layer_kwargs)
if params["bayesian"]:
self.bayesian_layers.append(layer)
layer = make_bayesian_layer(layer_size, 1)
else:
layer = nn.Linear(layer_size, 1)
layers.append(layer)
self.layers = nn.Sequential(*layers)

Expand Down
50 changes: 36 additions & 14 deletions src/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

Line = namedtuple(
"Line",
["y", "y_err", "y_ref", "y_orig", "label", "color"],
defaults = [None, None, None, None, None]
["y", "y_err", "y_ref", "y_orig", "label", "color", "fill"],
defaults = [None, None, None, None, None, False]
)

class Plots:
Expand Down Expand Up @@ -514,6 +514,7 @@ def plot_single_clustering(
for threshold in upper_thresholds:
masks.append(weights_fake > threshold)
labels.append(f"$w > {threshold}$")
true_hist, _ = np.histogram(observable.true_data, bins=bins, density=True)
hists = [
np.histogram(
observable.fake_data[self.fake_mask][mask],
Expand All @@ -523,8 +524,16 @@ def plot_single_clustering(
for mask in masks
]
lines = [
Line(y=hist, label=label, color=color)
for hist, label, color in zip(hists, labels, self.colors)
Line(
y = true_hist,
label = "Truth",
color = "k",
fill = True
),
*[
Line(y=hist, label=label, color=color)
for hist, label, color in zip(hists, labels, self.colors)
]
]
self.hist_plot(pdf, lines, bins, observable, show_ratios=False, show_weights=False)

Expand Down Expand Up @@ -575,7 +584,8 @@ def hist_plot(
line.y * scale,
line.y_err * scale if line.y_err is not None else None,
label=line.label,
color=line.color
color=line.color,
fill=line.fill
)

ratio_panels = []
Expand Down Expand Up @@ -626,7 +636,8 @@ def hist_line(
y: np.ndarray,
y_err: np.ndarray,
label: str,
color: str
color: str,
fill: bool = False
):
"""
Plot a stepped line for a histogram, optionally with error bars.
Expand All @@ -638,18 +649,29 @@ def hist_line(
y_err: Y errors for the bins
label: Label of the line
color: Color of the line
fill: Filled histogram
"""

dup_last = lambda a: np.append(a, a[-1])

ax.step(
bins,
dup_last(y),
label = label,
color = color,
linewidth = 1.0,
where = "post",
)
if fill:
ax.fill_between(
bins,
dup_last(y),
label = label,
facecolor = color,
step = "post",
alpha = 0.2
)
else:
ax.step(
bins,
dup_last(y),
label = label,
color = color,
linewidth = 1.0,
where = "post",
)
if y_err is not None:
ax.step(
bins,
Expand Down
22 changes: 15 additions & 7 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,23 @@ def train(self):
epoch_lr = self.optimizer.param_groups[0]["lr"]
self.losses["lr"].append(epoch_lr)
if self.bayesian:
self.losses["train_bce_loss"].append(torch.stack(epoch_bce_losses).mean().item())
self.losses["train_kl_loss"].append(torch.stack(epoch_kl_losses).mean().item())
train_bce_loss = torch.stack(epoch_bce_losses).mean().item()
train_kl_loss = torch.stack(epoch_kl_losses).mean().item()
self.losses["train_bce_loss"].append(train_bce_loss)
self.losses["train_kl_loss"].append(train_kl_loss)
self.losses["val_bce_loss"].append(val_bce_loss.item())
self.losses["val_kl_loss"].append(val_kl_loss.item())
print(f" Epoch {epoch:3d}: train loss {train_loss:.6f}, " +
f"val loss {val_loss:.6f}, LR {epoch_lr:.3e}", flush=True)

if val_loss < best_val_loss:
best_val_loss = val_loss
print(f" Epoch {epoch:3d}: train loss {train_loss:.6f} " +
f"(BCE {train_bce_loss:.6f}, KL {train_kl_loss:.6f}), " +
f"val loss {val_loss:.6f} " +
f"(BCE {val_bce_loss:.6f}, KL {val_kl_loss:.6f}), " +
f"LR {epoch_lr:.3e}", flush=True)
else:
print(f" Epoch {epoch:3d}: train loss {train_loss:.6f}, " +
f"val loss {val_loss:.6f}, LR {epoch_lr:.3e}", flush=True)

if val_bce_loss < best_val_loss:
best_val_loss = val_bce_loss
self.save("best")

if checkpoint_interval is not None and (epoch+1) % checkpoint_interval == 0:
Expand Down

0 comments on commit 1ae7b08

Please sign in to comment.