Skip to content

Commit

Permalink
feat: verify downloaded image size is expected
Browse files Browse the repository at this point in the history
  • Loading branch information
y-young committed Apr 5, 2024
1 parent d18804f commit 9ccc9d3
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 25 deletions.
10 changes: 7 additions & 3 deletions nazurin/models/file.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import pathlib
from dataclasses import dataclass
from typing import Optional

import aiofiles
import aiofiles.os
Expand Down Expand Up @@ -45,14 +46,15 @@ def destination(self) -> pathlib.Path:
def destination(self, value: str):
self._destination = sanitize_path(value)

async def size(self):
async def size(self) -> Optional[int]:
"""
Get file size in bytes
"""

if os.path.exists(self.path):
stat = await aiofiles.os.stat(self.path)
return stat.st_size
return None

async def exists(self) -> bool:
if (
Expand All @@ -63,11 +65,13 @@ async def exists(self) -> bool:
return False

@network_retry
async def download(self, session: NazurinRequestSession):
async def download(self, session: NazurinRequestSession) -> Optional[int]:
if await self.exists():
logger.info("File {} already exists", self.path)
return True
await ensure_existence_async(TEMP_DIR)
logger.info("Downloading {} to {}...", self.url, self.path)
await session.download(self.url, self.path)
logger.info("Downloaded to {}", self.path)
size = await self.size()
logger.info("Downloaded to {}, size = {}", self.path, size)
return size
34 changes: 21 additions & 13 deletions nazurin/models/image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from dataclasses import dataclass
from typing import Optional

import aiohttp
from humanize import naturalsize
Expand All @@ -14,7 +15,7 @@
@dataclass
class Image(File):
thumbnail: str = None
_size: int = None
_size: Optional[int] = None
"""
File size in bytes
"""
Expand Down Expand Up @@ -88,20 +89,27 @@ def set_size(self, value: int):
async def download(self, session: aiohttp.ClientSession):
RETRIES = 3
for i in range(RETRIES):
await super().download(session)
downloaded_size = await super().download(session)
is_valid = await check_image(self.path)
attempt_count = f"{i + 1} / {RETRIES}"
if is_valid:
break
logger.warning(
"Downloaded image {} is not valid, retry {} / {}",
self.path,
i + 1,
RETRIES,
)
if self._size is None or self._size == downloaded_size:
return
logger.warning(
"Downloaded file size {} does not match image size {}, attempt {}",
downloaded_size,
self._size,
attempt_count,
)
else:
logger.warning(
"Downloaded image {} is not valid, attempt {}",
self.path,
attempt_count,
)
if i < RETRIES - 1:
# Keep the last one for debugging
os.remove(self.path)
if not is_valid:
raise NazurinError(
"Download failed with invalid image, please check logs for details"
)
raise NazurinError(
"Download failed with invalid image, please check logs for details"
)
18 changes: 9 additions & 9 deletions nazurin/utils/network.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import os
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import AsyncContextManager, Generator, Optional, Union
from typing import AsyncContextManager, AsyncGenerator, Optional, Union

import aiofiles
import cloudscraper
Expand All @@ -21,7 +21,7 @@ def __init__(
cookies: Optional[dict] = None,
headers: Optional[dict] = None,
timeout: int = TIMEOUT,
**kwargs
**kwargs,
):
raise NotImplementedError

Expand All @@ -45,7 +45,7 @@ def __init__(
cookies: Optional[dict] = None,
headers: Optional[dict] = None,
timeout: int = TIMEOUT,
**kwargs
**kwargs,
):
headers = headers or {}
headers.update({"User-Agent": UA})
Expand All @@ -60,7 +60,7 @@ def __init__(
headers=headers,
trust_env=True,
timeout=timeout,
**kwargs
**kwargs,
)

async def download(self, url: str, destination: Union[str, os.PathLike]):
Expand All @@ -85,7 +85,7 @@ def __init__(
cookies: Optional[dict] = None,
headers: Optional[dict] = None,
timeout: int = TIMEOUT,
**kwargs
**kwargs,
):
self.cookies = cookies
self.headers = headers
Expand All @@ -96,7 +96,7 @@ def __init__(
@asynccontextmanager
async def get(
self, *args, impersonate: str = "chrome110", **kwargs
) -> Generator[CurlResponse, None, None]:
) -> AsyncGenerator[CurlResponse, None]:
yield await super().request(
"GET",
*args,
Expand All @@ -105,7 +105,7 @@ async def get(
timeout=self.timeout,
impersonate=impersonate,
proxies=self.proxies,
**kwargs
**kwargs,
)

async def download(self, url: str, destination: Union[str, os.PathLike]):
Expand All @@ -132,7 +132,7 @@ def __init__(
cookies: Optional[dict] = None,
headers: Optional[dict] = None,
timeout: int = TIMEOUT,
**kwargs
**kwargs,
):
proxies = {"https": PROXY, "http": PROXY} if PROXY else {}
session = Session()
Expand All @@ -145,7 +145,7 @@ def __init__(
@asynccontextmanager
async def get(
self, *args, **kwargs
) -> Generator[cloudscraper.requests.Response, None, None]:
) -> AsyncGenerator[cloudscraper.requests.Response, None]:
yield await async_wrap(self.scraper.get)(*args, timeout=self.timeout, **kwargs)

async def download(self, url: str, destination: Union[str, os.PathLike]):
Expand Down

0 comments on commit 9ccc9d3

Please sign in to comment.