-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathparse_hyperparameters.py
47 lines (38 loc) · 1.48 KB
/
parse_hyperparameters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
import json
import pandas as pd
from tqdm import tqdm
from collections import defaultdict
def remove_parameter_name_prefix(parameter_name: str) -> str:
if not parameter_name.startswith("params_"):
return parameter_name
else:
return parameter_name[len("params_") :]
def parse_tuning_results() -> None:
result_path = "./tuning"
best_hyperparameters = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
tuning_result_files = [
file for file in os.listdir(result_path) if file.endswith(".csv")
]
for tuning_result_file in tqdm(tuning_result_files):
results = pd.read_csv(
os.path.join(result_path, tuning_result_file), index_col=0
)
results = results.drop(["number", "state"], axis=1)
results = results[results["value"] == results["value"].max()]
results = results.reset_index().iloc[0].to_dict()
results = {
remove_parameter_name_prefix(param_name): value
for param_name, value in results.items()
}
results = {
key: value
for key, value in results.items()
if key not in ["index", "value"]
}
language, track, model = tuning_result_file[:-4].split("=")[1].split("-")
best_hyperparameters[language][track][model] = results
with open("best_hyperparameters.json", "w") as hsf:
json.dump(best_hyperparameters, hsf)
if __name__ == "__main__":
parse_tuning_results()