Skip to content

Commit

Permalink
[Units] Add and use aiohttp_client.ensure_session
Browse files Browse the repository at this point in the history
  • Loading branch information
Harmon758 committed Jul 26, 2023
1 parent 0f2ae2c commit e66b38c
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 39 deletions.
24 changes: 24 additions & 0 deletions units/aiohttp_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

from __future__ import annotations

from contextlib import asynccontextmanager
from typing import TYPE_CHECKING

from aiohttp import ClientSession

if TYPE_CHECKING:
from collections.abc import Iterator


@asynccontextmanager
async def ensure_session(
session: ClientSession | None
) -> Iterator[ClientSession]:
if session_not_passed := (session is None):
session = ClientSession()
try:
yield session
finally:
if session_not_passed:
await session.close()

22 changes: 9 additions & 13 deletions units/location.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@

from __future__ import annotations

import datetime
import os
from typing import TYPE_CHECKING

from .aiohttp_client import ensure_session

import aiohttp
if TYPE_CHECKING:
import aiohttp


async def get_geocode_data(
location: str, *, aiohttp_session: aiohttp.ClientSession | None = None
) -> dict:
# TODO: Add reverse option
if aiohttp_session_not_passed := (aiohttp_session is None):
aiohttp_session = aiohttp.ClientSession()
try:
async with ensure_session(aiohttp_session) as aiohttp_session:
async with aiohttp_session.get(
"https://maps.googleapis.com/maps/api/geocode/json",
params = {"address": location, "key": os.getenv("GOOGLE_API_KEY")}
Expand All @@ -28,9 +32,6 @@ async def get_geocode_data(
raise RuntimeError(f"Error: {error_message}")

return geocode_data["results"][0]
finally:
if aiohttp_session_not_passed:
await aiohttp_session.close()


async def get_timezone_data(
Expand All @@ -40,9 +41,7 @@ async def get_timezone_data(
longitude: float | int | str | None = None,
aiohttp_session: aiohttp.ClientSession | None = None
) -> dict:
if aiohttp_session_not_passed := (aiohttp_session is None):
aiohttp_session = aiohttp.ClientSession()
try:
async with ensure_session(aiohttp_session) as aiohttp_session:
if latitude is None and longitude is None:
if not location:
raise TypeError("location or latitude and longitude required")
Expand Down Expand Up @@ -73,9 +72,6 @@ async def get_timezone_data(
raise RuntimeError(f"Error: {error_message}")

return timezone_data
finally:
if aiohttp_session_not_passed:
await aiohttp_session.close()


DEGREES_RANGES_TO_DIRECTIONS = {
Expand Down
30 changes: 11 additions & 19 deletions units/runescape.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@

import aiohttp
from __future__ import annotations

from typing import TYPE_CHECKING

from .aiohttp_client import ensure_session

if TYPE_CHECKING:
import aiohttp


async def get_item_id(
item: str, *, aiohttp_session: aiohttp.ClientSession | None = None
) -> int:
if aiohttp_session_not_passed := (aiohttp_session is None):
aiohttp_session = aiohttp.ClientSession()
try:
async with ensure_session(aiohttp_session) as aiohttp_session:
# https://runescape.wiki/w/Application_programming_interface#Grand_Exchange_Database_API
# https://www.mediawiki.org/wiki/API:Opensearch
# TODO: Handle redirects?
Expand Down Expand Up @@ -41,9 +46,6 @@ async def get_item_id(
return item_id[0]

raise ValueError(f"{item} is not an item")
finally:
if aiohttp_session_not_passed:
await aiohttp_session.close()


async def get_ge_data(
Expand All @@ -52,9 +54,7 @@ async def get_ge_data(
item_id: int | str | None = None,
aiohttp_session: aiohttp.ClientSession | None = None
) -> dict:
if aiohttp_session_not_passed := (aiohttp_session is None):
aiohttp_session = aiohttp.ClientSession()
try:
async with ensure_session(aiohttp_session) as aiohttp_session:
if item_id is None:
item_id = await get_item_id(item, aiohttp_session = aiohttp_session)
async with aiohttp_session.get(
Expand All @@ -65,17 +65,12 @@ async def get_ge_data(
raise ValueError(f"{item} not found on the Grand Exchange")
data = await resp.json(content_type = "text/html")
return data["item"]
finally:
if aiohttp_session_not_passed:
await aiohttp_session.close()


async def get_monster_data(
monster: str, *, aiohttp_session: aiohttp.ClientSession | None = None
) -> dict:
if aiohttp_session_not_passed := (aiohttp_session is None):
aiohttp_session = aiohttp.ClientSession()
try:
async with ensure_session(aiohttp_session) as aiohttp_session:
async with aiohttp_session.get(
"http://services.runescape.com/m=itemdb_rs/bestiary/beastSearch.json",
params = {"term": monster}
Expand All @@ -89,7 +84,4 @@ async def get_monster_data(
) as resp:
data = await resp.json(content_type = "text/html")
return data
finally:
if aiohttp_session_not_passed:
await aiohttp_session.close()

11 changes: 4 additions & 7 deletions units/wikis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import re
from typing import TYPE_CHECKING

import aiohttp
from bs4 import BeautifulSoup
from pydantic import BaseModel

from .aiohttp_client import ensure_session

if TYPE_CHECKING:
import aiohttp
from collections.abc import Iterable


Expand All @@ -34,9 +36,7 @@ async def search_wiki(
) -> WikiArticle:
# TODO: Add User-Agent
# TODO: Use textwrap
if aiohttp_session_not_passed := (aiohttp_session is None):
aiohttp_session = aiohttp.ClientSession()
try:
async with ensure_session(aiohttp_session) as aiohttp_session:
if random:
if not isinstance(random_namespaces, int | str):
random_namespaces = '|'.join(
Expand Down Expand Up @@ -153,7 +153,4 @@ async def search_wiki(
) if thumbnail else None
)
)
finally:
if aiohttp_session_not_passed:
await aiohttp_session.close()

0 comments on commit e66b38c

Please sign in to comment.