From 8fead47cc027e586f8538cda8673d849acda9f97 Mon Sep 17 00:00:00 2001 From: chin yun yu Date: Wed, 27 May 2020 11:15:22 +0800 Subject: [PATCH] make return wave samples not contain zeros which make training unstable --- data_loader/data_loaders.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/data_loader/data_loaders.py b/data_loader/data_loaders.py index dcc616c..92b4b53 100644 --- a/data_loader/data_loaders.py +++ b/data_loader/data_loaders.py @@ -79,7 +79,8 @@ def __init__(self, self.waves = [] self.sr = None self.files = [] - self.file_lengths = [] + + file_lengths = [] def get_nframes(info_str): try: @@ -96,26 +97,26 @@ def get_nframes(info_str): filename = os.path.join(self.data_path, f) f_obj = sf.SoundFile(filename) self.files.append(f_obj) - self.file_lengths.append(get_nframes(f_obj.extra_info)) + file_lengths.append(max(0, get_nframes(f_obj.extra_info) - segment) + 1) if not self.sr: self.sr = f_obj.samplerate else: assert f_obj.samplerate == self.sr - self.file_lengths = np.array(self.file_lengths) - self.boundaries = np.cumsum(self.file_lengths) / (self.file_lengths.sum() - 1) + self.file_lengths = np.array(file_lengths) + self.boundaries = np.cumsum(self.file_lengths) / self.file_lengths.sum() # normalization value based on each file # will updated on the fily - self.max_values = np.zeros_like(self.boundaries) + self.max_values = np.zeros(self.file_lengths.shape).astype(np.float32) def __len__(self): return self.size def __getitem__(self, index): - index = np.digitize(random.uniform(0, 1), self.boundaries) + index = np.digitize(random.uniform(0, 1), self.boundaries, right=True) f, length = self.files[index], self.file_lengths[index] - pos = random.randint(0, length - 1) + pos = random.randrange(0, length) f.seek(pos) x = f.read(self.segment, dtype='float32', always_2d=True, fill_value=0.).mean(1) max_abs = np.abs(x).max()