From b0dfaf96b7a16f5c9ebb93f9b46a27c74142b3e5 Mon Sep 17 00:00:00 2001 From: Francesco Vaselli Date: Mon, 28 Aug 2023 16:42:57 +0200 Subject: [PATCH] train with ups dataset3 --- src/data_processing/check_dataset.py | 5 +++++ src/data_processing/upsample_dataset.py | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/src/data_processing/check_dataset.py b/src/data_processing/check_dataset.py index 0fb7c98..534213a 100644 --- a/src/data_processing/check_dataset.py +++ b/src/data_processing/check_dataset.py @@ -32,6 +32,11 @@ def filter_stationary_sequences_dataset(ds): ds = filter_stationary_sequences_dataset(ds) train_x = ds[:, :7] train_y = ds[:, 7:] +# check for nans/infs +print(np.any(np.isnan(train_x))) +print(np.any(np.isnan(train_y))) +print(np.any(np.isinf(train_x))) +print(np.any(np.isinf(train_y))) # filter_stationary_sequences(train_x, train_y) print(train_x.shape, train_y.shape) # restore mean and std diff --git a/src/data_processing/upsample_dataset.py b/src/data_processing/upsample_dataset.py index 5d1538d..f5a414a 100644 --- a/src/data_processing/upsample_dataset.py +++ b/src/data_processing/upsample_dataset.py @@ -58,6 +58,12 @@ def upsample_dataset( dataset._scale(standardize=True) data = dataset.train_x targets = dataset.train_y + + # shuffle the data + idx = np.random.permutation(len(data)) + data = data[idx] + targets = targets[idx] + return data, targets