diff --git a/src/plots.py b/src/plots.py index 9aa91f7..9a5ee4e 100644 --- a/src/plots.py +++ b/src/plots.py @@ -140,7 +140,7 @@ def plot_roc(self, file: str): Args: file: Output file name """ - scores = np.concatenate((self.weights_true, self.weights_fake), axis=0) + scores = -np.concatenate((1/self.weights_true, self.weights_fake), axis=0) labels = np.concatenate(( np.ones_like(self.weights_true), np.zeros_like(self.weights_fake)