-
Notifications
You must be signed in to change notification settings - Fork 9
/
restrict_keys.py
49 lines (41 loc) · 1.46 KB
/
restrict_keys.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import json
import csv
from pathlib import Path
from tqdm import tqdm
import argtyped
class Arguments(argtyped.Arguments):
keys: Path = Path("indoor-keys.txt")
input: Path = Path("airbnb-train.tsv")
output: Path = Path("airbnb-train-indoor.tsv")
if __name__ == "__main__":
args = Arguments()
print(args.to_string(width=80))
with open(args.keys) as fid:
photo_ids = {list(map(int, k.strip().split("-")))[1] for k in fid.readlines()}
with open(args.input, newline="") as fid:
reader = csv.DictReader(
fid, delimiter="\t", fieldnames=("listing_id", "photo_id", "url", "caption")
)
total = 0
for _ in reader:
total += 1
with open(args.input, newline="") as fid:
with open(args.output, "w", newline="") as out:
reader = csv.DictReader(
fid,
delimiter="\t",
fieldnames=("listing_id", "photo_id", "url", "caption"),
)
writer = csv.DictWriter(
out,
delimiter="\t",
fieldnames=("listing_id", "photo_id", "url", "caption"),
)
counter = 0
for row in tqdm(reader, total=total):
photo_id = int(row["photo_id"])
if photo_id in photo_ids:
writer.writerow(row)
photo_ids.remove(photo_id)
counter += 1
print("Add", counter, "items")