From 81c8b24e691bed2c9b4ef6a110e68fb5f1a84a9a Mon Sep 17 00:00:00 2001 From: yyoung Date: Fri, 5 Apr 2024 15:41:50 +0800 Subject: [PATCH] feat: verify downloaded image size is expected --- nazurin/models/file.py | 10 +++++++--- nazurin/models/image.py | 34 +++++++++++++++++++++------------- nazurin/utils/network.py | 18 +++++++++--------- 3 files changed, 37 insertions(+), 25 deletions(-) diff --git a/nazurin/models/file.py b/nazurin/models/file.py index 1134065e..1e5658ee 100644 --- a/nazurin/models/file.py +++ b/nazurin/models/file.py @@ -1,6 +1,7 @@ import os import pathlib from dataclasses import dataclass +from typing import Optional import aiofiles import aiofiles.os @@ -45,7 +46,7 @@ def destination(self) -> pathlib.Path: def destination(self, value: str): self._destination = sanitize_path(value) - async def size(self): + async def size(self) -> Optional[int]: """ Get file size in bytes """ @@ -53,6 +54,7 @@ async def size(self): if os.path.exists(self.path): stat = await aiofiles.os.stat(self.path) return stat.st_size + return None async def exists(self) -> bool: if ( @@ -63,11 +65,13 @@ async def exists(self) -> bool: return False @network_retry - async def download(self, session: NazurinRequestSession): + async def download(self, session: NazurinRequestSession) -> Optional[int]: if await self.exists(): logger.info("File {} already exists", self.path) return True await ensure_existence_async(TEMP_DIR) logger.info("Downloading {} to {}...", self.url, self.path) await session.download(self.url, self.path) - logger.info("Downloaded to {}", self.path) + size = await self.size() + logger.info("Downloaded to {}, size = {}", self.path, size) + return size diff --git a/nazurin/models/image.py b/nazurin/models/image.py index 1d3340a7..85c84ef3 100644 --- a/nazurin/models/image.py +++ b/nazurin/models/image.py @@ -1,5 +1,6 @@ import os from dataclasses import dataclass +from typing import Optional import aiohttp from humanize import naturalsize @@ -14,7 +15,7 @@ @dataclass class Image(File): thumbnail: str = None - _size: int = None + _size: Optional[int] = None """ File size in bytes """ @@ -88,20 +89,27 @@ def set_size(self, value: int): async def download(self, session: aiohttp.ClientSession): RETRIES = 3 for i in range(RETRIES): - await super().download(session) + downloaded_size = await super().download(session) is_valid = await check_image(self.path) + attempt_count = f"{i + 1} / {RETRIES}" if is_valid: - break - logger.warning( - "Downloaded image {} is not valid, retry {} / {}", - self.path, - i + 1, - RETRIES, - ) + if self._size is None or self._size == downloaded_size: + return + logger.warning( + "Downloaded file size {} does not match image size {}, attempt {}", + downloaded_size, + self._size, + attempt_count, + ) + else: + logger.warning( + "Downloaded image {} is not valid, attempt {}", + self.path, + attempt_count, + ) if i < RETRIES - 1: # Keep the last one for debugging os.remove(self.path) - if not is_valid: - raise NazurinError( - "Download failed with invalid image, please check logs for details" - ) + raise NazurinError( + "Download failed with invalid image, please check logs for details" + ) diff --git a/nazurin/utils/network.py b/nazurin/utils/network.py index e42147bb..6e7d8d1e 100644 --- a/nazurin/utils/network.py +++ b/nazurin/utils/network.py @@ -1,7 +1,7 @@ import abc import os from contextlib import AbstractAsyncContextManager, asynccontextmanager -from typing import AsyncContextManager, Generator, Optional, Union +from typing import AsyncContextManager, AsyncGenerator, Optional, Union import aiofiles import cloudscraper @@ -21,7 +21,7 @@ def __init__( cookies: Optional[dict] = None, headers: Optional[dict] = None, timeout: int = TIMEOUT, - **kwargs + **kwargs, ): raise NotImplementedError @@ -45,7 +45,7 @@ def __init__( cookies: Optional[dict] = None, headers: Optional[dict] = None, timeout: int = TIMEOUT, - **kwargs + **kwargs, ): headers = headers or {} headers.update({"User-Agent": UA}) @@ -60,7 +60,7 @@ def __init__( headers=headers, trust_env=True, timeout=timeout, - **kwargs + **kwargs, ) async def download(self, url: str, destination: Union[str, os.PathLike]): @@ -85,7 +85,7 @@ def __init__( cookies: Optional[dict] = None, headers: Optional[dict] = None, timeout: int = TIMEOUT, - **kwargs + **kwargs, ): self.cookies = cookies self.headers = headers @@ -96,7 +96,7 @@ def __init__( @asynccontextmanager async def get( self, *args, impersonate: str = "chrome110", **kwargs - ) -> Generator[CurlResponse, None, None]: + ) -> AsyncGenerator[CurlResponse, None]: yield await super().request( "GET", *args, @@ -105,7 +105,7 @@ async def get( timeout=self.timeout, impersonate=impersonate, proxies=self.proxies, - **kwargs + **kwargs, ) async def download(self, url: str, destination: Union[str, os.PathLike]): @@ -132,7 +132,7 @@ def __init__( cookies: Optional[dict] = None, headers: Optional[dict] = None, timeout: int = TIMEOUT, - **kwargs + **kwargs, ): proxies = {"https": PROXY, "http": PROXY} if PROXY else {} session = Session() @@ -145,7 +145,7 @@ def __init__( @asynccontextmanager async def get( self, *args, **kwargs - ) -> Generator[cloudscraper.requests.Response, None, None]: + ) -> AsyncGenerator[cloudscraper.requests.Response, None]: yield await async_wrap(self.scraper.get)(*args, timeout=self.timeout, **kwargs) async def download(self, url: str, destination: Union[str, os.PathLike]):