Skip to content

Commit

Permalink
More improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
mephenor committed Nov 28, 2024
1 parent baa8e8d commit ffd2c20
Show file tree
Hide file tree
Showing 8 changed files with 1,054 additions and 1,101 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ repos:
- id: no-commit-to-branch
args: [--branch, dev, --branch, int, --branch, main]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.3
rev: v0.8.0
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.2
rev: v1.13.0
hooks:
- id: mypy
args: [--no-warn-unused-ignores]
1,693 changes: 818 additions & 875 deletions lock/requirements-dev.txt

Large diffs are not rendered by default.

358 changes: 185 additions & 173 deletions lock/requirements.txt

Large diffs are not rendered by default.

10 changes: 9 additions & 1 deletion src/ghga_connector/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,16 @@ def download(
debug: bool = typer.Option(
False, help="Set this option in order to view traceback for errors."
),
overwrite: bool = typer.Option(
False,
help="Set to true to overwrite already existing files in the output directory.",
),
):
"""Wrapper for the async download function"""
asyncio.run(
async_download(output_dir, my_public_key_path, my_private_key_path, debug)
async_download(
output_dir, my_public_key_path, my_private_key_path, debug, overwrite
)
)


Expand All @@ -272,6 +278,7 @@ async def async_download(
my_public_key_path: Path,
my_private_key_path: Path,
debug: bool = False,
overwrite: bool = False,
):
"""Download files asynchronously"""
if not my_public_key_path.is_file():
Expand Down Expand Up @@ -330,6 +337,7 @@ async def async_download(
part_size=CONFIG.part_size,
message_display=message_display,
work_package_accessor=parameters.work_package_accessor,
overwrite=overwrite,
)
staged_files.clear()

Expand Down
9 changes: 3 additions & 6 deletions src/ghga_connector/core/downloading/abstract_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,13 @@ def download_file(self, *, output_path: Path, part_size: int):
"""Download file to the specified location and manage lower level details."""

@abstractmethod
def await_download_url(self) -> Coroutine[URLResponse, Any, Any]:
def fetch_download_url(self) -> Coroutine[URLResponse, Any, Any]:
"""Wait until download URL can be generated.
Returns a URLResponse containing two elements:
1. the download url
2. the file size in bytes
"""

@abstractmethod
def get_download_url(self) -> Coroutine[URLResponse, Any, Any]:
"""Fetch a presigned URL from which file data can be downloaded."""

@abstractmethod
def get_file_header_envelope(self) -> Coroutine[bytes, Any, Any]:
"""
Expand All @@ -54,7 +50,7 @@ def get_file_header_envelope(self) -> Coroutine[bytes, Any, Any]:
"""

@abstractmethod
async def download_to_queue(self, *, part_range: PartRange) -> None:
async def download_to_queue(self, *, url: str, part_range: PartRange) -> None:
"""
Start downloading file parts in parallel into a queue.
This should be wrapped into asyncio.task and is guarded by a semaphore to limit
Expand All @@ -65,6 +61,7 @@ async def download_to_queue(self, *, part_range: PartRange) -> None:
async def download_content_range(
self,
*,
url: str,
start: int,
end: int,
) -> None:
Expand Down
63 changes: 24 additions & 39 deletions src/ghga_connector/core/downloading/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#
"""Contains a concrete implementation of the abstract downloader"""

import asyncio
import base64
from asyncio import Queue, Semaphore, Task, create_task
from collections.abc import Coroutine
Expand Down Expand Up @@ -85,10 +84,8 @@ def __init__( # noqa: PLR0913

async def download_file(self, *, output_path: Path, part_size: int):
"""Download file to the specified location and manage lower level details."""
# stage download and get file size
url_response = await self.await_download_url()

# Split the file into parts based on the part size
url_response = await self.fetch_download_url()
part_ranges = calc_part_ranges(
part_size=part_size, total_file_size=url_response.file_size
)
Expand All @@ -97,7 +94,11 @@ async def download_file(self, *, output_path: Path, part_size: int):

# start async part download to intermediate queue
for part_range in part_ranges:
await task_handler.schedule(self.download_to_queue(part_range=part_range))
await task_handler.schedule(
self.download_to_queue(
url=url_response.download_url, part_range=part_range
)
)

# get file header envelope
try:
Expand Down Expand Up @@ -125,17 +126,16 @@ async def download_file(self, *, output_path: Path, part_size: int):
)
await write_to_file

async def await_download_url(self) -> URLResponse:
"""Wait until download URL can be generated.
async def fetch_download_url(self) -> URLResponse:
"""Fetch a work order token and retrieve the download url.
Returns a URLResponse containing two elements:
1. the download url
2. the file size in bytes
"""
# get the download_url, wait if needed

try:
self._message_display.display(
f"Fetching file authorization for {self._file_id}"
f"Fetching work order token for {self._file_id}"
)
url_and_headers = await get_file_authorization(
file_id=self._file_id,
Expand All @@ -147,31 +147,20 @@ async def await_download_url(self) -> URLResponse:
)
except exceptions.BadResponseCodeError as error:
self._message_display.failure(
"The request was invalid and returned a bad HTTP status code."
f"The request for file {self._file_id} returned an unexpected HTTP status code: {error.response_code}."
)
raise error
except exceptions.RequestFailedError as error:
self._message_display.failure("The request failed.")
self._message_display.failure(
f"The download request for file {self._file_id} failed."
)
raise error

return response # type: ignore

async def get_download_url(self) -> URLResponse:
"""Fetch a presigned URL from which file data can be downloaded."""
self._message_display.display(
f"Fetching file authorization for {self._file_id}"
)
url_and_headers = await get_file_authorization(
file_id=self._file_id, work_package_accessor=self._work_package_accessor
)
self._message_display.display(f"Fetching download URL for {self._file_id}")
url_response = await get_download_url(
client=self._client, url_and_headers=url_and_headers
)
if isinstance(url_response, RetryResponse):
if isinstance(response, RetryResponse):
# File should be staged at that point in time
raise exceptions.UnexpectedRetryResponseError()
return url_response

return response

async def get_file_header_envelope(self) -> bytes:
"""
Expand Down Expand Up @@ -219,7 +208,7 @@ async def get_file_header_envelope(self) -> bytes:
ResponseExceptionTranslator(spec=spec).handle(response=response)
raise exceptions.BadResponseCodeError(url=url, response_code=status_code)

async def download_to_queue(self, *, part_range: PartRange) -> None:
async def download_to_queue(self, *, url: str, part_range: PartRange) -> None:
"""
Start downloading file parts in parallel into a queue.
This should be wrapped into asyncio.task and is guarded by a semaphore to limit
Expand All @@ -229,34 +218,33 @@ async def download_to_queue(self, *, part_range: PartRange) -> None:
async with self._semaphore:
try:
await self.download_content_range(
start=part_range.start, end=part_range.stop
url=url, start=part_range.start, end=part_range.stop
)
except BaseException as exception:
await self._queue.put(exception)

async def download_content_range(
self,
*,
url: str,
start: int,
end: int,
) -> None:
"""Download a specific range of a file's content using a presigned download url."""
headers = httpx.Headers({"Range": f"bytes={start}-{end}"})

url_response = await self.get_download_url()
download_url = url_response.download_url
try:
response: httpx.Response = await self._retry_handler(
fn=self._client.get, url=download_url, headers=headers
fn=self._client.get, url=url, headers=headers
)
except RetryError as retry_error:
wrapped_exception = retry_error.last_attempt.exception()

if isinstance(wrapped_exception, httpx.RequestError):
exceptions.raise_if_connection_failed(
request_error=wrapped_exception, url=download_url
request_error=wrapped_exception, url=url
)
raise exceptions.RequestFailedError(url=download_url) from retry_error
raise exceptions.RequestFailedError(url=url) from retry_error
elif wrapped_exception:
raise wrapped_exception from retry_error
elif result := retry_error.last_attempt.result():
Expand All @@ -271,9 +259,7 @@ async def download_content_range(
await self._queue.put((start, response.content))
return

raise exceptions.BadResponseCodeError(
url=download_url, response_code=status_code
)
raise exceptions.BadResponseCodeError(url=url, response_code=status_code)

async def drain_queue_to_file(
self, *, file_name: str, file: BufferedWriter, file_size: int, offset: int
Expand All @@ -299,4 +285,3 @@ async def drain_queue_to_file(
downloaded_size += chunk_size
self._queue.task_done()
progress.advance(chunk_size)
await asyncio.sleep(0)
2 changes: 2 additions & 0 deletions src/ghga_connector/core/downloading/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def __exit__(self, exc_type, exc_value, traceback):
if exc_type:
self._progress.remove_task(self._task_id)
self._progress.stop()
# add a newline so next output is alway printed on a separate line
print()

def advance(self, size: int):
"""Advance progress bar by specified amount of bytes and display."""
Expand Down
16 changes: 11 additions & 5 deletions src/ghga_connector/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ async def download_files( # noqa: PLR0913
work_package_accessor: WorkPackageAccessor,
file_id: str,
file_extension: str = "",
overwrite: bool = False,
) -> None:
"""Core command to download a file. Can be called by CLI, GUI, etc."""
if not is_service_healthy(api_url):
Expand All @@ -117,8 +118,16 @@ async def download_files( # noqa: PLR0913

# check output file
output_file = output_dir / f"{file_name}.c4gh"
# if output_file.exists():
# raise exceptions.FileAlreadyExistsError(output_file=str(output_file))
if output_file.exists():
if overwrite:
message_display.display(
f"A file with name '{output_file}' already exists and will be overwritten."
)
else:
message_display.failure(
f"A file with name '{output_file}' already exists. Skipping."
)
return

# with_suffix() might overwrite existing suffixes, do this instead
output_file_ongoing = output_file.parent / (output_file.name + ".part")
Expand Down Expand Up @@ -147,9 +156,6 @@ async def download_files( # noqa: PLR0913
raise error

# rename fully downloaded file
# TODO: don't error here for now
# if output_file.exists():
# raise exceptions.RenameDownloadedFileError(file_path=output_file)
output_file_ongoing.rename(output_file)

message_display.success(
Expand Down

0 comments on commit ffd2c20

Please sign in to comment.