-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
160 lines (133 loc) · 5.22 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import argparse
import logging
import os
from typing import Any, Dict, cast
import pandas as pd
from io_utils import export_individual, parse_args, pass_run, validate
from ga import GA
from lib.history2vec import History2VecResult
from lib.julia_initializer import JuliaInitializer
def config_logging(target_data: str, mutation_rate: float, population_size: int, cross_rate: float):
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
filename=f"log/{target_data}_mutation_rate_{mutation_rate}_population_{population_size}_cross_rate_{cross_rate}.log",
)
logging.info(
f"Start GA with population_size={population_size}, mutation_rate={mutation_rate}, cross_rate={cross_rate}"
)
def run(
target: History2VecResult,
target_data: str,
num_generations: int,
population_size: int,
mutation_rate: float,
cross_rate: float,
jl_main: Any,
thread_num: int,
archive_dir: str,
) -> list:
"""GAを実行し,最も適応度の高い個体の適応度,履歴ベクトル,パラメータ,10個の指標を返す.
Args:
target (History2VecResult): ターゲットの10個の指標
target_data (str): ターゲットデータ
num_generations (int): 世代数
population_size (int): 個体数
mutation_rate (float): 突然変異率
cross_rate (float): 交叉率
jl_main (Any): Juliaのmain関数
thread_num (int): Juliaのスレッド数
archive_dir (str): アーカイブの出力先
Returns:
list: 最も適応度の高い個体の適応度,履歴ベクトル,パラメータ,10個の指標
"""
result = []
logging.info(f"Target Data: {target_data}")
logging.info(f"Target Metrics: {target}")
ga = GA(
target=target,
num_generations=num_generations,
population_size=population_size,
mutation_rate=mutation_rate,
cross_rate=cross_rate,
jl_main=jl_main,
thread_num=thread_num,
archive_dir=archive_dir,
)
min_fitness, target_vec, params, ten_metrics = ga.run()
logging.info(f"min_fitness={min_fitness}, target_vec={target_vec}, params={params}, ten_metrics={ten_metrics}")
result.append((min_fitness, target_vec, params, ten_metrics))
# sort by fitness
result = sorted(result, key=lambda x: x[0])
min_distance, target, best_individual, metrics = result[0]
return min_distance, target, best_individual, metrics
def main():
"""実行時にターゲットデータを読み込み,それに対して最も適応度の高いパラメータを遺伝的アルゴリズムで探索する."""
# parse arguments
parser = argparse.ArgumentParser()
args = parse_args(parser)
population_size, mutation_rate, cross_rate = (
args.population_size,
args.mutation_rate,
args.cross_rate,
)
validate(population_size, mutation_rate, cross_rate)
target_data = args.target_data
# read target data
if target_data == "synthetic":
os.makedirs(f"./log/{target_data}", exist_ok=True)
rho = args.rho
nu = args.nu
s = args.s
history2vec_results = pd.read_csv("../data/synthetic_target.csv").groupby(["rho", "nu", "s"]).mean()
row = history2vec_results.query(f"rho == {rho} and nu == {nu} and s == '{s}'").iloc[0]
target = History2VecResult(
gamma=row.gamma,
no=row.no,
nc=row.nc,
oo=row.oo,
oc=row.oc,
c=row.c,
y=row.y,
g=row.g,
r=row.r,
h=row.h,
)
target_data = f"synthetic/rho{rho}_nu{nu}_s{s}"
num_generations = 100
else:
target_csv = f"../data/{target_data}.csv"
df = cast(Dict[str, float], pd.read_csv(target_csv).iloc[0].to_dict())
target = History2VecResult(**df)
num_generations = 500
# setting output directory
output_base_dir = f"./results/{target_data}"
os.makedirs(os.path.join(output_base_dir, "archives"), exist_ok=True)
output_fp = os.path.join(output_base_dir, "best.csv")
# configure logging
config_logging(target_data, mutation_rate, population_size, cross_rate)
# check if the run is already finished
if pass_run(args.force, output_fp):
logging.info("GA is skipped.")
print("GA search is skipped. Use --force option to run GA.")
return
elif args.force:
print("GA is forced to run. This may overwrite existing result.")
logging.warning("GA is forced to run.")
# Set Up Julia
jl_main, thread_num = JuliaInitializer().initialize()
min_distance, _, best_individual, _ = run(
target=target,
target_data=target_data,
num_generations=num_generations,
population_size=population_size,
mutation_rate=mutation_rate,
cross_rate=cross_rate,
jl_main=jl_main,
thread_num=thread_num,
archive_dir=os.path.join(output_base_dir, "archives"),
)
export_individual(min_distance, best_individual, population_size, mutation_rate, cross_rate, output_fp)
logging.info(f"Finihsed GA. Result is dumped to {target_data}")
if __name__ == "__main__":
main()