Skip to content

Commit

Permalink
Merge branch 'main' into YAC-96
Browse files Browse the repository at this point in the history
  • Loading branch information
mfl15 authored Feb 14, 2024
2 parents d26f951 + 93d50f5 commit 02fb4f6
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,6 @@ curl --cookie zenodo-cookies.txt "https://zenodo.org/records/<zendo_id>/files/<f
# curl --cookie zenodo-cookies.txt "https://zenodo.org/records/10113534/files/genbank-2022.03-archaea-k31_0.80_pretrained.zip?download=1" --output genbank-2022.03-archaea-k31_0.80_pretrained.zip
```

**Please note that if you plan to use these pre-trained reference databases, once you download and unzip it. You need to change the paths within the config json file (e.g., gtdb-rs214-reps.k31_0.9995_config.json) to the correct paths in your machine.**

</br>

Expand Down
2 changes: 1 addition & 1 deletion tests/test_download_pretrained_ref_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,4 @@ def test_main(self, mock_unzip, mock_fetch_records, mock_download_file, mock_get
mock_get.return_value.json.return_value = {'hits': {'hits': [{'title': 'test-db'}]}}
mock_download_file.return_value = True
download_script.main(args)
mock_unzip.assert_called_with(os.path.join('.', self.PRETRAINED_ZIP), '.')
mock_unzip.assert_called_with(self.PRETRAINED_ZIP, '.')
23 changes: 20 additions & 3 deletions yacht/download_pretrained_ref_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from loguru import logger
import sys
import os
import json
import zipfile
from .utils import create_output_folder, check_download_args
# Import global variables
Expand Down Expand Up @@ -84,10 +85,25 @@ def download_file(url, output_path):
logger.error(f"Failed to download {url}: {e}")
return False


def update_config_file(file_path):
try:
absolute_path = os.path.abspath(file_path.replace(".zip",""))
config_file = [file for file in os.listdir(absolute_path) if "_config.json" in file][0]
config_file = os.path.join(absolute_path,config_file)
with open(config_file) as fp:
config = json.loads(fp.read())
for key in config:
if isinstance(config[key], str) and config[key].startswith('/'):
base_name = os.path.basename(config[key])
config[key] = os.path.join(absolute_path, base_name)
with open(config_file,'w') as fp:
json.dump(config, fp, indent=4)

except:
logger.error(f"Could not find config file at {absolute_path}")
def unzip_file(file_path, extract_to):
try:
with zipfile.ZipFile(file_path, "r") as zip_ref:
with zipfile.ZipFile(os.path.join(extract_to, file_path), 'r') as zip_ref:
zip_ref.extractall(extract_to)
logger.success(f"Extracted {file_path} to {extract_to}")
except zipfile.BadZipFile:
Expand Down Expand Up @@ -155,7 +171,8 @@ def main(args):
)

if file_url and download_file(file_url, output_path):
unzip_file(output_path, args.outfolder)
unzip_file(file_name_to_search, args.outfolder)
update_config_file(output_path)
else:
logger.warning(f"File '{file_name_to_search}' not found in Zenodo records.")

Expand Down

0 comments on commit 02fb4f6

Please sign in to comment.