-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
37 lines (31 loc) · 964 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import numpy as np
import sys
from src.mgr import manager
from src.dataset_loader import init_dataset
from src.services import init_service
import torch
import random
import os
import ray
# Ray Initialization
# num_cpus = os.cpu_count()
# ray.init(num_cpus = num_cpus)
def main():
runName, newRun, serviceType, randomRun, ablationType = sys.argv[1:]
newRun = newRun.lower() == "true"
if runName.strip().lower() == "none":
runName = None
manager(runName, newRun, serviceType, randomRun, ablationType)
# Fix Randomness
seed = manager.settingsConfig.train.seed
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
dataset_loader = init_dataset(manager.dataConfig.loaderName)
service = init_service(serviceType, manager.service_name, dataset_loader)
service()
if __name__ == "__main__":
main()