Skip to content

Commit

Permalink
Merge branch 'main' of github.com:heidelberg-hepml/discriminator-metric
Browse files Browse the repository at this point in the history
  • Loading branch information
luigifvr committed Mar 29, 2023
2 parents e539685 + 4c446d0 commit fcde05e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
24 changes: 15 additions & 9 deletions src/loaders/prec_inn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ def load(params: dict) -> list[DiscriminatorData]:
generated = pd.read_hdf(params["generated_file"]).to_numpy()
fake_momenta = generated[:,:20].reshape(-1, 5, 4)
fake_log_weights = generated[:,20:] if generated.shape[1] > 20 else None
true_momenta = true_momenta[np.all(np.isfinite(true_momenta), axis=(1,2))]
fake_momenta = fake_momenta[np.all(np.isfinite(fake_momenta), axis=(1,2))]
true_mask = np.all(np.isfinite(true_momenta), axis=(1,2))
fake_mask = np.all(np.isfinite(fake_momenta), axis=(1,2))
true_momenta = true_momenta[true_mask]
fake_momenta = fake_momenta[fake_mask]
fake_log_weights = fake_log_weights[fake_mask] if fake_log_weights is not None else None
multiplicity_true = np.sum(true_momenta[:,:,0] != 0., axis=1)
multiplicity_fake = np.sum(fake_momenta[:,:,0] != 0., axis=1)

Expand All @@ -47,7 +50,6 @@ def load(params: dict) -> list[DiscriminatorData]:
mult = subset["multiplicity"]
subset_true = true_momenta[multiplicity_true == mult][:,:mult]
subset_fake = fake_momenta[multiplicity_fake == mult][:,:mult]
subset_logw = fake_log_weights[multiplicity_fake == mult]
train_true, test_true, val_true = split_data(
subset_true,
params["train_split"],
Expand All @@ -58,11 +60,15 @@ def load(params: dict) -> list[DiscriminatorData]:
params["train_split"],
params["test_split"]
)
train_logw, test_logw, val_logw = split_data(
subset_logw,
params["train_split"],
params["test_split"]
)
if fake_log_weights is None:
test_logw = None
else:
subset_logw = fake_log_weights[multiplicity_fake == mult]
train_logw, test_logw, val_logw = split_data(
subset_logw,
params["train_split"],
params["test_split"]
)
preproc_kwargs = {
"norm": {},
"include_momenta": params.get("include_momenta", True),
Expand Down Expand Up @@ -133,7 +139,7 @@ def compute_preprocessing(
input_obs.append(mass)

if append_delta_r:
drinv = lambda x: np.minimum(1/(x+1e-7), 20)
drinv = lambda x: x #np.minimum(1/(x+1e-7), 20)
if mult > 3:
input_obs.append(drinv(dr(obs.phi[:,2], obs.phi[:,3], obs.eta[:,2], obs.eta[:,3])))
if mult > 4:
Expand Down
4 changes: 2 additions & 2 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ def train(self):
print(f" Epoch {epoch:3d}: train loss {train_loss:.6f}, " +
f"val loss {val_loss:.6f}, LR {epoch_lr:.3e}", flush=True)

if val_bce_loss < best_val_loss:
best_val_loss = val_bce_loss
if val_loss < best_val_loss:
best_val_loss = val_loss
self.save("best")

if checkpoint_interval is not None and (epoch+1) % checkpoint_interval == 0:
Expand Down

0 comments on commit fcde05e

Please sign in to comment.