diff --git a/.github/workflows/runTest.yml b/.github/workflows/runTest.yml index 08540d2..5d3dc40 100644 --- a/.github/workflows/runTest.yml +++ b/.github/workflows/runTest.yml @@ -22,7 +22,7 @@ jobs: run: pytest tests/test_utils.py - name: test-workflow run: pytest tests/test_workflow.py - - name: unit-tests - run: pytest -vv tests/unittests.py - name: integration-tests - run: pytest -vv tests/integration_tests.py \ No newline at end of file + run: pytest -vv tests/integration_tests.py + - name: unit-tests + run: pytest -vv tests/unittests.py \ No newline at end of file diff --git a/README.md b/README.md index 190eba6..ce07417 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ sourmash sketch fromfile ref_paths.csv -p dna,k=31,scaled=1000,abund -o ref.sig. python ../make_training_data_from_sketches.py --ref_file ref.sig.zip --ksize 31 --num_threads ${NUM_THREADS} --ani_thresh 0.95 --prefix 'demo_ani_thresh_0.95' --outdir ./ --force # run YACHT algorithm to check the presence of reference genomes in the query sample (inference step) -python ../run_YACHT.py --json demo_ani_thresh_0.95_config.json --sample_file sample.sig.zip --significance 0.99 --num_threads ${NUM_THREADS} --min_coverage_list 1 0.6 0.2 0.1 --out ./result.xlsx +python ../run_YACHT.py --json demo_ani_thresh_0.95_config.json --sample_file sample.sig.zip --significance 0.99 --num_threads ${NUM_THREADS} --min_coverage_list 1 0.6 0.2 0.1 --out_filename result.xlsx # convert result to CAMI profile format (Optional) python ../srcs/standardize_yacht_output.py --yacht_output result.xlsx --sheet_name min_coverage0.2 --genome_to_taxid toy_genome_to_taxid.tsv --mode cami --sample_name 'MySample' --outfile_prefix cami_result --outdir ./ @@ -179,8 +179,8 @@ The most important parameter of this script is `--ani_thresh`: this is average n | File (names starting with prefix) | Content | | ------------------------------------- | ------------------------------------------------------------ | | _config.json | A JSON file stores the required information needed to run the next YACHT algorithm | -| _manifest.tsv | A TSV file contains organisms and their relevant info after removing the similar ones | -| _removed_orgs_to_corr_orgas_mapping.tsv | A TSV file with two columns: removed organism names ('removed_org') and their similar genomes ('corr_orgs')| +| _manifest.tsv | A TSV file contains organisms and their relevant info after removing the similar ones. +| _rep_to_corr_orgas_mapping.tsv | A TSV file contains representative organisms and their similar organisms that have been removed |
@@ -190,7 +190,7 @@ The most important parameter of this script is `--ani_thresh`: this is average n After this, you are ready to perform the hypothesis test for each organism in your reference database. This can be accomplished with something like: ```bash -python run_YACHT.py --json 'gtdb_ani_thresh_0.95_config.json' --sample_file 'sample.sig.zip' --num_threads 32 --keep_raw --significance 0.99 --min_coverage_list 1 0.5 0.1 0.05 0.01 --out ./result.xlsx +python run_YACHT.py --json 'gtdb_ani_thresh_0.95_config.json' --sample_file 'sample.sig.zip' --num_threads 32 --keep_raw --significance 0.99 --min_coverage_list 1 0.5 0.1 0.05 0.01 --outdir ./ ``` #### Parameter @@ -207,7 +207,8 @@ The `--min_coverage_list` parameter dictates a list of `min_coverage` which indi | --keep_raw | keep the raw result (i.e. `min_coverage=1`) no matter if the user specifies it | | --show_all | Show all organisms (no matter if present) | | --min_coverage_list | a list of `min_coverage` values, see more detailed description above (default: 1, 0.5, 0.1, 0.05, 0.01) | -| --out | path to output excel result (default: './result.xlsx') | +| --out_filename | filename of output excel result (default: 'result.xlsx') | +| --outdir | the path to output directory where the results and intermediate files will be genreated | #### Output diff --git a/make_training_data_from_sketches.py b/make_training_data_from_sketches.py index 038d091..2eaa04e 100644 --- a/make_training_data_from_sketches.py +++ b/make_training_data_from_sketches.py @@ -8,7 +8,6 @@ import srcs.utils as utils from loguru import logger import json -import shutil logger.remove() logger.add(sys.stdout, format="{time:YYYY-MM-DD HH:mm:ss} - {level} - {message}", level="INFO") @@ -49,11 +48,6 @@ path_to_temp_dir = os.path.join(outdir, prefix+'_intermediate_files') if os.path.exists(path_to_temp_dir) and not force: raise ValueError(f"Temporary directory {path_to_temp_dir} already exists. Please remove it or given a new prefix name using parameter '--prefix'.") - else: - # remove the temporary directory if it exists - if os.path.exists(path_to_temp_dir): - logger.warning(f"Temporary directory {path_to_temp_dir} already exists. Removing it.") - shutil.rmtree(path_to_temp_dir) os.makedirs(path_to_temp_dir, exist_ok=True) # unzip the sourmash signature file to the temporary directory @@ -73,12 +67,12 @@ # Find the close related genomes with ANI > ani_thresh from the reference database logger.info("Find the close related genomes with ANI > ani_thresh from the reference database") - multisearch_result = utils.run_multisearch(num_threads, ani_thresh, ksize, scale, path_to_temp_dir) + sig_same_genoms_dict = utils.run_multisearch(num_threads, ani_thresh, ksize, scale, path_to_temp_dir) # remove the close related organisms: any organisms with ANI > ani_thresh # pick only the one with largest number of unique kmers from all the close related organisms logger.info("Removing the close related organisms with ANI > ani_thresh") - remove_corr_df, manifest_df = utils.remove_corr_organisms_from_ref(sig_info_dict, multisearch_result) + rep_remove_dict, manifest_df = utils.remove_corr_organisms_from_ref(sig_info_dict, sig_same_genoms_dict) # write out the manifest file logger.info("Writing out the manifest file") @@ -87,19 +81,22 @@ # write out a mapping dataframe from representative organism to the close related organisms logger.info("Writing out a mapping dataframe from representative organism to the close related organisms") - if len(remove_corr_df) == 0: - logger.warning("No close related organisms found.") - remove_corr_df_indicator = "" + if len(rep_remove_dict) == 0: + logger.warning("No close related organisms found. No mapping dataframe is written.") + rep_remove_df = pd.DataFrame(columns=['rep_org', 'corr_orgs']) + rep_remove_df_path = os.path.join(outdir, f'{prefix}_rep_to_corr_orgas_mapping.tsv') + rep_remove_df.to_csv(rep_remove_df_path, sep='\t', index=None) else: - remove_corr_df_path = os.path.join(outdir, f'{prefix}_removed_orgs_to_corr_orgas_mapping.tsv') - remove_corr_df.to_csv(remove_corr_df_path, sep='\t', index=None) - remove_corr_df_indicator = remove_corr_df_path + rep_remove_df = pd.DataFrame([(rep_org, ','.join(corr_org_list)) for rep_org, corr_org_list in rep_remove_dict.items()]) + rep_remove_df.columns = ['rep_org', 'corr_orgs'] + rep_remove_df_path = os.path.join(outdir, f'{prefix}_rep_to_corr_orgas_mapping.tsv') + rep_remove_df.to_csv(rep_remove_df_path, sep='\t', index=None) # save the config file logger.info("Saving the config file") json_file_path = os.path.join(outdir, f'{prefix}_config.json') json.dump({'manifest_file_path': manifest_file_path, - 'remove_cor_df_path': remove_corr_df_indicator, + 'rep_remove_df_path': rep_remove_df_path, 'intermediate_files_dir': path_to_temp_dir, 'scale': scale, 'ksize': ksize, diff --git a/run_YACHT.py b/run_YACHT.py index fe7d99a..28efc3a 100644 --- a/run_YACHT.py +++ b/run_YACHT.py @@ -10,7 +10,6 @@ import json import warnings import zipfile -from pathlib import Path warnings.filterwarnings("ignore") from tqdm import tqdm from loguru import logger @@ -33,7 +32,7 @@ 'Each value should be between 0 and 1, with 0 being the most sensitive (and least ' 'precise) and 1 being the most precise (and least sensitive).', required=False, default=[1, 0.5, 0.1, 0.05, 0.01]) - parser.add_argument('--out', type=str, help='path to output excel file', required=False, default=os.path.join(os.getcwd(), 'result.xlsx')) + parser.add_argument('--out_filename', help='Full path of output filename', required=False, default='result.xlsx') # parse the arguments args = parser.parse_args() @@ -44,13 +43,7 @@ keep_raw = args.keep_raw # Keep raw results in output file. show_all = args.show_all # Show all organisms (no matter if present) in output file. min_coverage_list = args.min_coverage_list # a list of percentages of unique k-mers covered by reads in the sample. - out = str(Path(args.out).absolute()) # full path to output excel file - outdir = os.path.dirname(out) # path to output directory - out_filename = os.path.basename(out) # output filename - - # check if the output filename is valid - if os.path.splitext(out_filename)[1] != '.xlsx': - raise ValueError(f'Output filename {out} is not a valid excel file. Please use .xlsx as the extension.') + out_filename = args.out_filename # output filename # check if the json file exists utils.check_file_existence(json_file_path, f'Config file {json_file_path} does not exist. ' @@ -64,8 +57,8 @@ ani_thresh = config['ani_thresh'] # Make sure the output can be written to - if not os.access(outdir, os.W_OK): - raise FileNotFoundError(f"Cannot write to the location: {outdir}.") + if not os.access(os.path.abspath(os.path.dirname(out_filename)), os.W_OK): + raise FileNotFoundError(f"Cannot write to the location: {os.path.abspath(os.path.dirname(out_filename))}.") # check if min_coverage is between 0 and 1 for x in min_coverage_list: @@ -105,6 +98,10 @@ # check that the sample scale factor is the same as the genome scale factor for all organisms if scale != sample_sig_info[4]: raise ValueError(f'Sample scale factor does not equal genome scale factor. Please check your input.') + + # check if the output filename is valid + if not isinstance(out_filename, str) and out_filename != '': + out_filename = 'result.xlsx' # compute hypothesis recovery logger.info('Computing hypothesis recovery.') @@ -129,9 +126,9 @@ manifest_list = temp_manifest_list # save the results into Excel file - logger.info(f'Saving results to {outdir}.') + logger.info(f'Saving results to {os.path.dirname(out_filename)}.') # save the results with different min_coverage - with pd.ExcelWriter(out, engine='openpyxl', mode='w') as writer: + with pd.ExcelWriter(out_filename, engine='openpyxl', mode='w') as writer: # save the raw results (i.e., min_coverage=1.0) if keep_raw: temp_mainifest = manifest_list[0].copy() diff --git a/srcs/utils.py b/srcs/utils.py index 4b7d974..376b949 100644 --- a/srcs/utils.py +++ b/srcs/utils.py @@ -85,8 +85,9 @@ def run_multisearch(num_threads: int, ani_thresh: float, ksize: int, scale: int, :param ksize: int (size of kmer) :param scale: int (scale factor) :param path_to_temp_dir: string (path to the folder to store the intermediate files) - :return: a dataframe with symmetric pairwise multisearch result (query_name, match_name) + :return: a dictionary mapping signature name to a list of its close related genomes (ANI > ani_thresh) """ + results = {} # run the sourmash multisearch # save signature files to a text file @@ -95,7 +96,7 @@ def run_multisearch(num_threads: int, ani_thresh: float, ksize: int, scale: int, sig_files.to_csv(sig_files_path, header=False, index=False) # convert ani threshold to containment threshold - containment_thresh = (ani_thresh ** ksize) + containment_thresh = 0.9*(ani_thresh ** ksize) cmd = f"sourmash scripts multisearch {sig_files_path} {sig_files_path} -k {ksize} -s {scale} -c {num_threads} -t {containment_thresh} -o {os.path.join(path_to_temp_dir, 'training_multisearch_result.csv')}" logger.info(f"Running sourmash multisearch with command: {cmd}") exit_code = os.system(cmd) @@ -104,62 +105,82 @@ def run_multisearch(num_threads: int, ani_thresh: float, ksize: int, scale: int, # read the multisearch result multisearch_result = pd.read_csv(os.path.join(path_to_temp_dir, 'training_multisearch_result.csv'), sep=',', header=0) + multisearch_result = multisearch_result.drop_duplicates().reset_index(drop=True) multisearch_result = multisearch_result.query('query_name != match_name').reset_index(drop=True) - - # because the multisearch result is not symmetric, that is - # we have: A B score but not B A score - # we need to make it symmetric - A_TO_B = multisearch_result[['query_name','match_name']].drop_duplicates().reset_index(drop=True) - B_TO_A = A_TO_B[['match_name','query_name']].rename(columns={'match_name':'query_name','query_name':'match_name'}) - multisearch_result = pd.concat([A_TO_B, B_TO_A]).drop_duplicates().reset_index(drop=True) - return multisearch_result + for query_name, match_name in tqdm(multisearch_result[['query_name', 'match_name']].to_numpy()): + if str(query_name) not in results: + results[str(query_name)] = [str(match_name)] + else: + results[str(query_name)].append(str(match_name)) + + return results -def remove_corr_organisms_from_ref(sig_info_dict: Dict[str, Tuple[str, float, int, int]], multisearch_result: pd.DataFrame) -> Tuple[Dict[str, List[str]], pd.DataFrame]: +def remove_corr_organisms_from_ref(sig_info_dict: Dict[str, Tuple[str, float, int, int]], sig_same_genoms_dict: Dict[str, List[str]]) -> Tuple[Dict[str, List[str]], pd.DataFrame]: """ Helper function that removes the close related organisms from the reference matrix. :param sig_info_dict: a dictionary mapping all signature name from reference data to a tuple (md5sum, minhash mean abundance, minhash hashes length, minhash scaled) - :param multisearch_result: a dataframe with symmetric pairwise multisearch result (query_name, match_name) + :param sig_same_genoms_dict: a dictionary mapping signature name to a list of its close related genomes (ANI > ani_thresh) :return - remove_corr_df: a dataframe with two columns: removed organism name and its close related organisms + rep_remove_dict: a dictionary with key as representative signature name and value as a list of signatures to be removed manifest_df: a dataframe containing the processed reference signature information """ - # extract organisms that have close related organisms and their number of unique kmers - # sort name in order to better check the removed organisms - corr_organisms = sorted([str(query_name) for query_name in multisearch_result['query_name'].unique()]) - sizes = np.array([sig_info_dict[organism][2] for organism in corr_organisms]) - # sort organisms by size in ascending order, so we keep the largest organism, discard the smallest - bysize = np.argsort(sizes) - corr_organisms_bysize = np.array(corr_organisms)[bysize].tolist() - - # use dictionary to store the removed organisms and their close related organisms - # key: removed organism name - # value: a set of close related organisms - mapping = multisearch_result.groupby('query_name')['match_name'].agg(set).to_dict() - - # remove the sorted organisms until all left genomes are distinct (e.g., ANI <= ani_thresh) + # for each genome with close related genomes, pick the one with largest number of unique kmers + rep_remove_dict = {} temp_remove_set = set() - # loop through the organisms size in ascending order - for organism in tqdm(corr_organisms_bysize, desc='Removing close related organisms'): - ## for a given organism check its close related organisms, see if there are any organisms left after removing those in the remove set - ## if so, put this organism in the remove set - left_corr_orgs = mapping[organism].difference(temp_remove_set) - if len(left_corr_orgs) > 0: - temp_remove_set.add(organism) - - # generate a dataframe with two columns: removed organism name and its close related organisms - logger.info('Generating a dataframe with two columns: removed organism name and its close related organisms.') - remove_corr_list = [(organism, ','.join(list(mapping[organism]))) for organism in tqdm(temp_remove_set)] - remove_corr_df = pd.DataFrame(remove_corr_list, columns=['removed_org', 'corr_orgs']) + manifest_df = [] + for genome, same_genomes in tqdm(sig_same_genoms_dict.items()): + # skip if the genome has been removed + if genome in temp_remove_set: + continue + # keep same genome if it is not in the remove set + same_genomes = list(set(same_genomes).difference(temp_remove_set)) + # get the number of unique kmers for each genome + unique_kmers = np.array([sig_info_dict[genome][2]] + [sig_info_dict[same_genome][2] for same_genome in same_genomes]) + # get the index of the genome with largest number of unique kmers + rep_idx = np.argmax(unique_kmers) + # get the representative genome + rep_genome = genome if rep_idx == 0 else same_genomes[rep_idx-1] + # get the list of genomes to be removed + remove_genomes = same_genomes if rep_idx == 0 else [genome] + same_genomes[:rep_idx-1] + same_genomes[rep_idx:] + # update remove set + temp_remove_set.update(remove_genomes) + if len(remove_genomes) > 0: + rep_remove_dict[rep_genome] = remove_genomes # remove the close related organisms from the reference genome list - manifest_df = [] for sig_name, (md5sum, minhash_mean_abundance, minhash_hashes_len, minhash_scaled) in tqdm(sig_info_dict.items()): if sig_name not in temp_remove_set: manifest_df.append((sig_name, md5sum, minhash_hashes_len, get_num_kmers(minhash_mean_abundance, minhash_hashes_len, minhash_scaled, False), minhash_scaled)) manifest_df = pd.DataFrame(manifest_df, columns=['organism_name', 'md5sum', 'num_unique_kmers_in_genome_sketch', 'num_total_kmers_in_genome_sketch', 'genome_scale_factor']) - return remove_corr_df, manifest_df + return rep_remove_dict, manifest_df + +# def compute_sample_vector(sample_hashes, hash_to_idx): +# """ +# Helper function that computes the sample vector for a given sample signature. +# :param sample_hashes: hashes in the sample signature +# :param hash_to_idx: dictionary mapping hashes to indices in the training dictionary +# :return: numpy array (sample vector) +# """ +# # total number of hashes in the training dictionary +# hash_to_idx_keys = set(hash_to_idx.keys()) + +# # total number of hashes in the sample +# sample_hashes_keys = set(sample_hashes.keys()) + +# # initialize the sample vector +# sample_vector = np.zeros(len(hash_to_idx_keys)) + +# # get the hashes that are in both the sample and the training dictionary +# sample_intersect_training_hashes = hash_to_idx_keys.intersection(sample_hashes_keys) + +# # fill in the sample vector +# for sh in tqdm(sample_intersect_training_hashes): +# sample_vector[hash_to_idx[sh]] = sample_hashes[sh] + +# return sample_vector + class Prediction: """ diff --git a/tests/integration_tests.py b/tests/integration_tests.py index d42c12d..479a942 100644 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -37,26 +37,28 @@ def test_make_training_data_from_sketches(): ref_file = 'tests/testdata/20_genomes_sketches.zip' ksize = '31' ani_thresh = '0.95' - out_prefix = 'gtdb_ani_thresh_0.95' - config_file = f'{out_prefix}_config.json' - hash_to_col_idx_file = f'{out_prefix}_hash_to_col_idx.pkl' - processed_org_idx_file = f'{out_prefix}_processed_org_idx.csv' - ref_matrix_processed_file = f'{out_prefix}_ref_matrix_processed.npz' + prefix = 'gtdb_ani_thresh_0.95' + config_file = f'{prefix}_config.json' + processed_manifest_file = f'{prefix}_processed_manifest.tsv' + rep_to_corr_orgas_mapping_file = f'{prefix}_rep_to_corr_orgas_mapping.tsv' + intermediate_files_dir = f'{prefix}_intermediate_files' command = [ 'python', 'make_training_data_from_sketches.py', '--ref_file', ref_file, '--ksize', ksize, + '--prefix', prefix, '--ani_thresh', ani_thresh, - '--out_prefix', out_prefix + '--outdir', './', + '--force', ] subprocess.run(command) assert os.path.isfile(config_file) - assert os.path.isfile(hash_to_col_idx_file) - assert os.path.isfile(processed_org_idx_file) - assert os.path.isfile(ref_matrix_processed_file) + assert os.path.isfile(processed_manifest_file) + assert os.path.isfile(rep_to_corr_orgas_mapping_file) + assert os.path.isdir(intermediate_files_dir) with open(config_file, 'r') as f: config = json.load(f) @@ -64,31 +66,9 @@ def test_make_training_data_from_sketches(): assert config['ani_thresh'] == float(ani_thresh) def test_run_yacht(): - script_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) - test_dir = os.path.join(script_dir, 'tests') - data_dir = os.path.join(test_dir, 'testdata') - out_prefix = "integration_test" - full_out_prefix = os.path.join(data_dir, out_prefix) - abundance_file = full_out_prefix + "recovered_abundance.xlsx" - reference_sketches = os.path.join(data_dir, "20_genomes_sketches.zip") - sample_sketches = os.path.join(data_dir, "sample.sig") - - cmd = f"python {os.path.join(script_dir, 'make_training_data_from_sketches.py')} --ref_file {reference_sketches}" \ - f" --out_prefix {full_out_prefix} --ksize 31" - res = subprocess.run(cmd, shell=True, check=True) - assert res.returncode == 0 - - assert exists(full_out_prefix + "_hash_to_col_idx.pkl") - assert exists(full_out_prefix + "_processed_org_idx.csv") - assert exists(full_out_prefix + "_ref_matrix_processed.npz") - - if exists(abundance_file): - os.remove(abundance_file) - - cmd = f"python {os.path.join(script_dir, 'run_YACHT.py')} --json {os.path.join(script_dir, 'gtdb_ani_thresh_0.95_config.json')} " \ - f"--sample_file {sample_sketches} --significance 0.99 --min_coverage 1 --outdir {data_dir} --out_filename {abundance_file}" + cmd = "python run_YACHT.py --json gtdb_ani_thresh_0.95_config.json --sample_file 'tests/testdata/sample.sig.zip' --significance 0.99 --min_coverage_list 1 0.6 0.2 0.1" res = subprocess.run(cmd, shell=True, check=True) assert res.returncode == 0 - assert exists(abundance_file) + assert exists('result.xlsx') diff --git a/tests/test_workflow.py b/tests/test_workflow.py index fb7406f..1074caf 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -13,6 +13,7 @@ def test_full_workflow(): test_dir = os.path.join(script_dir, 'tests') data_dir = os.path.join(test_dir, 'testdata') out_prefix = "20_genomes_trained" + full_out_prefix = os.path.join(data_dir, out_prefix) abundance_file = os.path.join(data_dir, "result.xlsx") reference_sketches = os.path.join(data_dir, "20_genomes_sketches.zip") sample_sketches = os.path.join(data_dir, "sample.sig.zip") @@ -30,8 +31,8 @@ def test_full_workflow(): # Remove the intermediate folder shutil.rmtree(os.path.join(data_dir, intermediate_dir), ignore_errors=True) # python ../make_training_data_from_sketches.py --ref_file testdata/20_genomes_sketches.zip --ksize 31 --prefix 20_genomes_trained --outdir testdata/ - cmd = f"python {os.path.join(script_dir, 'make_training_data_from_sketches.py')} --force --ref_file {reference_sketches}" \ - f" --prefix {out_prefix} --ksize 31 --outdir {data_dir}" + cmd = f"python {os.path.join(script_dir, 'make_training_data_from_sketches.py')} --ref_file {reference_sketches}" \ + f" --prefix {full_out_prefix} --ksize 31 --outdir {data_dir}" res = subprocess.run(cmd, shell=True, check=True) # check that no errors were raised assert res.returncode == 0 @@ -40,11 +41,12 @@ def test_full_workflow(): assert exists(f) # check that the files are big enough for f in expected_files: - assert os.stat(f).st_size > 300 + assert os.stat(f).st_size > 400 # then do the presence/absence estimation if exists(abundance_file): os.remove(abundance_file) - cmd = f"python {os.path.join(script_dir, 'run_YACHT.py')} --json {os.path.join(data_dir, '20_genomes_trained_config.json')} --sample_file {sample_sketches} --significance 0.99 --min_coverage 0.001 --out {os.path.join(data_dir,abundance_file)} --show_all" + # python ../run_YACHT.py --json testdata/20_genomes_trained_config.json --sample_file testdata/sample.sig.zip --out_file result.xlsx + cmd = f"python {os.path.join(script_dir, 'run_YACHT.py')} --json {os.path.join(data_dir, '20_genomes_trained_config.json')} --sample_file {sample_sketches} --significance 0.99 --min_coverage 0.001 --out_file {os.path.join(data_dir,abundance_file)} --show_all" res = subprocess.run(cmd, shell=True, check=True) # check that no errors were raised assert res.returncode == 0 @@ -67,21 +69,19 @@ def test_incorrect_workflow1(): cmd = f"python run_YACHT.py --json {demo_dir}/demo_ani_thresh_0.95_config.json --sample_file {demo_dir}/ref.sig.zip" res = subprocess.run(cmd, shell=True, check=False) # this should fail - assert res.returncode != 0 + assert res.returncode == 1 def test_demo_workflow(): - script_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) - demo_dir = os.path.join(script_dir, "demo") - cmd = f"cd {demo_dir}; sourmash sketch dna -f -p k=31,scaled=1000,abund -o sample.sig.zip query_data/query_data.fq" + cmd = "cd demo; sourmash sketch dna -f -p k=31,scaled=1000,abund -o sample.sig.zip query_data/query_data.fq" _ = subprocess.run(cmd, shell=True, check=True) - cmd = f"cd {demo_dir}; sourmash sketch fromfile ref_paths.csv -p dna,k=31,scaled=1000,abund -o ref.sig.zip --force-output-already-exists" + cmd = "cd demo; sourmash sketch fromfile ref_paths.csv -p dna,k=31,scaled=1000,abund -o ref.sig.zip --force-output-already-exists" _ = subprocess.run(cmd, shell=True, check=True) - cmd = f"cd {demo_dir}; python ../make_training_data_from_sketches.py --force --ref_file ref.sig.zip --ksize 31 --num_threads 1 --ani_thresh 0.95 --prefix 'demo_ani_thresh_0.95' --outdir ./" + cmd = "cd demo; python ../make_training_data_from_sketches.py --force --ref_file ref.sig.zip --ksize 31 --num_threads 1 --ani_thresh 0.95 --prefix 'demo_ani_thresh_0.95' --outdir ./" _ = subprocess.run(cmd, shell=True, check=True) - cmd = f"cd {demo_dir}; python ../run_YACHT.py --json demo_ani_thresh_0.95_config.json --sample_file sample.sig.zip --significance 0.99 --num_threads 1 --min_coverage_list 1 0.6 0.2 0.1 --out ./result.xlsx" + cmd = "cd demo; python ../run_YACHT.py --json demo_ani_thresh_0.95_config.json --sample_file sample.sig.zip --significance 0.99 --num_threads 1 --min_coverage_list 1 0.6 0.2 0.1 --out_filename result.xlsx" _ = subprocess.run(cmd, shell=True, check=True) - cmd = f"cd {demo_dir}; python ../srcs/standardize_yacht_output.py --yacht_output result.xlsx --sheet_name min_coverage0.2 --genome_to_taxid toy_genome_to_taxid.tsv --mode cami --sample_name 'MySample' --outfile_prefix cami_result --outdir ./" + cmd = "cd demo; python ../srcs/standardize_yacht_output.py --yacht_output result.xlsx --sheet_name min_coverage0.2 --genome_to_taxid toy_genome_to_taxid.tsv --mode cami --sample_name 'MySample' --outfile_prefix cami_result --outdir ./" _ = subprocess.run(cmd, shell=True, check=True) diff --git a/tests/unittests.py b/tests/unittests.py index c5a163d..38e1809 100644 --- a/tests/unittests.py +++ b/tests/unittests.py @@ -1,17 +1,14 @@ +import json import pytest -import sourmash -import pickle -import numpy as np import pandas as pd -import csv import os +import tempfile +import gzip import sys -from scipy.sparse import csc_matrix +import shutil sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) -from srcs.utils import * -from srcs.hypothesis_recovery_src import * - -'''srcs/utils ↓↓↓''' +from srcs.hypothesis_recovery_src import single_hyp_test, get_alt_mut_rate +from srcs.utils import remove_corr_organisms_from_ref, check_file_existence, get_cami_profile, get_column_indices, get_info_from_single_sig, collect_signature_info, run_multisearch @pytest.fixture def test_output_files(): @@ -20,150 +17,21 @@ def test_output_files(): if os.path.exists(filename): os.remove(filename) -def test_write_processed_indices(test_output_files): - signature1 = sourmash.SourmashSignature(name="organism1", minhash=sourmash.MinHash(n=20, ksize=31)) - signature2 = sourmash.SourmashSignature(name="organism2", minhash=sourmash.MinHash(n=20, ksize=21)) - - signatures = [signature1, signature2] - uncorr_org_idx = [0, 1] - - write_processed_indices(test_output_files + '.csv', signatures, uncorr_org_idx) - - with open(test_output_files + '.csv', 'r', newline='', encoding='utf-8') as f: - reader = csv.reader(f) - header = next(reader) - rows = list(reader) - - expected_header = ['organism_name', 'original_index', 'processed_index', 'num_unique_kmers_in_genome_sketch', 'num_total_kmers_in_genome_sketch', 'genome_scale_factor'] - expected_rows = [ - ['organism1', '0', '0', str(len(signature1.minhash.hashes)), str(len(signature1.minhash.hashes)), str(signature1.minhash.scaled)], - ['organism2', '1', '1', str(len(signature2.minhash.hashes)), str(len(signature2.minhash.hashes)), str(signature2.minhash.scaled)] - ] - - assert header == expected_header - assert rows == expected_rows - -def test_write_processed_indices_empty(test_output_files): - signatures = [] - uncorr_org_idx = [] - - write_processed_indices(test_output_files + '.csv', signatures, uncorr_org_idx) - - with open(test_output_files + '.csv', 'r', newline='', encoding='utf-8') as f: - reader = csv.reader(f) - header = next(reader) - rows = list(reader) - - expected_header = ['organism_name', 'original_index', 'processed_index', 'num_unique_kmers_in_genome_sketch', 'num_total_kmers_in_genome_sketch', 'genome_scale_factor'] - expected_rows = [] - - assert header == expected_header - assert rows == expected_rows - -def test_write_hashes(test_output_files): - hashes = {'hash1': 1, 'hash2': 2, 'hash3': 3} - write_hashes(test_output_files + '.pickle', hashes) - - assert os.path.exists(test_output_files + '.pickle') - - with open(test_output_files + '.pickle', 'rb') as fid: - loaded_hashes = pickle.load(fid) - - assert loaded_hashes == hashes - os.remove(test_output_files + '.pickle') - -def test_write_empty_hashes(test_output_files): - hashes = {} - write_hashes(test_output_files + '.pickle', hashes) - - assert os.path.exists(test_output_files + '.pickle') - - with open(test_output_files + '.pickle', 'rb') as fid: - loaded_hashes = pickle.load(fid) - - assert loaded_hashes == hashes - os.remove(test_output_files + '.pickle') - -def test_write_overwrite_file(test_output_files): - initial_hashes = {'hash1': 1, 'hash2': 2} - new_hashes = {'hash3': 3, 'hash4': 4} - - write_hashes(test_output_files + '.pickle', initial_hashes) - write_hashes(test_output_files + '.pickle', new_hashes) - - assert os.path.exists(test_output_files + '.pickle') - - with open(test_output_files + '.pickle', 'rb') as fid: - loaded_hashes = pickle.load(fid) - - assert loaded_hashes == new_hashes - os.remove(test_output_files + '.pickle') - -def test_signatures_to_ref_matrix(): - signature1 = sourmash.SourmashSignature(name="organism1", minhash=sourmash.MinHash(n=0, ksize=31, scaled=5)) - signature2 = sourmash.SourmashSignature(name="organism2", minhash=sourmash.MinHash(n=0, ksize=31, scaled=7)) - real_signatures = [signature1, signature2] - - ksize = 31 - signature_count = 2 - - signature_list, ref_matrix, hash_to_idx, is_mismatch = signatures_to_ref_matrix(real_signatures, ksize, signature_count) - - assert isinstance(ref_matrix, csc_matrix) - assert isinstance(hash_to_idx, dict) - assert ref_matrix.shape[1] == len(real_signatures) - assert is_mismatch is False - -def test_signatures_to_ref_matrix_empty(): - empty_signatures = [] - ksize = 31 - signature_count = 0 - - signature_list, ref_matrix, hash_to_idx, is_mismatch = signatures_to_ref_matrix(empty_signatures, ksize, signature_count) - - assert ref_matrix.shape == (0, 0) - assert len(hash_to_idx) == 0 - assert is_mismatch is False - tmp_dir = "tests/unittests_data/test_tmp" hashes_data = {'hash1': 1, 'hash2': 2, 'hash3': 3} ksize = 31 -def test_load_hashes_to_index(): - with open(os.path.join(tmp_dir, "hash_to_col_idx.pkl"), 'wb') as fid: - pickle.dump(hashes_data, fid) - loaded_hashes = load_hashes_to_index(os.path.join(tmp_dir, "hash_to_col_idx.pkl")) - assert loaded_hashes == hashes_data - - empty_hashes_data = {} - with open(os.path.join(tmp_dir, "empty_hash_to_col_idx.pkl"), 'wb') as fid: - pickle.dump(empty_hashes_data, fid) - loaded_hashes = load_hashes_to_index(os.path.join(tmp_dir, "empty_hash_to_col_idx.pkl")) - assert loaded_hashes == empty_hashes_data - -def test_get_num_kmers(): - minhash = sourmash.MinHash(n=20, ksize=ksize) - signature = sourmash.SourmashSignature(name="test_signature", minhash=minhash) - expected_num_kmers = 0 - num_kmers = get_num_kmers(signature) - assert num_kmers == expected_num_kmers - - minhash = sourmash.MinHash(n=0, ksize=ksize, max_hash=2**64) - signature = sourmash.SourmashSignature(name="test_signature", minhash=minhash) - expected_num_kmers = 0 - num_kmers = get_num_kmers(signature) - assert num_kmers == expected_num_kmers - def test_check_file_existence(): + dont_exist = 'File does not exist' existing_file = os.path.join(tmp_dir, "existing_file.txt") with open(existing_file, 'w') as f: f.write("Test content") - assert check_file_existence(existing_file, "File does not exist") is None + assert check_file_existence(existing_file, dont_exist) is None non_existing_file = os.path.join(tmp_dir, "non_existing_file.txt") - with pytest.raises(ValueError, match="File does not exist"): - check_file_existence(non_existing_file, "File does not exist") + with pytest.raises(ValueError, match=dont_exist): + check_file_existence(non_existing_file, dont_exist) def test_get_column_indices(): column_name_to_index = { @@ -212,41 +80,6 @@ def test_get_cami_profile(): assert prediction2.taxpath == '2' assert prediction2.taxpathsn == 'Bacteria' -ZIP_PATH = 'tests/testdata/20_genomes_sketches.zip' - -def test_count_files_in_zip(): - num_files = count_files_in_zip(ZIP_PATH) - assert num_files == 21 - -'''srcs/hypothesis_recovery_srcs ↓↓↓''' - -def test_get_nontrivial_idx_empty_input(): - A = csc_matrix([]) - y = np.array([]) - result = get_nontrivial_idx(A, y) - expected_result = [] - assert result == expected_result - -def test_get_nontrivial_idx_no_nontrivial_indices(): - A = csc_matrix([[0, 0], [0, 0]]) - y = np.array([0, 0]) - result = get_nontrivial_idx(A, y) - expected_result = [] - assert result.tolist() == expected_result - -def test_get_nontrivial_idx_with_nontrivial_indices(): - A = csc_matrix([[0, 1, 0], [1, 0, 1], [0, 0, 1]]) - y = np.array([1, 0, 1]) - result = get_nontrivial_idx(A, y) - expected_result = [1, 2] - assert result.tolist() == expected_result - -def test_get_exclusive_indicators(): - A = csc_matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - result = get_exclusive_indicators(A) - expected_result = [[0], [1], [2]] - assert result == expected_result - def test_get_alt_mut_rate(): nu = 10 thresh = 5 @@ -273,38 +106,119 @@ def test_get_alt_mut_rate_large_thresh(): result = get_alt_mut_rate(nu, thresh, ksize, significance) expected_result = -1 assert result == expected_result - -def test_hypothesis_recovery(): - A = csc_matrix([[1, 0, 1], [0, 1, 0], [1, 0, 0]]) - y = np.array([1, 0, 1]) - ksize = 31 - significance = 0.99 - ani_thresh = 0.95 - min_coverage = 0.5 - result = hypothesis_recovery(A, y, ksize, significance, ani_thresh, min_coverage) - assert len(result) == 2 -def test_hypothesis_recovery_empty_input(): - A = csc_matrix([]) - y = np.array([]) - ksize = 31 - significance = 0.99 +def test_get_info_from_single_sig(): + sig_list_file = 'gtdb_ani_thresh_0.95_intermediate_files/training_sig_files.txt' + + with open(sig_list_file, 'r') as file: + lines = file.readlines() + if lines: + sig_file_path = lines[0].strip() + else: + raise IOError("Signature list file is empty") + + with tempfile.TemporaryDirectory() as tmpdir: + tmp_sig_file = os.path.join(tmpdir, os.path.basename(sig_file_path)) + + with gzip.open(sig_file_path, 'rb') as f_in: + with open(tmp_sig_file, 'wb') as f_out: + shutil.copyfileobj(f_in, f_out) + + ksize = 0 + result = get_info_from_single_sig(tmp_sig_file, ksize) + + expected_name = "VIKJ01000003.1 Chitinophagaceae bacterium isolate X1_MetaBAT.39 scaffold_1008, whole genome shotgun sequence" + expected_md5sum = "96cb85214535b0f9723a6abc17097821" + expected_mean_abundance = 1.0 + expected_hashes_len = 1984 + expected_scaled = 1000 + + assert result[0] == expected_name + assert result[1] == expected_md5sum + assert abs(result[2] - expected_mean_abundance) < 0.01 + assert result[3] == expected_hashes_len + assert result[4] == expected_scaled + +def test_collect_signature_info(): + num_threads = 2 + ksize = 0 + path_to_temp_dir = 'gtdb_ani_thresh_0.95_intermediate_files/' + + result = collect_signature_info(num_threads, ksize, path_to_temp_dir) + + with open('tests/unittests_data/test_collect_signature_info_data.json', 'r') as file: + expectations = json.load(file) + + for expectation in expectations.keys(): + assert expectation in result + actual_info = result[expectation] + assert expectations[expectation] == list(actual_info) + +def test_run_multisearch(): + num_threads = 32 ani_thresh = 0.95 - min_coverage = 0.5 - result = hypothesis_recovery(A, y, ksize, significance, ani_thresh, min_coverage) - expected_p_vals = [] - assert result[1].tolist() == expected_p_vals + ksize = 31 + scale = 1000 + path_to_temp_dir = 'gtdb_ani_thresh_0.95_intermediate_files/' + + expected_results = {} + + result = run_multisearch(num_threads, ani_thresh, ksize, scale, path_to_temp_dir) + + for signature_name, expected_related_genomes in expected_results.items(): + assert signature_name in result + actual_related_genomes = result[signature_name] + assert set(actual_related_genomes) == set(expected_related_genomes) + +def test_remove_corr_organisms_from_ref(): + sig_info_dict = { + "signature_1": ("md5sum_1", 1.0, 100, 1000), + "related_genome_1": ("md5sum_3", 1.5, 150, 1500), + "related_genome_2": ("md5sum_4", 1.3, 120, 1200), + "related_genome_3": ("md5sum_5", 1.7, 170, 1700), + "related_genome_4": ("md5sum_6", 1.8, 180, 1800), + } + + sig_same_genoms_dict = { + "signature_1": ["related_genome_1", "related_genome_2"], + "related_genome_1": ["related_genome_2"], + "related_genome_3": ["related_genome_4"], + } + + expected_rep_remove_dict = { + 'related_genome_1': ['signature_1', 'related_genome_2'], + 'related_genome_4': ['related_genome_3'] + } -def test_hypothesis_recovery_single_organism(): - A = csc_matrix([[1, 0, 1], [0, 0, 0], [0, 0, 0]]) - y = np.array([1, 0, 1]) + expected_manifest_df = pd.DataFrame([ + ("related_genome_1", "md5sum_3", 150, 225, 1500), + ("related_genome_4", "md5sum_6", 180, 324, 1800), + ], columns=['organism_name', 'md5sum', 'num_unique_kmers_in_genome_sketch', 'num_total_kmers_in_genome_sketch', 'genome_scale_factor']) + + rep_remove_dict, manifest_df = remove_corr_organisms_from_ref(sig_info_dict, sig_same_genoms_dict) + + assert rep_remove_dict == expected_rep_remove_dict + pd.testing.assert_frame_equal(manifest_df, expected_manifest_df) + + +def test_single_hyp_test(): + exclusive_hashes_info_org = (100, 90) ksize = 31 - significance = 0.99 - ani_thresh = 0.95 - min_coverage = 0.5 - result = hypothesis_recovery(A, y, ksize, significance, ani_thresh, min_coverage) - expected_p_vals = [1.0, 0.0, 1.0] - assert result[1].tolist() == expected_p_vals + + result = single_hyp_test(exclusive_hashes_info_org, ksize) + + in_sample_est, p_val, num_exclusive_kmers, num_exclusive_kmers_coverage, num_matches, \ + acceptance_threshold_with_coverage, actual_confidence_with_coverage, alt_confidence_mut_rate_with_coverage = result + + assert isinstance(in_sample_est, int) + assert isinstance(p_val, float) + assert isinstance(num_exclusive_kmers, int) + assert isinstance(num_exclusive_kmers_coverage, int) + assert isinstance(num_matches, int) + assert isinstance(acceptance_threshold_with_coverage, float) + assert isinstance(actual_confidence_with_coverage, float) + assert isinstance(alt_confidence_mut_rate_with_coverage, float) + if __name__ == '__main__': pytest.main() \ No newline at end of file diff --git a/tests/unittests_data/test_collect_signature_info_data.json b/tests/unittests_data/test_collect_signature_info_data.json new file mode 100644 index 0000000..c2cd60c --- /dev/null +++ b/tests/unittests_data/test_collect_signature_info_data.json @@ -0,0 +1 @@ +{"VMDJ01000165.1 Gammaproteobacteria bacterium isolate 27_1 c_000000000223, whole genome shotgun sequence": ["b7f087146f5cc3121477c29ff003e3d0", 1.0036337209302326, 1376, 1000], "VSSA01000053.1 Nocardioides sp. BGMRC 2183 Scaffold102_1, whole genome shotgun sequence": ["1a121dca600c6504e88252e81004f0cf", 1.0034972677595628, 4575, 1000], "VIKI01000038.1 Comamonadaceae bacterium isolate X1_MetaBAT.31 scaffold_1017, whole genome shotgun sequence": ["0661ecab88c3d65d0f10e599a5ba1654", 1.0016663580818368, 5401, 1000], "VIKE01000141.1 Rhodocyclaceae bacterium isolate X1_MetaBAT.22 scaffold_10076, whole genome shotgun sequence": ["8fb9b1a69838a58cc4f31c1e42a5f189", 1.0003899395593683, 5129, 1000], "SHMW01000001.1 Candidatus Lokiarchaeota archaeon isolate BC3 1189800001, whole genome shotgun sequence": ["ce54d962851b0fdeefc624300036a133", 1.000951701165834, 4203, 1000], "VIKH01000154.1 Gallionellaceae bacterium isolate X1_MetaBAT.29 scaffold_10015, whole genome shotgun sequence": ["a136145bee08846ed94c0406df3da2d4", 1.002787456445993, 2870, 1000], "VMDH01000017.1 Gammaproteobacteria bacterium isolate 24_2 c_000000000070, whole genome shotgun sequence": ["92fb1b3e4baa6c474aff3efb84957687", 1.0, 916, 1000], "VMDI01000049.1 Gammaproteobacteria bacterium isolate 24_3 c_000000000093, whole genome shotgun sequence": ["188d55801a78d4773cf6c0b46bca96ba", 1.0030864197530864, 972, 1000], "VMDK01000027.1 Sphingobacteriia bacterium isolate 28_1 c_000000000062, whole genome shotgun sequence": ["04212e93c2172d4df49dc5d8c2973d8b", 1.003282724661469, 2437, 1000], "SSEB01000012.1 Sinobacteraceae bacterium isolate Bin_35_3 c_000000004113, whole genome shotgun sequence": ["c9eb6a9d058df8036ad93bc45d5bf260", 1.001261829652997, 3170, 1000], "SHMX01000001.1 Candidatus Thorarchaeota archaeon isolate BC 1189500001, whole genome shotgun sequence": ["45f2675c1dca4ef1a24a05f5b268adbb", 1.0019261637239165, 3115, 1000], "SSEF01000018.1 Candidatus Moranbacteria bacterium isolate Bin_68_2 c_000000001403, whole genome shotgun sequence": ["c39c52d2d088348c950c2afe503b405b", 1.0050454086781029, 991, 1000], "VPFC01000001.1 [Empedobacter] haloabium strain ATCC 31962 contig1, whole genome shotgun sequence": ["04f2b0e94f2d1f1f5b8355114b70274e", 1.0508555893171279, 6253, 1000], "SSEL01000090.1 Rhodospirillaceae bacterium isolate Bin_26_3 c_000000001054, whole genome shotgun sequence": ["06ebe48d527882bfa9505aba8e31ae23", 1.0010533707865168, 5696, 1000], "WAAQ01000001.1 Microbacterium maritypicum strain DSM 12512 contig00001, whole genome shotgun sequence": ["7b312fffa3fb35440ba40203ba826c05", 1.0002729257641922, 3664, 1000], "VMDM01000010.1 Nitrosopumilus sp. isolate 32_1 c_000000000023, whole genome shotgun sequence": ["b691deddf179ead0a006527330d86dde", 1.001974333662389, 1013, 1000], "SHMU01000001.1 Candidatus Lokiarchaeota archaeon isolate BC1 1189600001, whole genome shotgun sequence": ["11fe9a00287c7ad086ebbc463724cf10", 1.0227748691099476, 3820, 1000], "VIKJ01000003.1 Chitinophagaceae bacterium isolate X1_MetaBAT.39 scaffold_1008, whole genome shotgun sequence": ["96cb85214535b0f9723a6abc17097821", 1.0, 1984, 1000], "VKGY01000191.1 Spirochaetes bacterium isolate X1_MetaBAT.41 scaffold_10187, whole genome shotgun sequence": ["39ea7fd48ee7003587c9c763946d5d6e", 1.000777604976672, 2572, 1000], "CP032507.1 Ectothiorhodospiraceae bacterium BW-2 chromosome, complete genome": ["16c6c1d37259d83088ab3a4f5b691631", 1.0962309542902968, 3741, 1000]} \ No newline at end of file