-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_rwe_multi_obj.py
64 lines (47 loc) · 1.95 KB
/
main_rwe_multi_obj.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse
import pytorch_lightning as pl
from problem import NLPProblemRWEMultiObj, DataModule, LightningRecurrentRWE
from evolution import MultiObjectiveOptimizer
import logging
logging.disable(logging.CRITICAL)
def parse_args():
parser = argparse.ArgumentParser()
parser = NLPProblemRWEMultiObj.add_arguments(parser)
parser = pl.Trainer.add_argparse_args(parser)
parser = DataModule.add_argparse_args(parser)
parser = DataModule.add_cache_arguments(parser)
parser = LightningRecurrentRWE.add_model_specific_args(parser)
parser = LightningRecurrentRWE.add_learning_specific_args(parser)
parser = MultiObjectiveOptimizer.add_optimizer_specific_args(parser)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
args.num_terminal = args.num_main + 1
args.l_main = args.h_main * (args.max_arity - 1) + 1
args.l_adf = args.h_adf * (args.max_arity - 1) + 1
args.main_length = args.h_main + args.l_main
args.adf_length = args.h_adf + args.l_adf
args.chromosome_length = (
args.num_main * args.main_length + args.num_adf * args.adf_length
)
args.D = args.chromosome_length
args.mutation_rate = args.adf_length / args.chromosome_length
return args
def main():
# get args
args = parse_args()
# solve source problems
problem = NLPProblemRWEMultiObj(args)
# create optimizer
optimizer = MultiObjectiveOptimizer(args)
# Optimize architecture
population, objs = optimizer.ga(problem)
for i, idv in enumerate(population):
symbols, _, _ = problem.replace_value_with_symbol(population[i])
print(f"Individual {i + 1}: {objs[i]}, chromosome: {symbols}")
problem.make_graph(idv, prefix=f"{args.task_name}.idv_{i+1}")
# build and save model
# lb, ub = problem.get_bounds()
# model = amt.MultinomialModel(population, lb, ub)
# amt.util.save_model(model, args.task_name)
if __name__ == "__main__":
main()