Skip to content

Commit

Permalink
Merge tag 'itrb-deployment-20240309' into production
Browse files Browse the repository at this point in the history
  • Loading branch information
edeutsch committed Apr 12, 2024
2 parents 67a1640 + 425f1ed commit 070a029
Show file tree
Hide file tree
Showing 6 changed files with 277 additions and 60 deletions.
2 changes: 2 additions & 0 deletions code/ARAX/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
old/*

Testing/sample_kg2_queries
2 changes: 1 addition & 1 deletion code/ARAX/ARAXQuery/ARAX_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ def execute_processing_plan(self,input_operations_dict, mode='ARAX'):
return response

#### Immediately after resultify, run the experimental ranker
if action['command'] == 'resultify':
if action['command'] == 'resultify' and mode != 'RTXKG2':
response.info(f"Running experimental reranker on results")
try:
ranker = ARAXRanker()
Expand Down
49 changes: 40 additions & 9 deletions code/ARAX/ARAXQuery/ARAX_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import ast
import re


from typing import Set, Union, Dict, List, Callable
from ARAX_response import ARAXResponse
from query_graph_info import QueryGraphInfo
Expand All @@ -19,6 +20,7 @@
from openapi_server.models.edge import Edge
from openapi_server.models.attribute import Attribute

edge_confidence_manual_agent = 0.999

def _get_nx_edges_by_attr(G: Union[nx.MultiDiGraph, nx.MultiGraph], key: str, val: str) -> Set[tuple]:
res_set = set()
Expand Down Expand Up @@ -46,6 +48,7 @@ def _get_weighted_graph_networkx_from_result_graph(kg_edge_id_to_edge: Dict[str,
qg_edge_key_to_edge_tuple = {edge_tuple[2]: edge_tuple for edge_tuple in qg_edge_tuples}
for analysis in result.analyses: # For now we only ever have one Analysis per Result
for key, edge_binding_list in analysis.edge_bindings.items():
edge_count = 0
for edge_binding in edge_binding_list:
kg_edge = kg_edge_id_to_edge[edge_binding.id]
kg_edge_conf = kg_edge.confidence
Expand All @@ -54,6 +57,9 @@ def _get_weighted_graph_networkx_from_result_graph(kg_edge_id_to_edge: Dict[str,
for qedge_key in qedge_keys:
qedge_tuple = qg_edge_key_to_edge_tuple[qedge_key]
res_graph[qedge_tuple[0]][qedge_tuple[1]][qedge_key]['weight'] += kg_edge_conf
edge_count += 1
if edge_count > 0:
res_graph[qedge_tuple[0]][qedge_tuple[1]][qedge_key]['weight'] /= edge_count
return res_graph


Expand Down Expand Up @@ -230,7 +236,7 @@ def result_confidence_maker(self, result):
if True:
result_confidence = 1 # everybody gets to start with a confidence of 1
for edge in result.edge_bindings:
kg_edge_id = edge.kg_id
kg_edge_id = edge.id
# TODO: replace this with the more intelligent function
# here we are just multiplying the edge confidences
# --- to see what info is going into each result: print(f"{result.essence}: {kg_edges[kg_edge_id].type}, {kg_edges[kg_edge_id].confidence}")
Expand Down Expand Up @@ -550,7 +556,7 @@ def aggregate_scores_dmk(self, response):
kg_edge_id_to_edge = self.kg_edge_id_to_edge
score_stats = self.score_stats
no_non_inf_float_flag = True
for edge_key,edge in message.knowledge_graph.edges.items():
for edge_key, edge in message.knowledge_graph.edges.items():
kg_edge_id_to_edge[edge_key] = edge
if edge.attributes is not None:
for edge_attribute in edge.attributes:
Expand Down Expand Up @@ -585,12 +591,20 @@ def aggregate_scores_dmk(self, response):
f"No non-infinite value was encountered in any edge attribute in the knowledge graph.")
response.info(f"Summary of available edge metrics: {score_stats}")

edge_ids_manual_agent = set()
# Loop over the entire KG and normalize and combine the score of each edge, place that information in the confidence attribute of the edge
for edge_key,edge in message.knowledge_graph.edges.items():
for edge_key, edge in message.knowledge_graph.edges.items():
if edge.attributes is not None:
edge_attributes = {x.original_attribute_name:x.value for x in edge.attributes}
for edge_attribute in edge.attributes:
if edge_attribute.attribute_type_id == "biolink:agent_type" and edge_attribute.value == "manual_agent":
edge_attributes['confidence'] = edge_confidence_manual_agent
edge.confidence = edge_confidence_manual_agent
edge_ids_manual_agent.add(edge_key)
break
else:
edge_attributes = {}

if edge_attributes.get("confidence", None) is not None:
#if False: # FIXME: there is no longer such an attribute. Stored as a generic attribute?
#if edge.confidence is not None:
Expand All @@ -607,12 +621,34 @@ def aggregate_scores_dmk(self, response):
# 2. number of edges in the results
# 3. possibly conflicting information, etc.

results = message.results

edge_set_to_high_confidence = set()
for result in results:
edge_bindings = result.analyses[0].edge_bindings
for qedge_key in message.query_graph.edges.keys():
all_edges_for_qedge_are_high_confidence = False
bound_edges = edge_bindings.get(qedge_key, [])
for edge_name in bound_edges:
edge_id = edge_name.id
if edge_id in edge_ids_manual_agent:
all_edges_for_qedge_are_high_confidence = True
break
if all_edges_for_qedge_are_high_confidence:
for edge_name in bound_edges:
edge_set_to_high_confidence.add(edge_name.id)

for edge_key, edge in message.knowledge_graph.edges.items():
if edge_key in edge_set_to_high_confidence:
print(f"setting max confidence for edge_key: {edge_key}")
edge.confidence = edge_confidence_manual_agent

###################################
# TODO: Replace this with a more "intelligent" separate function
# now we can loop over all the results, and combine their edge confidences (now populated)
qg_nx = _get_query_graph_networkx_from_query_graph(message.query_graph)
kg_edge_id_to_edge = self.kg_edge_id_to_edge
results = message.results

ranks_list = list(map(_quantile_rank_list,
map(lambda scorer_func: _score_result_graphs_by_networkx_graph_scorer(kg_edge_id_to_edge,
qg_nx,
Expand Down Expand Up @@ -644,11 +680,6 @@ def aggregate_scores_dmk(self, response):
for edge_attribute in edge_attributes:
if edge_attribute.original_attribute_name == 'probability_treats' and edge_attribute.value is not None:
result.analyses[0].score = float(edge_attribute.value)





# for result in message.results:
# self.result_confidence_maker(result)
###################################
Expand Down
160 changes: 160 additions & 0 deletions code/ARAX/Testing/sample_kg2_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""
This script grabs and saves a sample of real queries sent to our live KG2 endpoints. It deduplicates all queries
submitted to KG2 instances over the last X hours and saves a random sample of N of those queries to individual
JSON files. It also saves a summary of metadata about the queries in the sample. All output files are saved in a subdir
called 'sample_kg2_queries'.
Usage: python sample_kg2_queries.py <last_n_hours_to_sample_from> <sample_size>
"""
import copy
import csv
import json
import os
import random
import sys

import argparse
from typing import Optional

sys.path.append(os.path.dirname(os.path.abspath(__file__))+"/../ARAXQuery/")
from ARAX_query_tracker import ARAXQueryTracker
sys.path.append(os.path.dirname(os.path.abspath(__file__))+"/../NodeSynonymizer/")
from node_synonymizer import NodeSynonymizer


SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))


def get_num_input_curies(query_message: dict) -> int:
qg = query_message["message"]["query_graph"]
num_qnodes_with_curies = sum([1 for qnode in qg["nodes"].values() if qnode.get("ids")])
if num_qnodes_with_curies == 1:
for qnode in qg["nodes"].values():
if qnode.get("ids"):
return len(qnode["ids"])
return 1


def get_query_hash_key(query_message: dict) -> Optional[str]:
qg = query_message["message"]["query_graph"]
qedge = next(qedge for qedge in qg["edges"].values()) if qg.get("edges") else None
if not qedge: # Invalid query; skip it
return None
subj_qnode_key = qedge["subject"]
obj_qnode_key = qedge["object"]

# Craft the subject portion
subj_qnode = qg["nodes"][subj_qnode_key]
subj_ids_str = ",".join(sorted(subj_qnode.get("ids", [])))
subj_categories_str = ",".join(sorted(subj_qnode.get("categories", [])))
subj_hash_str = f"({subj_ids_str}; {subj_categories_str})"

# Craft the edge portion
predicates_hash_str = ",".join(sorted(qedge.get("predicates", [])))
qualifier_hash_strs = []
for qualifier_blob in qedge.get("qualifier_constraints", []):
qualifier_set = qualifier_blob["qualifier_set"]
qual_predicate, obj_qual_direction, obj_qual_aspect = "", "", ""
for qualifier_item in qualifier_set:
if qualifier_item.get("qualifier_type_id") == "biolink:qualified_predicate":
qual_predicate = qualifier_item["qualifier_value"]
elif qualifier_item.get("qualifier_type_id") == "biolink:object_direction_qualifier":
obj_qual_direction = qualifier_item["qualifier_value"]
elif qualifier_item.get("qualifier_type_id") == "biolink:object_aspect_qualifier":
obj_qual_aspect = qualifier_item["qualifier_value"]
qualifier_hash_strs.append(f"{qual_predicate}--{obj_qual_direction}--{obj_qual_aspect}")
qualifier_hash_str = ",".join(sorted(qualifier_hash_strs))
edge_hash_str = f"[{predicates_hash_str}; {qualifier_hash_str}]"

# Craft the object portion
obj_qnode = qg["nodes"][obj_qnode_key]
obj_ids_str = ",".join(sorted(obj_qnode.get("ids", [])))
obj_categories_str = ",".join(sorted(obj_qnode.get("categories", [])))
obj_hash_str = f"({obj_ids_str}; {obj_categories_str})"

return f"{subj_hash_str}--{edge_hash_str}-->{obj_hash_str}"


def main():
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("last_n_hours", help="Number of hours prior to the present to select the sample from")
arg_parser.add_argument("sample_size", help="Number of KG2 queries to include in random sample")
args = arg_parser.parse_args()

qt = ARAXQueryTracker()

last_n_hours = float(args.last_n_hours)
sample_size = int(args.sample_size)

print(f"Getting all queries in last {last_n_hours} hours")
queries = qt.get_entries(last_n_hours=float(last_n_hours))
print(f"Got {len(queries)} queries back (for all instance types)")

# Filter down only to KG2 queries
kg2_queries = [query for query in queries if query.instance_name == "kg2"]
print(f"There were a total of {len(kg2_queries)} KG2 queries in the last {last_n_hours} hours")

# Deduplicate queries
print(f"Deduplicating KG2 queries..")
node_synonymizer = NodeSynonymizer()
deduplicated_queries = dict()
for query in kg2_queries:
# Canonicalize any input curies
canonicalized_query = copy.deepcopy(query.input_query)
for qnode in canonicalized_query["message"]["query_graph"]["nodes"].values():
qnode_ids = qnode.get("ids")
if qnode_ids:
canonicalized_ids_dict = node_synonymizer.get_canonical_curies(qnode_ids)
canonicalized_ids = set()
for input_id, canonicalized_info in canonicalized_ids_dict.items():
if canonicalized_info:
canonicalized_ids.add(canonicalized_info["preferred_curie"])
else:
canonicalized_ids.add(input_id) # Just send the ID as is if synonymizer doesn't recognize it
qnode["ids"] = list(canonicalized_ids)

# Figure out if we've seen this query before using hash keys
hash_key = get_query_hash_key(canonicalized_query)
if hash_key and hash_key not in deduplicated_queries:
deduplicated_queries[hash_key] = {"query_id": query.query_id,
"query_hash_key": hash_key,
"start_datetime": query.start_datetime,
"submitter": query.origin,
"instance_name": query.instance_name,
"domain": query.domain,
"elapsed": query.elapsed,
"message_code": query.message_code,
"input_query": query.input_query,
"input_query_canonicalized": canonicalized_query}
print(f"After deduplication, there were {len(deduplicated_queries)} unique KG2 queries "
f"in the last {last_n_hours} hours ({round(len(deduplicated_queries)/len(kg2_queries), 2)*100}%)")

# Create a subdir to save sample queries/metadata if one doesn't already exist
sample_subdir = f"{SCRIPT_DIR}/sample_kg2_queries"
if not os.path.exists(sample_subdir):
os.system(f"mkdir {sample_subdir}")

# Grab a random sample of queries from the deduplicated set and save them to json files
print(f"Saving a random sample of {sample_size} deduplicated KG2 queries..")
random_selection = random.sample(list(deduplicated_queries), sample_size)
for query_hash_key in random_selection:
query_dict = deduplicated_queries[query_hash_key]
with open(f"{sample_subdir}/query_{query_dict['query_id']}.json", "w+") as query_file:
json.dump(query_dict, query_file, indent=2)

# Save a summary of the sample of queries for easier analysis
print(f"Saving a summary of the query sample..")
summary_col_names = ["query_id", "submitter", "instance_name", "domain", "start_datetime", "elapsed",
"message_code", "query_hash_key"]
with open(f"{sample_subdir}/sample_summary.tsv", "w+") as summary_file:
tsv_writer = csv.writer(summary_file, delimiter="\t")
tsv_writer.writerow(summary_col_names)
for query_hash_key in random_selection:
query_dict = deduplicated_queries[query_hash_key]
row = [query_dict[col_name] for col_name in summary_col_names]
tsv_writer.writerow(row)

print(f"Done.")


if __name__ == "__main__":
main()
Loading

0 comments on commit 070a029

Please sign in to comment.