Skip to content

Commit

Permalink
Update build_mapping_db.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chunyuma authored Jul 22, 2024
1 parent 5ebd34e commit 460b611
Showing 1 changed file with 15 additions and 51 deletions.
66 changes: 15 additions & 51 deletions code/ARAX/ARAXQuery/Infer/scripts/build_mapping_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,6 @@
import csv
csv.field_size_limit(sys.maxsize)
import collections
import json

## Import Personal Packages
pathlist = os.getcwd().split(os.path.sep)
ROOTindex = pathlist.index("xDTD_training_pipeline")
ROOTPath = os.path.sep.join([*pathlist[:(ROOTindex + 1)]])
sys.path.append(os.path.join(ROOTPath, 'scripts'))
import utils

DEBUG = True


class xDTDMappingDB():

Expand All @@ -39,12 +28,12 @@ def __init__(self, kgml_xdtd_data_path=None, tsv_path=None, database_name='xdtd_

if mode == 'build':
if not os.path.exists(kgml_xdtd_data_path):
print(f"Error: The given path of kgml_xdtd_data_path '{kgml_xdtd_data_path}' doesn't exist.", flush=True)
print(f"Error: The given path '{kgml_xdtd_data_path}' doesn't exist.", flush=True)
raise
else:
self.kgml_xdtd_data_path = kgml_xdtd_data_path
if not os.path.exists(tsv_path):
print(f"Error: The given path of tsv_path '{tsv_path}' doesn't exist.", flush=True)
print(f"Error: The given path '{tsv_path}' doesn't exist.", flush=True)
raise
else:
self.tsv_path = tsv_path
Expand All @@ -63,7 +52,7 @@ def __init__(self, kgml_xdtd_data_path=None, tsv_path=None, database_name='xdtd_
self.success_con = self._connect(db_path)
elif mode == 'run':
if db_loc is None:
print(f"Error: The given path of db_loc '{db_loc}' doesn't exist.", flush=True)
print(f"Error: The given path '{db_loc}' doesn't exist.", flush=True)
raise
else:
db_path = os.path.join(db_loc, database_name)
Expand Down Expand Up @@ -116,8 +105,8 @@ def populate_table(self):
## load kgml_xdtd data
print("Loading KGML-xDTD data...", flush=True)
kgml_xdtd_graph_nodes = pd.read_csv(os.path.join(self.kgml_xdtd_data_path, 'entity2freq.txt'), sep='\t', header=None).drop(columns=[1])
# kgml_xdtd_graph_edges = pd.read_csv(os.path.join(self.kgml_xdtd_data_path, 'graph_edges.txt'), sep='\t', header=0)
# kgml_xdtd_graph_edges_dict = {(row[0],row[2],row[1]):1 for row in kgml_xdtd_graph_edges.to_numpy()}
kgml_xdtd_graph_edges = pd.read_csv(os.path.join(self.kgml_xdtd_data_path, 'graph_edges.txt'), sep='\t', header=0)
kgml_xdtd_graph_edges_dict = {(row[0],row[2],row[1]):1 for row in kgml_xdtd_graph_edges.to_numpy()}
# kgml_xdtd_graph_edges_dict = {}
# for row in tqdm(kgml_xdtd_graph_edges.to_numpy()):
# if row[2] == 'biolink:entity_regulates_entity':
Expand Down Expand Up @@ -170,12 +159,10 @@ def populate_table(self):
data_reader = csv.reader(data_tsv, delimiter='\t')
tsv_edge_df = pd.DataFrame([row for row in data_reader])
tsv_edge_df.columns = headers
## filter out the 'domain_range_exclusion==True' edge
tsv_edge_df = tsv_edge_df.loc[tsv_edge_df['domain_range_exclusion'] != 'True',:].reset_index(drop=True)
tsv_edge_df = tsv_edge_df[['subject','object','predicate','primary_knowledge_source','publications','publications_info','kg2_ids']]
tsv_edge_df = tsv_edge_df[['subject','object','predicate','knowledge_source','publications','publications_info','kg2_ids']]
# Split 'knowledge_sources' on 'ǂ' and then explode it
# tsv_edge_df['knowledge_source'] = tsv_edge_df['knowledge_source'].str.split('ǂ')
# tsv_edge_df = tsv_edge_df.explode('knowledge_source')
tsv_edge_df['knowledge_source'] = tsv_edge_df['knowledge_source'].str.split('ǂ')
tsv_edge_df = tsv_edge_df.explode('knowledge_source')

## Insert node information into database
print("Inserting into NODE_MAPPING_TABLE...", flush=True)
Expand All @@ -189,11 +176,11 @@ def populate_table(self):
## Intert edge information into database
print("Inserting into EDGE_MAPPING_TABLE...", flush=True)
for row in tqdm(tsv_edge_df.to_numpy()):
# if (row[0], row[2], row[1]) in kgml_xdtd_graph_edges_dict:
## intsert into database
row = [f"{row[0]}--{row[2]}--{row[1]}"] + list(row)
insert_command = f"INSERT INTO EDGE_MAPPING_TABLE values (?,?,?,?,?,?,?,?)"
self.connection.execute(insert_command, tuple(row))
if (row[0], row[2], row[1]) in kgml_xdtd_graph_edges_dict:
## intsert into database
row = [f"{row[0]}--{row[2]}--{row[1]}"] + list(row)
insert_command = f"INSERT INTO EDGE_MAPPING_TABLE values (?,?,?,?,?,?,?,?)"
self.connection.execute(insert_command, tuple(row))
print(f"Inserting into EDGE_MAPPING_TABLE is completed", flush=True)
self.connection.commit()

Expand Down Expand Up @@ -225,19 +212,13 @@ def get_node_info(self, node_id = None, node_name = None):
query = f"SELECT * FROM NODE_MAPPING_TABLE WHERE id = '{node_id}'"
cursor.execute(query)
## create a named tuple
temp_result = cursor.fetchone()
if temp_result is None:
return None
res = res._make(temp_result)
res = res._make(cursor.fetchone())
return res
elif node_name is not None and type(node_name) == str:
query = f"SELECT * FROM NODE_MAPPING_TABLE WHERE name = '{node_name}'"
cursor.execute(query)
## create a named tuple
temp_result = cursor.fetchone()
if temp_result is None:
return None
res = res._make(temp_result)
res = res._make(cursor.fetchone())
return res
else:
return None
Expand All @@ -261,12 +242,8 @@ def get_edge_info(self, triple_id = None, triple_name = None):
elif triple_name is not None and type(triple_name) == tuple:
subject_name, predicate, object_name = triple_name
subject_info = self.get_node_info(node_name=subject_name)
if not subject_info:
return []
subject_id = subject_info.id
object_info = self.get_node_info(node_name=object_name)
if not object_info:
return []
object_id = object_info.id
if predicate != 'SELF_LOOP_RELATION':
query = f"SELECT * FROM EDGE_MAPPING_TABLE WHERE triple = '{subject_id}--{predicate}--{object_id}'"
Expand All @@ -284,7 +261,6 @@ def get_edge_info(self, triple_id = None, triple_name = None):

def main():
parser = argparse.ArgumentParser(description="Tests or builds the KGML-xDTD model Mapping Database", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--db_config_path", type=str, help="path to database config file", default="../config_dbs.json")
parser.add_argument('--build', action="store_true", required=False, help="If set, (re)build the index from scratch", default=False)
parser.add_argument('--test', action="store_true", required=False, help="If set, run a test of database by doing several lookups", default=False)
parser.add_argument('--tsv_path', type=str, required=True, help="Path to a folder containing the KG2 graph TSV files")
Expand All @@ -299,18 +275,6 @@ def main():

# To (re)build
if args.build:
## Download data from arax-databases.rtx.ai first
with open(args.db_config_path, 'rb') as file_in:
config_dbs = json.load(file_in)
kg2name = config_dbs["neo4j"]["KG2c"].replace('c.rtx.ai','').upper().replace('-','.')
if not os.path.exists(os.path.join(ROOTPath, "data", 'kg2c-tsv.tar.gz')):
os.system(f"scp [email protected]:~/{kg2name}/extra_files/kg2c-tsv.tar.gz {os.path.join(ROOTPath, 'data', 'kg2c-tsv.tar.gz')}")
## De-compress kg2c-tsv.tar.gz
if not os.path.exists(os.path.join(ROOTPath, 'data', 'kg2c-tsv')):
os.makedirs(os.path.join(ROOTPath, 'data', 'kg2c-tsv'))
if not os.path.exists(os.path.join(ROOTPath, "data", 'kg2c-tsv', 'edges_c.tsv')):
os.system(f"tar -zxvf {os.path.join(ROOTPath, 'data', 'kg2c-tsv.tar.gz')} -C {os.path.join(ROOTPath, 'data', 'kg2c-tsv')}")
args.tsv_path = os.path.join(ROOTPath, 'data', 'kg2c-tsv')
db = xDTDMappingDB(args.kgml_xdtd_data_path, args.tsv_path, args.database_name, args.outdir, mode='build', db_loc=None)
print("==== Creating tables ====", flush=True)
db.create_tables()
Expand Down

0 comments on commit 460b611

Please sign in to comment.