-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_best_network.py
78 lines (57 loc) · 2.06 KB
/
eval_best_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import argparse
import pytorch_lightning as pl
from problem import DataModule, EvalBestModel, BestModel
from evolution import Optimizer
def parse_args():
parser = argparse.ArgumentParser()
parser = EvalBestModel.add_arguments(parser)
parser = pl.Trainer.add_argparse_args(parser)
parser = DataModule.add_argparse_args(parser)
parser = DataModule.add_cache_arguments(parser)
parser = Optimizer.add_optimizer_specific_args(parser)
parser = BestModel.add_model_specific_args(parser)
parser = BestModel.add_learning_specific_args(parser)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--early_stop", type=int, default=0)
args = parser.parse_args()
args.num_terminal = args.num_main + 1
args.l_main = args.h_main * (args.max_arity - 1) + 1
args.l_adf = args.h_adf * (args.max_arity - 1) + 1
args.main_length = args.h_main + args.l_main
args.adf_length = args.h_adf + args.l_adf
args.chromosome_length = (
args.num_main * args.main_length + args.num_adf * args.adf_length
)
args.D = args.chromosome_length
args.mutation_rate = args.adf_length / args.chromosome_length
return args
def main_source():
# get args
args = parse_args()
# solve source problems
problem = EvalBestModel(args)
problem.progress_bar = 10
problem.weights_summary = "top"
if args.early_stop > 0:
problem.early_stop = args.early_stop
problem.evaluate(None)
# def main_target():
# # get args
# args = parse_args()
# # load source models
# names, models = amt.util.load_models()
# # solve source problems
# problem = GLUEProblem(args)
# # create optimizer
# optimizer = Optimizer(args)
# # Optimize architecture
# population, fitness = optimizer.transfer_ga(problem, models)
# # build and save model
# lb, ub = problem.get_bounds()
# model = amt.MultinomialModel(population, lb, ub)
# amt.util.save_model(model, args.task_name)
def main():
main_source()
# main_target()
if __name__ == "__main__":
main()