-
Notifications
You must be signed in to change notification settings - Fork 4
/
create_nsynth_dataset_split.py
72 lines (61 loc) · 2.6 KB
/
create_nsynth_dataset_split.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import argparse
import pathlib
import json
import os
from collections import OrderedDict
from sklearn.model_selection import train_test_split
from GANsynth_pytorch.pytorch_nsynth_lib.nsynth import (
NSynth, WavToSpectrogramDataLoader)
REPRODUCIBLE_RANDOM_SEED = 20200117
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_paths', type=str, nargs='+')
parser.add_argument('--output_directory', type=str, required=True)
parser.add_argument('--valid_pitch_range', type=int, nargs=2,
default=(24, 84))
parser.add_argument('--train_size', type=float, default=0.8)
parser.add_argument('--random_seed', type=int,
default=REPRODUCIBLE_RANDOM_SEED)
args = parser.parse_args()
print(args)
MAIN_OUTPUT_DIRECTORY = pathlib.Path(args.output_directory)
os.makedirs(MAIN_OUTPUT_DIRECTORY, exist_ok=False)
with open(MAIN_OUTPUT_DIRECTORY / 'command_line_parameters.json', 'w') as f:
json.dump(args.__dict__, f, indent=4)
# collect all filenames and associated metadata
filenames = []
all_json_data = {}
for dataset_path in args.dataset_paths:
dataset_path = pathlib.Path(dataset_path).expanduser().absolute()
audio_directory_path = dataset_path / 'audio/'
json_data_path = dataset_path / 'examples.json'
dataset = NSynth(audio_directory_paths=audio_directory_path,
json_data_path=json_data_path,
valid_pitch_range=args.valid_pitch_range)
# filenames.extend(dataset.filenames)
all_json_data.update(dataset.json_data)
# sort metadata by sample name
all_json_data = OrderedDict(sorted(all_json_data.items()))
# create json_data split
print('Create json_data split')
json_data_splits_as_lists = train_test_split(
list(all_json_data.items()),
train_size=args.train_size,
random_state=args.random_seed
)
# convert lists back top dictionaries
json_data_train_split, json_data_valid_split = [
{key: value
for key, value in split_as_list}
for split_as_list in json_data_splits_as_lists
]
print('Dump splits')
split_names = ['train', 'valid']
for split_name, json_data_split in zip(
split_names,
[json_data_train_split, json_data_valid_split]):
file_name = 'examples.json'
os.makedirs(MAIN_OUTPUT_DIRECTORY / split_name)
file_path = MAIN_OUTPUT_DIRECTORY / split_name / file_name
with open(file_path, 'w') as f:
json.dump(json_data_split, f, indent=4)