diff --git a/README.md b/README.md index 5248fc3..23dbb4b 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,9 @@ +# Discriminator Metric for Generative Models + +This tool makes it easy to train a classifier to reweight samples from a generative +model. It comes with an extensive plotting pipeline that allows to inspect the classifier +output to evaluate the performance of the generative network. + ## Usage Training a discriminator: @@ -17,6 +23,8 @@ python -m src --load_model --load_weights 20230303_100000_run_name ## Parameters +The following parameters can be set in the YAML parameter file. + ### Data loader Parameter | Description diff --git a/src/loaders/prec_inn.py b/src/loaders/prec_inn.py index 51b3500..39ef27c 100644 --- a/src/loaders/prec_inn.py +++ b/src/loaders/prec_inn.py @@ -6,6 +6,28 @@ from ..observable import Observable def load(params: dict) -> list[DiscriminatorData]: + """ + Loads the training, test and validation data (truth samples and generated samples) + and computes observables and preprocessed samples for the classifier training. + This loader is written for Z+{1,2,3}j events, as generated in the paper + Generative Networks for Precision Enthusiasts. + + Parameters: + truth_file: Path to an HDF5 file with the truth data + generated_file: Path to an HDF5 file with the generated data + train_split: Fraction of events to be used as training data + test_split: Fraction of events to be used as test data (the remaining events are used + for validation) + include_momenta: Include the full kinematic data in the training + append_mass: Append M_{mu mu} to the training data + append_delta_r: Append the jet delta R to the training data + + Args: + params: Dict with loader parameters (see above) + + Returns: + List of three DiscriminatorData objects, one for each jet multiplicity + """ true_momenta = pd.read_hdf(params["truth_file"]).to_numpy().reshape(-1, 5, 4) fake_momenta = pd.read_hdf(params["generated_file"]).to_numpy().reshape(-1, 5, 4) true_momenta = true_momenta[np.all(np.isfinite(true_momenta), axis=(1,2))]