diff --git a/units/aiohttp_client.py b/units/aiohttp_client.py new file mode 100644 index 0000000000..465494c3df --- /dev/null +++ b/units/aiohttp_client.py @@ -0,0 +1,24 @@ + +from __future__ import annotations + +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING + +from aiohttp import ClientSession + +if TYPE_CHECKING: + from collections.abc import Iterator + + +@asynccontextmanager +async def ensure_session( + session: ClientSession | None +) -> Iterator[ClientSession]: + if session_not_passed := (session is None): + session = ClientSession() + try: + yield session + finally: + if session_not_passed: + await session.close() + diff --git a/units/location.py b/units/location.py index f757fedd92..c71cb5b832 100644 --- a/units/location.py +++ b/units/location.py @@ -1,17 +1,21 @@ +from __future__ import annotations + import datetime import os +from typing import TYPE_CHECKING + +from .aiohttp_client import ensure_session -import aiohttp +if TYPE_CHECKING: + import aiohttp async def get_geocode_data( location: str, *, aiohttp_session: aiohttp.ClientSession | None = None ) -> dict: # TODO: Add reverse option - if aiohttp_session_not_passed := (aiohttp_session is None): - aiohttp_session = aiohttp.ClientSession() - try: + async with ensure_session(aiohttp_session) as aiohttp_session: async with aiohttp_session.get( "https://maps.googleapis.com/maps/api/geocode/json", params = {"address": location, "key": os.getenv("GOOGLE_API_KEY")} @@ -28,9 +32,6 @@ async def get_geocode_data( raise RuntimeError(f"Error: {error_message}") return geocode_data["results"][0] - finally: - if aiohttp_session_not_passed: - await aiohttp_session.close() async def get_timezone_data( @@ -40,9 +41,7 @@ async def get_timezone_data( longitude: float | int | str | None = None, aiohttp_session: aiohttp.ClientSession | None = None ) -> dict: - if aiohttp_session_not_passed := (aiohttp_session is None): - aiohttp_session = aiohttp.ClientSession() - try: + async with ensure_session(aiohttp_session) as aiohttp_session: if latitude is None and longitude is None: if not location: raise TypeError("location or latitude and longitude required") @@ -73,9 +72,6 @@ async def get_timezone_data( raise RuntimeError(f"Error: {error_message}") return timezone_data - finally: - if aiohttp_session_not_passed: - await aiohttp_session.close() DEGREES_RANGES_TO_DIRECTIONS = { diff --git a/units/runescape.py b/units/runescape.py index 38ad15c52b..bf151c1dfd 100644 --- a/units/runescape.py +++ b/units/runescape.py @@ -1,13 +1,18 @@ -import aiohttp +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .aiohttp_client import ensure_session + +if TYPE_CHECKING: + import aiohttp async def get_item_id( item: str, *, aiohttp_session: aiohttp.ClientSession | None = None ) -> int: - if aiohttp_session_not_passed := (aiohttp_session is None): - aiohttp_session = aiohttp.ClientSession() - try: + async with ensure_session(aiohttp_session) as aiohttp_session: # https://runescape.wiki/w/Application_programming_interface#Grand_Exchange_Database_API # https://www.mediawiki.org/wiki/API:Opensearch # TODO: Handle redirects? @@ -41,9 +46,6 @@ async def get_item_id( return item_id[0] raise ValueError(f"{item} is not an item") - finally: - if aiohttp_session_not_passed: - await aiohttp_session.close() async def get_ge_data( @@ -52,9 +54,7 @@ async def get_ge_data( item_id: int | str | None = None, aiohttp_session: aiohttp.ClientSession | None = None ) -> dict: - if aiohttp_session_not_passed := (aiohttp_session is None): - aiohttp_session = aiohttp.ClientSession() - try: + async with ensure_session(aiohttp_session) as aiohttp_session: if item_id is None: item_id = await get_item_id(item, aiohttp_session = aiohttp_session) async with aiohttp_session.get( @@ -65,17 +65,12 @@ async def get_ge_data( raise ValueError(f"{item} not found on the Grand Exchange") data = await resp.json(content_type = "text/html") return data["item"] - finally: - if aiohttp_session_not_passed: - await aiohttp_session.close() async def get_monster_data( monster: str, *, aiohttp_session: aiohttp.ClientSession | None = None ) -> dict: - if aiohttp_session_not_passed := (aiohttp_session is None): - aiohttp_session = aiohttp.ClientSession() - try: + async with ensure_session(aiohttp_session) as aiohttp_session: async with aiohttp_session.get( "http://services.runescape.com/m=itemdb_rs/bestiary/beastSearch.json", params = {"term": monster} @@ -89,7 +84,4 @@ async def get_monster_data( ) as resp: data = await resp.json(content_type = "text/html") return data - finally: - if aiohttp_session_not_passed: - await aiohttp_session.close() diff --git a/units/wikis.py b/units/wikis.py index fb1a7be682..7a4fe4c546 100644 --- a/units/wikis.py +++ b/units/wikis.py @@ -4,11 +4,13 @@ import re from typing import TYPE_CHECKING -import aiohttp from bs4 import BeautifulSoup from pydantic import BaseModel +from .aiohttp_client import ensure_session + if TYPE_CHECKING: + import aiohttp from collections.abc import Iterable @@ -34,9 +36,7 @@ async def search_wiki( ) -> WikiArticle: # TODO: Add User-Agent # TODO: Use textwrap - if aiohttp_session_not_passed := (aiohttp_session is None): - aiohttp_session = aiohttp.ClientSession() - try: + async with ensure_session(aiohttp_session) as aiohttp_session: if random: if not isinstance(random_namespaces, int | str): random_namespaces = '|'.join( @@ -153,7 +153,4 @@ async def search_wiki( ) if thumbnail else None ) ) - finally: - if aiohttp_session_not_passed: - await aiohttp_session.close()