diff --git a/code/ARAX/ARAXQuery/result_transformer.py b/code/ARAX/ARAXQuery/result_transformer.py index ec45579cb..d11ea4417 100644 --- a/code/ARAX/ARAXQuery/result_transformer.py +++ b/code/ARAX/ARAXQuery/result_transformer.py @@ -65,6 +65,9 @@ def transform(response: ARAXResponse): # Refer to this aux graph from the current Result or Edge (if this is an Infer support graph) if group_id and group_id.startswith("creative_"): + # Figure out which creative tool/method we're dealing with (e.g. creative_DTD, creative_expand) + group_id_prefix = "_".join(group_id.split("_")[:2]) + # Create an attribute for the support graph that we'll tack onto the treats edge for this result support_graph_attribute = Attribute(attribute_type_id="biolink:support_graphs", value=[aux_graph_key], @@ -83,7 +86,8 @@ def transform(response: ARAXResponse): else: inferred_qedge_key = inferred_qedge_keys[0] inferred_edge_keys = {edge_binding.id for edge_binding in - result.analyses[0].edge_bindings[inferred_qedge_key] if "creative_" in edge_binding.id} + result.analyses[0].edge_bindings[inferred_qedge_key] + if group_id_prefix in edge_binding.id} # Refer to the support graph from the proper edge(s) for inferred_edge_key in inferred_edge_keys: inferred_edge = message.knowledge_graph.edges[inferred_edge_key] diff --git a/code/ARAX/test/test_ARAX_expand.py b/code/ARAX/test/test_ARAX_expand.py index f32c6328a..4d314b92a 100644 --- a/code/ARAX/test/test_ARAX_expand.py +++ b/code/ARAX/test/test_ARAX_expand.py @@ -1547,10 +1547,8 @@ def test_treats_patch_issue_2328_a(): support_edge_keys = set() for edge in creative_expand_treats_edges: aux_graph_keys = get_support_graphs_attribute(edge).value - creative_expand_aux_graph_keys = [aux_graph_key for aux_graph_key in aux_graph_keys - if "creative_expand" in aux_graph_key] - assert creative_expand_aux_graph_keys - for aux_graph_key in creative_expand_aux_graph_keys: + assert aux_graph_keys + for aux_graph_key in aux_graph_keys: aux_graph = message.auxiliary_graphs[aux_graph_key] support_edge_keys.update(set(aux_graph.edges)) support_edges = [message.knowledge_graph.edges[edge_key] for edge_key in support_edge_keys]