Skip to content

Commit

Permalink
make return wave samples not contain zeros
Browse files Browse the repository at this point in the history
which make training unstable
  • Loading branch information
yoyolicoris authored May 27, 2020
1 parent 1ede586 commit 8fead47
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions data_loader/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def __init__(self,
self.waves = []
self.sr = None
self.files = []
self.file_lengths = []

file_lengths = []

def get_nframes(info_str):
try:
Expand All @@ -96,26 +97,26 @@ def get_nframes(info_str):
filename = os.path.join(self.data_path, f)
f_obj = sf.SoundFile(filename)
self.files.append(f_obj)
self.file_lengths.append(get_nframes(f_obj.extra_info))
file_lengths.append(max(0, get_nframes(f_obj.extra_info) - segment) + 1)

if not self.sr:
self.sr = f_obj.samplerate
else:
assert f_obj.samplerate == self.sr
self.file_lengths = np.array(self.file_lengths)
self.boundaries = np.cumsum(self.file_lengths) / (self.file_lengths.sum() - 1)
self.file_lengths = np.array(file_lengths)
self.boundaries = np.cumsum(self.file_lengths) / self.file_lengths.sum()

# normalization value based on each file
# will updated on the fily
self.max_values = np.zeros_like(self.boundaries)
self.max_values = np.zeros(self.file_lengths.shape).astype(np.float32)

def __len__(self):
return self.size

def __getitem__(self, index):
index = np.digitize(random.uniform(0, 1), self.boundaries)
index = np.digitize(random.uniform(0, 1), self.boundaries, right=True)
f, length = self.files[index], self.file_lengths[index]
pos = random.randint(0, length - 1)
pos = random.randrange(0, length)
f.seek(pos)
x = f.read(self.segment, dtype='float32', always_2d=True, fill_value=0.).mean(1)
max_abs = np.abs(x).max()
Expand Down

0 comments on commit 8fead47

Please sign in to comment.