-
Notifications
You must be signed in to change notification settings - Fork 1
/
gen_graph.py
62 lines (45 loc) · 1.86 KB
/
gen_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import argparse
import itertools
import pandas as pd
import numpy as np
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument('--data', default='movie', type=str, help='dataset name')
parser.add_argument('--add_reverse', default=True, action='store_true')
args = parser.parse_args()
print(args)
df = pd.read_csv('DATA/{}/edges.csv'.format(args.data))
num_nodes = max(int(df['src'].max()), int(df['dst'].max())) + 1
print('num_nodes: ', num_nodes)
ext_full_indptr = np.zeros(num_nodes + 1, dtype=np.int32)
ext_full_indices = [[] for _ in range(num_nodes)]
ext_full_ts = [[] for _ in range(num_nodes)]
ext_full_eid = [[] for _ in range(num_nodes)]
for idx, row in tqdm(df.iterrows(), total=len(df)):
src = int(row['src'])
dst = int(row['dst'])
ext_full_indices[src].append(dst)
ext_full_ts[src].append(row['time'])
ext_full_eid[src].append(idx)
if args.add_reverse:
ext_full_indices[dst].append(src)
ext_full_ts[dst].append(row['time'])
ext_full_eid[dst].append(idx)
for i in tqdm(range(num_nodes)):
ext_full_indptr[i + 1] = ext_full_indptr[i] + len(ext_full_indices[i])
ext_full_indices = np.array(list(itertools.chain(*ext_full_indices)))
ext_full_ts = np.array(list(itertools.chain(*ext_full_ts)))
ext_full_eid = np.array(list(itertools.chain(*ext_full_eid)))
print('Sorting...')
def tsort(i, indptr, indices, t, eid):
beg = indptr[i]
end = indptr[i + 1]
sidx = np.argsort(t[beg:end])
indices[beg:end] = indices[beg:end][sidx]
t[beg:end] = t[beg:end][sidx]
eid[beg:end] = eid[beg:end][sidx]
for i in tqdm(range(ext_full_indptr.shape[0] - 1)):
tsort(i, ext_full_indptr, ext_full_indices, ext_full_ts, ext_full_eid)
print('saving...')
np.savez('DATA/{}/ext_full.npz'.format(args.data), indptr=ext_full_indptr,
indices=ext_full_indices, ts=ext_full_ts, eid=ext_full_eid)