Skip to content

Commit

Permalink
more documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
theoheimel committed Mar 14, 2023
1 parent b7a85f1 commit 384497f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Discriminator Metric for Generative Models

This tool makes it easy to train a classifier to reweight samples from a generative
model. It comes with an extensive plotting pipeline that allows to inspect the classifier
output to evaluate the performance of the generative network.

## Usage

Training a discriminator:
Expand All @@ -17,6 +23,8 @@ python -m src --load_model --load_weights 20230303_100000_run_name

## Parameters

The following parameters can be set in the YAML parameter file.

### Data loader

Parameter | Description
Expand Down
22 changes: 22 additions & 0 deletions src/loaders/prec_inn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,28 @@
from ..observable import Observable

def load(params: dict) -> list[DiscriminatorData]:
"""
Loads the training, test and validation data (truth samples and generated samples)
and computes observables and preprocessed samples for the classifier training.
This loader is written for Z+{1,2,3}j events, as generated in the paper
Generative Networks for Precision Enthusiasts.
Parameters:
truth_file: Path to an HDF5 file with the truth data
generated_file: Path to an HDF5 file with the generated data
train_split: Fraction of events to be used as training data
test_split: Fraction of events to be used as test data (the remaining events are used
for validation)
include_momenta: Include the full kinematic data in the training
append_mass: Append M_{mu mu} to the training data
append_delta_r: Append the jet delta R to the training data
Args:
params: Dict with loader parameters (see above)
Returns:
List of three DiscriminatorData objects, one for each jet multiplicity
"""
true_momenta = pd.read_hdf(params["truth_file"]).to_numpy().reshape(-1, 5, 4)
fake_momenta = pd.read_hdf(params["generated_file"]).to_numpy().reshape(-1, 5, 4)
true_momenta = true_momenta[np.all(np.isfinite(true_momenta), axis=(1,2))]
Expand Down

0 comments on commit 384497f

Please sign in to comment.