-
Notifications
You must be signed in to change notification settings - Fork 0
/
Knowledge_Base_utils.py
154 lines (127 loc) · 5.6 KB
/
Knowledge_Base_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import typing import List
from hdt import HDTDocument
from rdflib import Graph
from rdflib_hdt import HDTStore
class Knowledge_Graph:
def __init__(self, kb_path: str):
self.kb = HDTDocument(kb_path)
self.graph = Graph(store=HDTStore(kb_path))
print(f"KB loaded, it has {self.kb.total_triples} triples")
def get_relations_1hop(self, entity: str, limit: int=1000):
"""
Get all relation that are one hop away from a given entity id.
"""
forw, _ = self.kb.search_triples(entity, "", "")
back, _ = self.kb.search_triples("", "", entity)
relations_1hop = set()
for triplet in forw:
subj, rel, obj = triplet
relations_1hop.update([rel])
for triplet in back:
subj, rel, obj = triplet
relations_1hop.update([rel])
return list(relations_1hop)[:limit]
def get_neighbors_1hop(self, entity: str):
"""
Get all neighboring entities that are one hop away from a given entity id.
"""
forw, _ = self.kb.search_triples(entity, "", "")
back, _ = self.kb.search_triples("", "", entity)
neighbors_1hop = set()
for triplet in forw:
subj, rel, obj = triplet
neighbors_1hop.update([obj])
for triplet in back:
subj, rel, obj = triplet
neighbors_1hop.update([subj])
return list(neighbors_1hop)[:limit]
def get_frequency(self, entity: str):
"""
Finds the frequency of an entity by counting how many triplets contain the entity
"""
_, car = self.kb.search_triples(entity, "", "")
_, car2 = self.kb.search_triples("", "", entity)
return car + car2
def get_connection_1hop(self, src: str, trg: str):
"""
Finds all one hop paths between two entities, if there exists any such path.
"""
res, _ = self.kb.search_triples(src, "", trg)
res2, _ = self.kb.search_triples(trg, "", src)
return list(res) + list(res2)
def get_connection_2hop(self, src: str, trg: str):
"""
Finds all two hop paths between two entities, if there exists any such path.
"""
neighbors1 = get_neighbors_1hop(src)
neighbors2 = get_neighbors_1hop(trg)
intersection_entities = neighbors1 & neighbors2
if not intersection_entities:
return []
connections = []
for entity in intersection_entities:
if self.get_frequency(entity) > 100000:
continue
connections1 = self.get_connection_1hop(src, entity)
connections2 = self.get_connection_1hop(entity, trg)
connections.append((connections1 + connections2))
return connections
def get_subgraph_triplets(self, topic_entity: str, curr_path: List[str], limit: int=1000):
"""
Find a subgraph given a topic entity and a path of relations (max 2hop) from that entity.
"""
if len(curr_path) == 1:
rel = curr_path[0]
triplets, _ = self.kb.search_triples(topic_entity, rel, "")
elif len(curr_path) == 2:
query = f"""
SELECT DISTINCT * WHERE {{
{topic_entity} {curr_path[0]} ?obj1 .
?obj1 {curr_path[1]} ?obj2 .
}} LIMIT {limit}
"""
res = self.graph.query(query)
triplets = []
for r in res:
cand = r.asdict()
filled_triplets = [[topic_entity, curr_path[0], "?obj1"], ["?obj1", curr_path[1], "?obj2"]]
for unk_uri, unk_uri_lbl in cand.items():
unk_uri_label = str(unk_uri_lbl.toPython())
filled_triplets = [tuple(el.replace(unk_uri, unk_uri_lbl) for el in triplet) for triplet in filled_triplets]
triplets += filled_triplets
return triplets
def execute_sparql(self, query: str):
res = self.graph.query(query)
return res
def deduce_triplets_from_sparql(self, query: str):
"""
Finds all triplets that a given sparql query uses while being executed and replaces unknown variables with the real values foudn after execution.
Can be used to deduce triplets list from a sparql query in KBQA datasets like LC-QUaD.
"""
prefixes = """
PREFIX dbr: <http://dbpedia.org/resource/>
PREFIX dbo: <http://dbpedia.org/ontology/>
PREFIX dbp: <http://dbpedia.org/property/>
PREFIX xsd: <http://www.w3.org/2001/XMLSchema#>
SELECT * WHERE {
"""
body = re.findall(r'\{(.*)\}', sparql)[0]
sparql_parts = re.findall(r'\{(.*?)\s?(?:FILTER.*)?\}', sparql, re.IGNORECASE)[0].split(". ")
triplets, unk_triplets = [], []
for query_triplet in sparql_parts:
if not query_triplet:
continue
subj, rel, obj = query_triplet.split()[:3]:
if not subj.startswith("?") and not obj.startswith("?"):
triplets.append((subj, rel, obj))
else:
unk_triplets.append([subj, rel, obj])
res = graph.query(prefixes + body + "}")
for r in res:
cand = r.asdict()
filled_triplets = unk_triplets.copy()
for unk_uri, unk_uri_lbl in cand.items():
unk_uri_label = str(unk_uri_lbl.toPython())
filled_triplets = [tuple(el.replace(unk_uri, unk_uri_lbl) for el in triplet) for triplet in filled_triplets]
triplets += filled_triplets
return set(triplets)