diff --git a/src/loaders/prec_inn.py b/src/loaders/prec_inn.py index 505835b..310f83f 100644 --- a/src/loaders/prec_inn.py +++ b/src/loaders/prec_inn.py @@ -32,8 +32,11 @@ def load(params: dict) -> list[DiscriminatorData]: generated = pd.read_hdf(params["generated_file"]).to_numpy() fake_momenta = generated[:,:20].reshape(-1, 5, 4) fake_log_weights = generated[:,20:] if generated.shape[1] > 20 else None - true_momenta = true_momenta[np.all(np.isfinite(true_momenta), axis=(1,2))] - fake_momenta = fake_momenta[np.all(np.isfinite(fake_momenta), axis=(1,2))] + true_mask = np.all(np.isfinite(true_momenta), axis=(1,2)) + fake_mask = np.all(np.isfinite(fake_momenta), axis=(1,2)) + true_momenta = true_momenta[true_mask] + fake_momenta = fake_momenta[fake_mask] + fake_log_weights = fake_log_weights[fake_mask] if fake_log_weights is not None else None multiplicity_true = np.sum(true_momenta[:,:,0] != 0., axis=1) multiplicity_fake = np.sum(fake_momenta[:,:,0] != 0., axis=1) @@ -47,7 +50,6 @@ def load(params: dict) -> list[DiscriminatorData]: mult = subset["multiplicity"] subset_true = true_momenta[multiplicity_true == mult][:,:mult] subset_fake = fake_momenta[multiplicity_fake == mult][:,:mult] - subset_logw = fake_log_weights[multiplicity_fake == mult] train_true, test_true, val_true = split_data( subset_true, params["train_split"], @@ -58,11 +60,15 @@ def load(params: dict) -> list[DiscriminatorData]: params["train_split"], params["test_split"] ) - train_logw, test_logw, val_logw = split_data( - subset_logw, - params["train_split"], - params["test_split"] - ) + if fake_log_weights is None: + test_logw = None + else: + subset_logw = fake_log_weights[multiplicity_fake == mult] + train_logw, test_logw, val_logw = split_data( + subset_logw, + params["train_split"], + params["test_split"] + ) preproc_kwargs = { "norm": {}, "include_momenta": params.get("include_momenta", True), @@ -133,7 +139,7 @@ def compute_preprocessing( input_obs.append(mass) if append_delta_r: - drinv = lambda x: np.minimum(1/(x+1e-7), 20) + drinv = lambda x: x #np.minimum(1/(x+1e-7), 20) if mult > 3: input_obs.append(drinv(dr(obs.phi[:,2], obs.phi[:,3], obs.eta[:,2], obs.eta[:,3]))) if mult > 4: diff --git a/src/train.py b/src/train.py index 7dcc993..c084ee1 100644 --- a/src/train.py +++ b/src/train.py @@ -202,8 +202,8 @@ def train(self): print(f" Epoch {epoch:3d}: train loss {train_loss:.6f}, " + f"val loss {val_loss:.6f}, LR {epoch_lr:.3e}", flush=True) - if val_bce_loss < best_val_loss: - best_val_loss = val_bce_loss + if val_loss < best_val_loss: + best_val_loss = val_loss self.save("best") if checkpoint_interval is not None and (epoch+1) % checkpoint_interval == 0: