-
Notifications
You must be signed in to change notification settings - Fork 0
/
create_splits.py
67 lines (55 loc) · 2.33 KB
/
create_splits.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import argparse
import glob
import os
import random
import numpy as np
import shutil
def split(data_dir, create_test_ds=False):
"""
Create three splits from the processed records. The files should be moved to new folders in the
same directory. This folder should be named train, val and test.
args:
- data_dir [str]: data directory, /home/workspace/data/waymo
"""
files = [os.path.join(data_dir, f) for f in os.listdir(
os.path.abspath(data_dir)) if f.endswith(".tfrecord")]
if not files:
print("No file found")
return
random.shuffle(files)
if create_test_ds:
train, validation, test = list(
np.split(files, [int(len(files)*0.72), int(len(files)*0.9)]))
else:
train, validation = list(np.split(files, [int(len(files)*0.80)]))
train_folder_path = os.path.join(os.path.dirname(data_dir), "train")
os.makedirs(train_folder_path, exist_ok=True)
for t in list(train):
new_path = os.path.join(train_folder_path, os.path.basename(t))
if os.path.exists(new_path):
print("File already exists")
shutil.move(t, new_path)
validation_folder_path = os.path.join(os.path.dirname(data_dir), "val")
os.makedirs(validation_folder_path, exist_ok=True)
for t in list(validation):
new_path = os.path.join(validation_folder_path, os.path.basename(t))
if os.path.exists(new_path):
print("File already exists")
shutil.move(t, new_path)
if create_test_ds:
test_folder_path = os.path.join(os.path.dirname(data_dir), "test")
os.makedirs(test_folder_path, exist_ok=True)
for t in list(test):
new_path = os.path.join(test_folder_path, os.path.basename(t))
if os.path.exists(new_path):
print("File already exists")
shutil.move(t, new_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Split data into training / validation / testing')
parser.add_argument('--data_dir', required=True,
help='data directory')
parser.add_argument('--add_test_ds', required=False, default=True, nargs='?', const=True,
help='Split into 3 dataset train/validation/test instead train/validation')
args = parser.parse_args()
split(args.data_dir)