-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
15 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,17 +12,6 @@ | |
import csv | ||
csv.field_size_limit(sys.maxsize) | ||
import collections | ||
import json | ||
|
||
## Import Personal Packages | ||
pathlist = os.getcwd().split(os.path.sep) | ||
ROOTindex = pathlist.index("xDTD_training_pipeline") | ||
ROOTPath = os.path.sep.join([*pathlist[:(ROOTindex + 1)]]) | ||
sys.path.append(os.path.join(ROOTPath, 'scripts')) | ||
import utils | ||
|
||
DEBUG = True | ||
|
||
|
||
class xDTDMappingDB(): | ||
|
||
|
@@ -39,12 +28,12 @@ def __init__(self, kgml_xdtd_data_path=None, tsv_path=None, database_name='xdtd_ | |
|
||
if mode == 'build': | ||
if not os.path.exists(kgml_xdtd_data_path): | ||
print(f"Error: The given path of kgml_xdtd_data_path '{kgml_xdtd_data_path}' doesn't exist.", flush=True) | ||
print(f"Error: The given path '{kgml_xdtd_data_path}' doesn't exist.", flush=True) | ||
raise | ||
else: | ||
self.kgml_xdtd_data_path = kgml_xdtd_data_path | ||
if not os.path.exists(tsv_path): | ||
print(f"Error: The given path of tsv_path '{tsv_path}' doesn't exist.", flush=True) | ||
print(f"Error: The given path '{tsv_path}' doesn't exist.", flush=True) | ||
raise | ||
else: | ||
self.tsv_path = tsv_path | ||
|
@@ -63,7 +52,7 @@ def __init__(self, kgml_xdtd_data_path=None, tsv_path=None, database_name='xdtd_ | |
self.success_con = self._connect(db_path) | ||
elif mode == 'run': | ||
if db_loc is None: | ||
print(f"Error: The given path of db_loc '{db_loc}' doesn't exist.", flush=True) | ||
print(f"Error: The given path '{db_loc}' doesn't exist.", flush=True) | ||
raise | ||
else: | ||
db_path = os.path.join(db_loc, database_name) | ||
|
@@ -116,8 +105,8 @@ def populate_table(self): | |
## load kgml_xdtd data | ||
print("Loading KGML-xDTD data...", flush=True) | ||
kgml_xdtd_graph_nodes = pd.read_csv(os.path.join(self.kgml_xdtd_data_path, 'entity2freq.txt'), sep='\t', header=None).drop(columns=[1]) | ||
# kgml_xdtd_graph_edges = pd.read_csv(os.path.join(self.kgml_xdtd_data_path, 'graph_edges.txt'), sep='\t', header=0) | ||
# kgml_xdtd_graph_edges_dict = {(row[0],row[2],row[1]):1 for row in kgml_xdtd_graph_edges.to_numpy()} | ||
kgml_xdtd_graph_edges = pd.read_csv(os.path.join(self.kgml_xdtd_data_path, 'graph_edges.txt'), sep='\t', header=0) | ||
kgml_xdtd_graph_edges_dict = {(row[0],row[2],row[1]):1 for row in kgml_xdtd_graph_edges.to_numpy()} | ||
# kgml_xdtd_graph_edges_dict = {} | ||
# for row in tqdm(kgml_xdtd_graph_edges.to_numpy()): | ||
# if row[2] == 'biolink:entity_regulates_entity': | ||
|
@@ -170,12 +159,10 @@ def populate_table(self): | |
data_reader = csv.reader(data_tsv, delimiter='\t') | ||
tsv_edge_df = pd.DataFrame([row for row in data_reader]) | ||
tsv_edge_df.columns = headers | ||
## filter out the 'domain_range_exclusion==True' edge | ||
tsv_edge_df = tsv_edge_df.loc[tsv_edge_df['domain_range_exclusion'] != 'True',:].reset_index(drop=True) | ||
tsv_edge_df = tsv_edge_df[['subject','object','predicate','primary_knowledge_source','publications','publications_info','kg2_ids']] | ||
tsv_edge_df = tsv_edge_df[['subject','object','predicate','knowledge_source','publications','publications_info','kg2_ids']] | ||
# Split 'knowledge_sources' on 'ǂ' and then explode it | ||
# tsv_edge_df['knowledge_source'] = tsv_edge_df['knowledge_source'].str.split('ǂ') | ||
# tsv_edge_df = tsv_edge_df.explode('knowledge_source') | ||
tsv_edge_df['knowledge_source'] = tsv_edge_df['knowledge_source'].str.split('ǂ') | ||
tsv_edge_df = tsv_edge_df.explode('knowledge_source') | ||
|
||
## Insert node information into database | ||
print("Inserting into NODE_MAPPING_TABLE...", flush=True) | ||
|
@@ -189,11 +176,11 @@ def populate_table(self): | |
## Intert edge information into database | ||
print("Inserting into EDGE_MAPPING_TABLE...", flush=True) | ||
for row in tqdm(tsv_edge_df.to_numpy()): | ||
# if (row[0], row[2], row[1]) in kgml_xdtd_graph_edges_dict: | ||
## intsert into database | ||
row = [f"{row[0]}--{row[2]}--{row[1]}"] + list(row) | ||
insert_command = f"INSERT INTO EDGE_MAPPING_TABLE values (?,?,?,?,?,?,?,?)" | ||
self.connection.execute(insert_command, tuple(row)) | ||
if (row[0], row[2], row[1]) in kgml_xdtd_graph_edges_dict: | ||
## intsert into database | ||
row = [f"{row[0]}--{row[2]}--{row[1]}"] + list(row) | ||
insert_command = f"INSERT INTO EDGE_MAPPING_TABLE values (?,?,?,?,?,?,?,?)" | ||
self.connection.execute(insert_command, tuple(row)) | ||
print(f"Inserting into EDGE_MAPPING_TABLE is completed", flush=True) | ||
self.connection.commit() | ||
|
||
|
@@ -225,19 +212,13 @@ def get_node_info(self, node_id = None, node_name = None): | |
query = f"SELECT * FROM NODE_MAPPING_TABLE WHERE id = '{node_id}'" | ||
cursor.execute(query) | ||
## create a named tuple | ||
temp_result = cursor.fetchone() | ||
if temp_result is None: | ||
return None | ||
res = res._make(temp_result) | ||
res = res._make(cursor.fetchone()) | ||
return res | ||
elif node_name is not None and type(node_name) == str: | ||
query = f"SELECT * FROM NODE_MAPPING_TABLE WHERE name = '{node_name}'" | ||
cursor.execute(query) | ||
## create a named tuple | ||
temp_result = cursor.fetchone() | ||
if temp_result is None: | ||
return None | ||
res = res._make(temp_result) | ||
res = res._make(cursor.fetchone()) | ||
return res | ||
else: | ||
return None | ||
|
@@ -261,12 +242,8 @@ def get_edge_info(self, triple_id = None, triple_name = None): | |
elif triple_name is not None and type(triple_name) == tuple: | ||
subject_name, predicate, object_name = triple_name | ||
subject_info = self.get_node_info(node_name=subject_name) | ||
if not subject_info: | ||
return [] | ||
subject_id = subject_info.id | ||
object_info = self.get_node_info(node_name=object_name) | ||
if not object_info: | ||
return [] | ||
object_id = object_info.id | ||
if predicate != 'SELF_LOOP_RELATION': | ||
query = f"SELECT * FROM EDGE_MAPPING_TABLE WHERE triple = '{subject_id}--{predicate}--{object_id}'" | ||
|
@@ -284,7 +261,6 @@ def get_edge_info(self, triple_id = None, triple_name = None): | |
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Tests or builds the KGML-xDTD model Mapping Database", formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||
parser.add_argument("--db_config_path", type=str, help="path to database config file", default="../config_dbs.json") | ||
parser.add_argument('--build', action="store_true", required=False, help="If set, (re)build the index from scratch", default=False) | ||
parser.add_argument('--test', action="store_true", required=False, help="If set, run a test of database by doing several lookups", default=False) | ||
parser.add_argument('--tsv_path', type=str, required=True, help="Path to a folder containing the KG2 graph TSV files") | ||
|
@@ -299,18 +275,6 @@ def main(): | |
|
||
# To (re)build | ||
if args.build: | ||
## Download data from arax-databases.rtx.ai first | ||
with open(args.db_config_path, 'rb') as file_in: | ||
config_dbs = json.load(file_in) | ||
kg2name = config_dbs["neo4j"]["KG2c"].replace('c.rtx.ai','').upper().replace('-','.') | ||
if not os.path.exists(os.path.join(ROOTPath, "data", 'kg2c-tsv.tar.gz')): | ||
os.system(f"scp [email protected]:~/{kg2name}/extra_files/kg2c-tsv.tar.gz {os.path.join(ROOTPath, 'data', 'kg2c-tsv.tar.gz')}") | ||
## De-compress kg2c-tsv.tar.gz | ||
if not os.path.exists(os.path.join(ROOTPath, 'data', 'kg2c-tsv')): | ||
os.makedirs(os.path.join(ROOTPath, 'data', 'kg2c-tsv')) | ||
if not os.path.exists(os.path.join(ROOTPath, "data", 'kg2c-tsv', 'edges_c.tsv')): | ||
os.system(f"tar -zxvf {os.path.join(ROOTPath, 'data', 'kg2c-tsv.tar.gz')} -C {os.path.join(ROOTPath, 'data', 'kg2c-tsv')}") | ||
args.tsv_path = os.path.join(ROOTPath, 'data', 'kg2c-tsv') | ||
db = xDTDMappingDB(args.kgml_xdtd_data_path, args.tsv_path, args.database_name, args.outdir, mode='build', db_loc=None) | ||
print("==== Creating tables ====", flush=True) | ||
db.create_tables() | ||
|