-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
94 lines (77 loc) · 4.09 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# -*- coding: utf-8 -*-
"""
Created on 4/4/2019
@author: RuihongQiu
"""
import os
import argparse
import logging
import torch
import time
from tqdm import tqdm
from dataset import MultiSessionsGraph
from torch_geometric.data import DataLoader
from model import GNNModel
# from sort_pooling_model import SortPoolModel
# from virtual_node_model import VirtualNodeModel
# from set2set_model import Set2SetModel
# from set2set_rnn_model import Set2SetATTModel
# from virtual_node_rnn_model import VirtualNodeRNNModel
from train import forward
from tensorboardX import SummaryWriter
# Logger configuration
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(filename)s[line:%(lineno)d] %(message)s')
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='diginetica', help='dataset name: diginetica/yoochoose1_4/yoochoose1_64/sample')
parser.add_argument('--batch_size', type=int, default=100, help='input batch size')
parser.add_argument('--hidden_size', type=int, default=100, help='hidden state size')
parser.add_argument('--epoch', type=int, default=10, help='the number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate') # [0.001, 0.0005, 0.0001]
parser.add_argument('--lr_dc', type=float, default=0.1, help='learning rate decay rate')
parser.add_argument('--lr_dc_step', type=int, default=3, help='the number of steps after which the learning rate decay')
parser.add_argument('--l2', type=float, default=1e-5, help='l2 penalty') # [0.001, 0.0005, 0.0001, 0.00005, 0.00001]
parser.add_argument('--top_k', type=int, default=20, help='top K indicator for evaluation')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum for SGD')
opt = parser.parse_args()
logging.warning(opt)
def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cur_dir = os.getcwd()
# custom dataset
train_dataset = MultiSessionsGraph(cur_dir + '/datasets/' + opt.dataset, phrase='train')
train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True)
test_dataset = MultiSessionsGraph(cur_dir + '/datasets/' + opt.dataset, phrase='test')
test_loader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=False)
# log_dir = cur_dir + '/log/' + str(opt.dataset) + '/' + str(opt) + '_s2s3_linear_gat8-1_noleaky_' + time.strftime(
# "%Y-%m-%d %H:%M:%S", time.localtime())
log_dir = cur_dir + '/log/' + str(opt.dataset) + '/' + 'model_log'
if not os.path.exists(log_dir):
os.makedirs(log_dir)
logging.warning('logging to {}'.format(log_dir))
writer = SummaryWriter(log_dir)
if opt.dataset == 'diginetica':
n_node = 43097
elif opt.dataset == 'yoochoose1_64' or opt.dataset == 'yoochoose1_4':
n_node = 37483
else:
n_node = 309
model = GNNModel(hidden_size=opt.hidden_size, n_node=n_node).to(device)
# model = SortPoolModel(hidden_size=opt.hidden_size, n_node=n_node).to(device)
# model = Set2SetModel(hidden_size=opt.hidden_size, n_node=n_node).to(device)
# model = GINSet2SetModel(hidden_size=opt.hidden_size, n_node=n_node).to(device)
# model = VirtualNodeModel(hidden_size=opt.hidden_size, n_node=n_node).to(device)
# model = Set2SetATTModel(hidden_size=opt.hidden_size, n_node=n_node).to(device)
# model = VirtualNodeRNNModel(hidden_size=opt.hidden_size, n_node=n_node).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.l2)
# optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr, weight_decay=opt.l2, momentum=opt.momentum)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opt.lr_dc_step, gamma=opt.lr_dc)
logging.warning(model)
for epoch in tqdm(range(opt.epoch)):
scheduler.step()
_, _ = forward(model, train_loader, device, writer, epoch, top_k=opt.top_k, optimizer=optimizer, train_flag=True)
with torch.no_grad():
h, m = forward(model, test_loader, device, writer, epoch, top_k=opt.top_k, train_flag=False)
print(h, m)
if __name__ == '__main__':
main()