Skip to content

Commit

Permalink
Merge branch 'master' into itrb-test #2259
Browse files Browse the repository at this point in the history
  • Loading branch information
sundareswarpullela committed Apr 16, 2024
2 parents dd847bd + b954bc3 commit 1a59c03
Show file tree
Hide file tree
Showing 67 changed files with 22,349 additions and 8,350 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ on:
workflow_dispatch:

push:
branches: [ master, production, itrb-test ]
branches: [ master, production, itrb-test, dev ]
paths:
- 'code/**'
- 'DockerBuild/**'
- 'requirements.txt'
- '.github/workflows/pytest.yml'
pull_request:
branches: [ master, production, itrb-test ]
branches: [ master, production, itrb-test, dev ]
paths:
- 'code/**'
- 'DockerBuild/**'
Expand Down
8 changes: 3 additions & 5 deletions ISSUE_TEMPLATES/kg2rollout.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,8 @@ Host arax.ncats.io
The following databases should be rebuilt and copies of them should be put in `/home/rtxconfig/KG2.X.Y` on `arax-databases.rtx.ai`. Please use this kind of naming format: `mydatabase_v1.0_KG2.X.Y.sqlite`.

- [ ] NGD database (how-to is [here](https://github.com/RTXteam/RTX/blob/master/code/ARAX/ARAXQuery/Overlay/ngd/README.md))
- [ ] refreshed DTD @chunyuma
- [ ] DTD model @chunyuma _(may be skipped - depends on the changes in this KG2 version)_
- [ ] DTD database @chunyuma _(may be skipped - depends on the changes in this KG2 version)_
- [ ] XDTD database @chunyuma
- [ ] refreshed XDTD database @chunyuma
- [ ] XDTD database @chunyuma _(may be skipped - depends on the changes in this KG2 version)_

**NOTE**: As databases are rebuilt, `RTX/code/config_dbs.json` will need to be updated to point to their new paths! Push these changes to the branch for this KG2 version, unless the rollout of this KG2 version has already occurred, in which case you should push to `master` (but first follow the steps described [here](https://github.com/RTXteam/RTX/wiki/Config,-databases,-and-SFTP#config_dbsjson)).

Expand Down Expand Up @@ -153,7 +151,7 @@ Before rolling out, we need to pre-upload the new databases (referenced in `conf
- [ ] `ssh [email protected]`
- [ ] `cd RTX`
- [ ] `git pull origin master`
- [ ] If there have been changes to `requirements.txt`, make sure to do `~/venv3.9/bin/pip3 install -r code/requirements.txt`
- [ ] If there have been changes to `requirements.txt`, make sure to do `~/venv3.9/bin/pip3 install -r requirements.txt`
- [ ] `sudo bash`
- [ ] `mkdir -m 777 /mnt/data/orangeboard/databases/KG2.X.Y`
- [ ] `exit`
Expand Down
340 changes: 119 additions & 221 deletions code/ARAX/ARAXQuery/ARAX_connect.py

Large diffs are not rendered by default.

48 changes: 32 additions & 16 deletions code/ARAX/ARAXQuery/ARAX_expander.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,9 +368,7 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):
response.update_query_plan(qedge_key, 'edge_properties', 'status', 'Expanding')
for kp in kp_selector.valid_kps:
response.update_query_plan(qedge_key, kp, 'Waiting', 'Prepping query to send to KP')
message.query_graph.edges[qedge_key].filled = True # Mark as expanded in overarching QG #1848
qedge = query_graph.edges[qedge_key]
qedge.filled = True # Also mark as expanded in local QG #1848

# Create a query graph for this edge (that uses curies found in prior steps)
one_hop_qg = self._get_query_graph_for_edge(qedge_key, query_graph, overarching_kg, log)
Expand Down Expand Up @@ -400,6 +398,10 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):
if log.status != 'OK':
return response

# Mark this qedge as 'filled', but only AFTER pruning back prior node(s) as needed
message.query_graph.edges[qedge_key].filled = True # Mark as expanded in overarching QG #1848
qedge.filled = True # Also mark as expanded in local QG #1848

# Figure out which KPs would be best to expand this edge with (if no KP was specified)
if not user_specified_kp:
if mode == "RTXKG2":
Expand Down Expand Up @@ -528,11 +530,19 @@ def apply(self, response, input_parameters, mode: str = "ARAX"):
# Declare that we are done expanding this qedge
response.update_query_plan(qedge_key, 'edge_properties', 'status', 'Done')

# Make sure we found at least SOME answers for this edge
# Make sure we have at least SOME answers for all (regular) qedges expanded so far..
# TODO: Should this really just return response here? What about returning partial KG?
if not eu.qg_is_fulfilled(one_hop_qg, overarching_kg) and not qedge.exclude and not qedge.option_group_id:
log.warning(f"No paths were found in any KPs satisfying qedge {qedge_key}. KPs used were: "
f"{kps_to_query}")
is_fulfilled, unfulfilled_qedge_keys = eu.qg_is_fulfilled(query_graph,
overarching_kg,
enforce_required_only=True,
enforce_expanded_only=True,
return_unfulfilled_qedges=True)
if not is_fulfilled:
if qedge.exclude:
log.warning(f"After processing 'exclude=True' edge {qedge_key}, "
f"no paths remain from any KPs that satisfy qedge(s) {unfulfilled_qedge_keys}.")
else:
log.warning(f"No paths were found in any KPs satisfying qedge {unfulfilled_qedge_keys}.")
return response

# Expand any specified nodes
Expand Down Expand Up @@ -1116,17 +1126,20 @@ def _apply_any_kryptonite_edges(organized_kg: QGOrganizedKnowledgeGraph, full_qu
for edge_key in edge_keys_to_remove:
organized_kg.edges_by_qg_id[qedge_key].pop(edge_key)

if not organized_kg.edges_by_qg_id[qedge_key]:
log.warning(f"All {qedge_key} edges have been deleted due to an Exclude=true (i.e., 'kryptonite') edge!")

@staticmethod
def _prune_kg(qnode_key_to_prune: str, prune_threshold: int, kg: QGOrganizedKnowledgeGraph,
qg: QueryGraph, log: ARAXResponse) -> QGOrganizedKnowledgeGraph:
log.info(f"Pruning back {qnode_key_to_prune} nodes because there are more than {prune_threshold}")
kg_copy = copy.deepcopy(kg)
qg_expanded_thus_far = eu.get_qg_expanded_thus_far(qg, kg)
qg_expanded_thus_far.nodes[qnode_key_to_prune].is_set = False # Necessary for assessment of answer quality
qg_for_resultify = copy.deepcopy(qg)
qg_for_resultify.nodes[qnode_key_to_prune].is_set = False # Necessary for assessment of answer quality
num_edges_in_kg = sum([len(edges) for edges in kg.edges_by_qg_id.values()])
overlay_fet = True if num_edges_in_kg < 100000 else False
# Use fisher exact test and the ranker to prune down answers for this qnode
intermediate_results_response = eu.create_results(qg_expanded_thus_far, kg_copy, log,
intermediate_results_response = eu.create_results(qg_for_resultify, kg_copy, log,
rank_results=True, overlay_fet=overlay_fet,
qnode_key_to_prune=qnode_key_to_prune)
log.debug(f"A total of {len(intermediate_results_response.envelope.message.results)} "
Expand All @@ -1143,12 +1156,15 @@ def _prune_kg(qnode_key_to_prune: str, prune_threshold: int, kg: QGOrganizedKnow
scores.append(current_result.analyses[0].score)
kept_nodes.update({binding.id for binding in current_result.node_bindings[qnode_key_to_prune]})
counter += 1
log.info(f"Kept top {len(kept_nodes)} answers for {qnode_key_to_prune}. "
f"Best score was {round(max(scores), 5)}, worst kept was {round(min(scores), 5)}.")
if kept_nodes:
log.info(f"Kept top {len(kept_nodes)} answers for {qnode_key_to_prune}. "
f"Best score was {round(max(scores), 5)}, worst kept was {round(min(scores), 5)}.")
else:
log.error(f"All nodes were pruned out for {qnode_key_to_prune}! Shouldn't be possible",
error_code="PruneError")
# Actually eliminate them from the KG

nodes_to_delete = set(kg.nodes_by_qg_id[qnode_key_to_prune]).difference(kept_nodes)
kg.remove_nodes(nodes_to_delete, qnode_key_to_prune, qg_expanded_thus_far)
kg.remove_nodes(nodes_to_delete, qnode_key_to_prune, qg_for_resultify)
else:
log.error(f"Ran into an issue using Resultify when trying to prune {qnode_key_to_prune} answers: "
f"{intermediate_results_response.show()}", error_code="PruneError")
Expand All @@ -1163,10 +1179,10 @@ def _remove_dead_end_paths(expands_qg: QueryGraph, kg: QGOrganizedKnowledgeGraph
found in the last expansion will connect to edges in the next one)
"""
log.debug(f"Pruning any paths that are now dead ends (with help of Resultify)")
qg_expanded_thus_far = eu.get_qg_expanded_thus_far(expands_qg, kg)
for qnode in qg_expanded_thus_far.nodes.values():
is_set_true_qg = copy.deepcopy(expands_qg)
for qnode in is_set_true_qg.nodes.values():
qnode.is_set = True # This makes resultify run faster and doesn't hurt in this case
resultify_response = eu.create_results(qg_expanded_thus_far, kg, log)
resultify_response = eu.create_results(is_set_true_qg, kg, log)
if resultify_response.status == "OK":
pruned_kg = eu.convert_standard_kg_to_qg_organized_kg(resultify_response.envelope.message.knowledge_graph)
else:
Expand Down
56 changes: 42 additions & 14 deletions code/ARAX/ARAXQuery/ARAX_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def query_return_stream(self, query, mode='ARAX'):
while i_message < n_messages:
with self.lock:
i_message_obj = self.response.messages[i_message].copy()
yield(json.dumps(i_message_obj) + "\n")
yield(json.dumps(i_message_obj, allow_nan=False) + "\n")
i_message += 1
idle_ticks = 0.0

Expand All @@ -147,7 +147,7 @@ def query_return_stream(self, query, mode='ARAX'):
query_plan_counter = self_query_plan_counter
with self.lock:
self_response_query_plan = self.response.query_plan.copy()
yield(json.dumps(self_response_query_plan, sort_keys=True) + "\n")
yield(json.dumps(self_response_query_plan, allow_nan=False, sort_keys=True) + "\n")
idle_ticks = 0.0
time.sleep(0.2)
idle_ticks += 0.2
Expand All @@ -161,14 +161,14 @@ def query_return_stream(self, query, mode='ARAX'):
# #### If there are any more logging messages in the queue, send them first
n_messages = len(self.response.messages)
while i_message < n_messages:
yield(json.dumps(self.response.messages[i_message]) + "\n")
yield(json.dumps(self.response.messages[i_message], allow_nan=False) + "\n")
i_message += 1

#### Also emit any updates to the query_plan
self_response_query_plan_counter = self.response.query_plan['counter']
if query_plan_counter < self_response_query_plan_counter:
query_plan_counter = self_response_query_plan_counter
yield(json.dumps(self.response.query_plan, sort_keys=True) + "\n")
yield(json.dumps(self.response.query_plan, allow_nan=False, sort_keys=True) + "\n")

# Remove the little DONE flag the other thread used to signal this thread that it is done
self.response.status = re.sub('DONE,', '', self.response.status)
Expand All @@ -178,7 +178,21 @@ def query_return_stream(self, query, mode='ARAX'):
self.response.envelope.status = 'Success'

# Stream the resulting message back to the client
yield(json.dumps(self.response.envelope.to_dict(), sort_keys=True) + "\n")
try:
msg_str = json.dumps(self.response.envelope.to_dict(),
allow_nan=False,
sort_keys=True) + "\n"
except ValueError as v:
self.response.envelope.message.results = []
self.response.envelope.message.auxiliary_graphs = None
self.response.envelope.message.knowledge_graph = {'edges': dict(), 'nodes': dict()}
self.response.envelope.status = 'ERROR'
error_message_str = f"error dumping result to JSON: {str(v)}"
self.response.error(error_message_str)
eprint(error_message_str)
msg_str = json.dumps(self.response.envelope.to_dict(),
sort_keys=True) + "\n"
yield msg_str

# Wait until both threads rejoin here and the return
main_query_thread.join()
Expand Down Expand Up @@ -483,7 +497,7 @@ def validate_incoming_query_graph(self,message):
response.info(f"Validating the input query graph")

# Define allowed qnode and qedge attributes to check later
allowed_qnode_attributes = { 'ids': 1, 'categories':1, 'is_set': 1, 'option_group_id': 1, 'name': 1, 'constraints': 1 }
allowed_qnode_attributes = { 'ids': 1, 'categories':1, 'is_set': 1, 'set_interpretation': 1, 'set_id': 1, 'option_group_id': 1, 'name': 1, 'constraints': 1 }
allowed_qedge_attributes = { 'predicates': 1, 'subject': 1, 'object': 1, 'option_group_id': 1, 'exclude': 1, 'relation': 1, 'attribute_constraints': 1, 'qualifier_constraints': 1, 'knowledge_type': 1 }

#### Loop through nodes checking the attributes
Expand Down Expand Up @@ -1060,10 +1074,10 @@ def main():
"add_qnode(name=acetaminophen, key=n0)",
"add_qnode(categories=biolink:Protein, key=n1)",
"add_qedge(subject=n0, object=n1, key=e0)",
"expand(edge_key=e0)",
"expand(edge_key=e0, kp=infores:rtx-kg2)",
"overlay(action=compute_ngd, virtual_relation_label=N1, subject_qnode_key=n0, object_qnode_key=n1)",
"resultify(ignore_edge_direction=true)",
"filter_results(action=limit_number_of_results, max_results=10)",
"#filter_results(action=limit_number_of_results, max_results=5)",
"return(message=true, store=true)",
]}}

Expand Down Expand Up @@ -1816,6 +1830,17 @@ def main():
"scoreless_resultify(ignore_edge_direction=true)",
"rank_results()"
]}}
elif params.example_number == 2262:
query = {"operations": {"actions": [
"create_message",
"add_qnode(name=DOID:1227, key=n00)",
"add_qnode(categories=biolink:ChemicalEntity, key=n01)",
"add_qedge(subject=n01, object=n00, key=e00, predicates=biolink:treats)",
"expand(edge_key=e00, kp=infores:rtx-kg2)",
"filter_kg(action=remove_edges_by_predicate, edge_predicate=biolink:treats, remove_connected_nodes=t, qedge_keys=[e00])",
"resultify(ignore_edge_direction=true)",
"return(message=true, store=false)"
]}}
else:
eprint(f"Invalid test number {params.example_number}. Try 1 through 17")
return
Expand All @@ -1838,10 +1863,12 @@ def main():


#### Print out the logging stream
print(response.show(level=ARAXResponse.DEBUG))
#if verbose:
# print(response.show(level=ARAXResponse.DEBUG))

#### Print out the message that came back
print(json.dumps(ast.literal_eval(repr(envelope)), sort_keys=True, indent=2))
#if verbose:
# print(json.dumps(ast.literal_eval(repr(envelope)), sort_keys=True, indent=2))

#### Other stuff that could be dumped
#print(json.dumps(message.to_dict(),sort_keys=True,indent=2))
Expand All @@ -1858,10 +1885,11 @@ def main():
#print(f"Essence names in the answers: {[x.essence for x in message.results]}")
print("Results:")
for result in message.results:
confidence = result.confidence
if confidence is None:
confidence = 0.0
print(" -" + '{:6.3f}'.format(confidence) + f"\t{result.essence}")
analysis = result.analyses[0]
score = analysis.score
if score is None:
score = 0.0
print(" -" + '{:6.3f}'.format(score) + f"\t{result.essence}")

# print the response id at the bottom for convenience too:
print(f"Returned response id: {envelope.id}")
Expand Down
39 changes: 37 additions & 2 deletions code/ARAX/ARAXQuery/ARAX_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,33 @@ def _score_result_graphs_by_networkx_graph_scorer(kg_edge_id_to_edge: Dict[str,
return nx_graph_scorer(result_graphs_nx)


def _break_ties_and_preserve_order(scores):
adjusted_scores = scores.copy()
n = len(scores)
# if there are more than 1,000 scores, apply the fix to the first 1000 scores and ignore the rest
if n > 1000:
n = 1000

for i in range(n):
if i > 0 and adjusted_scores[i] >= adjusted_scores[i - 1]:
# Calculate the decrement such that it makes this score slightly less than the previous,
# maintaining the descending order.
decrement = round(adjusted_scores[i - 1] - adjusted_scores[i], 3) - 0.001
adjusted_scores[i] = adjusted_scores[i - 1] - max(decrement, 0.001)

# Ensure the adjusted score doesn't become lower than the next score
if i < n - 1 and adjusted_scores[i] <= adjusted_scores[i + 1]:
# Adjust the next score to be slightly less than the current score
increment = round(adjusted_scores[i] - adjusted_scores[i + 1], 3) - 0.001
adjusted_scores[i + 1] = adjusted_scores[i] - max(increment, 0.001)

# round all scores to 3 decimal places
adjusted_scores = [round(score, 3) for score in adjusted_scores]
# make sure no scores are below 0
adjusted_scores = [max(score, 0) for score in adjusted_scores]
return adjusted_scores


class ARAXRanker:

# #### Constructor
Expand Down Expand Up @@ -657,11 +684,12 @@ def aggregate_scores_dmk(self, response):
[_score_networkx_graphs_by_max_flow,
_score_networkx_graphs_by_longest_path,
_score_networkx_graphs_by_frobenius_norm])))
#print(ranks_list)
#print(float(len(ranks_list)))


result_scores = sum(ranks_list)/float(len(ranks_list))
#print(result_scores)


# Replace Inferred Results Score with Probability score calculated by xDTD model
inferred_qedge_keys = [qedge_key for qedge_key, qedge in message.query_graph.edges.items()
if qedge.knowledge_type == "inferred"]
Expand Down Expand Up @@ -699,6 +727,13 @@ def aggregate_scores_dmk(self, response):

# Re-sort the final results
message.results.sort(key=lambda result: result.analyses[0].score, reverse=True)
# break ties and preserve order, round to 3 digits and make sure none are < 0
scores_with_ties = [result.analyses[0].score for result in message.results]
scores_without_ties = _break_ties_and_preserve_order(scores_with_ties)
# reinsert these scores into the results
for result, score in zip(message.results, scores_without_ties):
result.analyses[0].score = score
result.row_data[0] = score
response.debug("Results have been ranked and sorted")


Expand Down
Loading

0 comments on commit 1a59c03

Please sign in to comment.