Skip to content

Commit

Permalink
Merge branch 'celltype_annotation_automl' of https://github.com/Omics…
Browse files Browse the repository at this point in the history
…ML/dance into celltype_annotation_automl
  • Loading branch information
xingzhongyu committed Nov 8, 2024
2 parents 5d32d2e + 165e9f3 commit 5edd74d
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 14 deletions.
92 changes: 92 additions & 0 deletions dataset_server.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
{
"cta_actinn": [
"01209dce-3575-4bed-b1df-129f57fbc031",
"055ca631-6ffb-40de-815e-b931e10718c0",
"2a498ace-872a-4935-984b-1afa70fd9886",
"2adb1f8a-a6b1-4909-8ee8-484814e2d4bf",
"3faad104-2ab8-4434-816d-474d8d2641db",
"471647b3-04fe-4c76-8372-3264feb950e8",
"4c4cd77c-8fee-4836-9145-16562a8782fe",
"84230ea4-998d-4aa8-8456-81dd54ce23af",
"8a554710-08bc-4005-87cd-da9675bdc2e7",
"ae29ebd0-1973-40a4-a6af-d15a5f77a80f",
"bc260987-8ee5-4b6e-8773-72805166b3f7",
"bc2a7b3d-f04e-477e-96c9-9d5367d5425c",
"d3566d6a-a455-4a15-980f-45eb29114cab",
"d9b4bc69-ed90-4f5f-99b2-61b0681ba436",
"eeacb0c1-2217-4cf6-b8ce-1f0fedf1b569",
"c7775e88-49bf-4ba2-a03b-93f00447c958",
"456e8b9b-f872-488b-871d-94534090a865",
"738942eb-ac72-44ff-a64b-8943b5ecd8d9",
"a5d95a42-0137-496f-8a60-101e17f263c8",
"71be997d-ff75-41b9-8a9f-1288c865f921"
]
,
"cta_celltypist": [
"01209dce-3575-4bed-b1df-129f57fbc031",
"055ca631-6ffb-40de-815e-b931e10718c0",
"2a498ace-872a-4935-984b-1afa70fd9886",
"2adb1f8a-a6b1-4909-8ee8-484814e2d4bf",
"3faad104-2ab8-4434-816d-474d8d2641db",
"471647b3-04fe-4c76-8372-3264feb950e8",
"4c4cd77c-8fee-4836-9145-16562a8782fe",
"84230ea4-998d-4aa8-8456-81dd54ce23af",
"8a554710-08bc-4005-87cd-da9675bdc2e7",
"ae29ebd0-1973-40a4-a6af-d15a5f77a80f",
"bc260987-8ee5-4b6e-8773-72805166b3f7",
"bc2a7b3d-f04e-477e-96c9-9d5367d5425c",
"d3566d6a-a455-4a15-980f-45eb29114cab",
"d9b4bc69-ed90-4f5f-99b2-61b0681ba436",
"eeacb0c1-2217-4cf6-b8ce-1f0fedf1b569",
"c7775e88-49bf-4ba2-a03b-93f00447c958",
"456e8b9b-f872-488b-871d-94534090a865",
"738942eb-ac72-44ff-a64b-8943b5ecd8d9",
"a5d95a42-0137-496f-8a60-101e17f263c8",
"71be997d-ff75-41b9-8a9f-1288c865f921"
],
"cta_scdeepsort": [
"01209dce-3575-4bed-b1df-129f57fbc031",
"055ca631-6ffb-40de-815e-b931e10718c0",
"2a498ace-872a-4935-984b-1afa70fd9886",
"2adb1f8a-a6b1-4909-8ee8-484814e2d4bf",
"3faad104-2ab8-4434-816d-474d8d2641db",
"471647b3-04fe-4c76-8372-3264feb950e8",
"4c4cd77c-8fee-4836-9145-16562a8782fe",
"84230ea4-998d-4aa8-8456-81dd54ce23af",
"8a554710-08bc-4005-87cd-da9675bdc2e7",
"ae29ebd0-1973-40a4-a6af-d15a5f77a80f",
"bc260987-8ee5-4b6e-8773-72805166b3f7",
"bc2a7b3d-f04e-477e-96c9-9d5367d5425c",
"d3566d6a-a455-4a15-980f-45eb29114cab",
"d9b4bc69-ed90-4f5f-99b2-61b0681ba436",
"eeacb0c1-2217-4cf6-b8ce-1f0fedf1b569",
"c7775e88-49bf-4ba2-a03b-93f00447c958",
"456e8b9b-f872-488b-871d-94534090a865",
"738942eb-ac72-44ff-a64b-8943b5ecd8d9",
"a5d95a42-0137-496f-8a60-101e17f263c8",
"71be997d-ff75-41b9-8a9f-1288c865f921"
]
,
"cta_singlecellnet": [
"01209dce-3575-4bed-b1df-129f57fbc031",
"055ca631-6ffb-40de-815e-b931e10718c0",
"2a498ace-872a-4935-984b-1afa70fd9886",
"2adb1f8a-a6b1-4909-8ee8-484814e2d4bf",
"3faad104-2ab8-4434-816d-474d8d2641db",
"471647b3-04fe-4c76-8372-3264feb950e8",
"4c4cd77c-8fee-4836-9145-16562a8782fe",
"84230ea4-998d-4aa8-8456-81dd54ce23af",
"8a554710-08bc-4005-87cd-da9675bdc2e7",
"ae29ebd0-1973-40a4-a6af-d15a5f77a80f",
"bc260987-8ee5-4b6e-8773-72805166b3f7",
"bc2a7b3d-f04e-477e-96c9-9d5367d5425c",
"d3566d6a-a455-4a15-980f-45eb29114cab",
"d9b4bc69-ed90-4f5f-99b2-61b0681ba436",
"eeacb0c1-2217-4cf6-b8ce-1f0fedf1b569",
"c7775e88-49bf-4ba2-a03b-93f00447c958",
"456e8b9b-f872-488b-871d-94534090a865",
"738942eb-ac72-44ff-a64b-8943b5ecd8d9",
"a5d95a42-0137-496f-8a60-101e17f263c8",
"71be997d-ff75-41b9-8a9f-1288c865f921"
]
}
130 changes: 116 additions & 14 deletions get_result_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,19 @@

import numpy as np
import pandas as pd
from natsort import os_sort_key
from omegaconf import OmegaConf
from sympy import im
from tqdm import tqdm

from dance.utils import try_import

wandb = try_import("wandb")
entity = "xzy11632"
project = "dance-dev"
collect_datasets = {
"cta_celltypist": [
"c7775e88-49bf-4ba2-a03b-93f00447c958", "456e8b9b-f872-488b-871d-94534090a865",
"738942eb-ac72-44ff-a64b-8943b5ecd8d9", "a5d95a42-0137-496f-8a60-101e17f263c8",
"71be997d-ff75-41b9-8a9f-1288c865f921"
]
}
file_root = "/egr/research-dselab/dingjia5/zhongyu/dance/examples/tuning"
with open("dataset_server.json") as f:
collect_datasets = json.load(f)
file_root = "/home/zyxing/dance/examples/tuning"


def check_identical_strings(string_list):
Expand Down Expand Up @@ -52,24 +50,128 @@ def get_sweep_url(step_csv: pd.DataFrame, single=True):
return sweep_url


import re


def spilt_web(url: str):
pattern = r"https://wandb\.ai/([^/]+)/([^/]+)/sweeps/([^/]+)"

match = re.search(pattern, url)

if match:
entity = match.group(1)
project = match.group(2)
sweep_id = match.group(3)

return entity, project, sweep_id
else:
print(url)
print("No match found")


def get_best_method(urls, metric_col="test_acc"):
all_best_run = None
all_best_step_name = None
step_names = ["step2", "step3_0", "step3_1", "step3_2"]

def get_metric(run):
if metric_col not in run.summary:
return float('-inf')
else:
return run.summary[metric_col]

for step_name, url in zip(step_names, urls):
_, _, sweep_id = spilt_web(url)
sweep = wandb.Api().sweep(f"{entity}/{project}/{sweep_id}")
goal = sweep.config["metric"]["goal"]
if goal == "maximize":
best_run = max(sweep.runs, key=get_metric)
elif goal == "minimize":
best_run = min(sweep.runs, key=get_metric)
else:
raise RuntimeError("choose goal in ['minimize','maximize']")
if metric_col not in best_run.summary:
continue
if all_best_run is None:
all_best_run = best_run
all_best_step_name = step_name
elif all_best_run.summary[metric_col] < best_run.summary[metric_col] and goal == "maximize":
all_best_run = best_run
all_best_step_name = step_name
elif all_best_run.summary[metric_col] > best_run.summary[metric_col] and goal == "minimize":
all_best_run = best_run
all_best_step_name = step_name
return all_best_step_name, all_best_run


def get_best_yaml(step_name, best_run, file_path):
if step_name == "step2":
conf = OmegaConf.load(f"{file_path}/pipeline_params_tuning_config.yaml")
for i, fun in enumerate(conf["pipeline"]):
if "include" not in fun:
continue
type_fun = fun["type"]
prefix = f"pipeline.{i}.{type_fun}"
# filtered_dict = {k: v for k, v in b_run.config.items() if k==prefix}.items()[0]
fun_name = best_run.config[prefix]
fun['target'] = fun_name
if 'params' not in fun:
fun['params'] = {}
if "default_params" in fun and fun_name in fun["default_params"]:
fun['params'].update(fun["default_params"][fun_name])
del fun["include"]
del fun["default_params"]
else:
step3_number = step_name.split("_")[1]
conf = OmegaConf.load(f"{file_path}/config_yamls/params/{step3_number}_test_acc_params_tuning_config.yaml")
for i, fun in enumerate(conf['pipeline']):
if 'params_to_tune' not in fun:
continue
target = fun["target"]
prefix = f"params.{i}.{target}"
filtered_dict = {k: v for k, v in best_run.config.items() if k.startswith(prefix)}
for k, v in filtered_dict.items():
param_name = k.split(".")[-1]
fun['params_to_tune'][param_name] = v
if "params" not in fun:
fun["params"] = {}
fun["params"].update(fun['params_to_tune'])
del fun["params_to_tune"]
return OmegaConf.to_yaml(conf["pipeline"])


def check_exist(file_path):
file_path = f"{file_path}/results/params/"
if os.path.exists(file_path) and os.path.isdir(file_path):
file_num = len(os.listdir(file_path))
return file_num > 1
else:
return False


def write_ans():
ans = []
for method_folder in tqdm(collect_datasets):
for dataset_id in collect_datasets[method_folder]:
file_path = f"{file_root}/{method_folder}/{dataset_id}/results"
step2_url = get_sweep_url(pd.read_csv(f"{file_path}/pipeline/best_test_acc.csv"))
file_path = f"{file_root}/{method_folder}/{dataset_id}"
if not check_exist(file_path):
continue
step2_url = get_sweep_url(pd.read_csv(f"{file_path}/results/pipeline/best_test_acc.csv"))
step3_urls = []
for i in range(3):
file_csv = f"{file_path}/params/{i}_best_test_acc.csv"
file_csv = f"{file_path}/results/params/{i}_best_test_acc.csv"
if not os.path.exists(file_csv): #no parameter
print(f"文件 {file_csv} 不存在,跳过。")
continue
step3_urls.append(get_sweep_url(pd.read_csv(file_csv)))
step3_str = ",".join(step3_urls)
step_str = f"step2:{step2_url}|step3:{step3_str}"
ans.append({"Dataset_id": dataset_id, method_folder: step_str})
with open('temp_ans.json', 'w') as f:
json.dump(ans, f, indent=4)
step_name, best_run = get_best_method([step2_url] + step3_urls)
best_yaml = get_best_yaml(step_name, best_run, file_path)
ans.append({"Dataset_id": dataset_id, method_folder: step_str, "best_yaml": best_yaml})
# with open('temp_ans.json', 'w') as f:
# json.dump(ans, f,indent=4)
pd.DataFrame(ans).to_csv("temp_ans.csv")


write_ans()

0 comments on commit 5edd74d

Please sign in to comment.