Skip to content

Commit

Permalink
Refactor authentication with requests BaseAuth
Browse files Browse the repository at this point in the history
  • Loading branch information
J535D165 committed Oct 19, 2023
1 parent 2fc6914 commit 380afdb
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions pyalex/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from urllib.parse import quote_plus

import requests
from requests.auth import AuthBase
from urllib3.util import Retry

try:
Expand All @@ -21,7 +22,7 @@ def __setattr__(self, key, value):

config = AlexConfig(
email=None,
api_key=None,
api_key=None,
openalex_url="https://api.openalex.org",
max_retries=0,
retry_backoff_factor=0.1,
Expand Down Expand Up @@ -100,7 +101,7 @@ def get_requests_session():
'https://',
requests.adapters.HTTPAdapter(max_retries=retries)
)

return requests_session


Expand Down Expand Up @@ -179,6 +180,18 @@ def Venue(*args, **kwargs):
return Source(*args, **kwargs)


class OpenAlexAuth(AuthBase):
"""OpenAlex auth class based on requests auth"""

def __init__(self, api_key):

self.api_key = api_key

def __call__(self, r):
r.params["api_key"] = self.api_key
return r


class CursorPaginator:
def __init__(self, alex_class=None, per_page=None, cursor="*", n_max=None):

Expand Down Expand Up @@ -248,11 +261,10 @@ def __getitem__(self, record_id):
return self._get_multi_items(record_id)

url = self._full_collection_name() + "/" + record_id
params = {"api_key": config.api_key} if config.api_key else {}
res = get_requests_session().get(
res = requests.get(
url,
headers={"User-Agent": "pyalex/" + __version__, "email": config.email},
params=params,
auth=OpenAlexAuth(config.api_key) if config.api_key else None,
)
res.raise_for_status()
res_json = res.json()
Expand Down Expand Up @@ -297,11 +309,10 @@ def get(self, return_meta=False, page=None, per_page=None, cursor=None):
self._add_params("page", page)
self._add_params("cursor", cursor)

params = {"api_key": config.api_key} if config.api_key else {}
res = get_requests_session().get(
res = requests.get(
self.url,
headers={"User-Agent": "pyalex/" + __version__, "email": config.email},
params=params,
auth=OpenAlexAuth(config.api_key) if config.api_key else None,
)

# handle query errors
Expand Down

0 comments on commit 380afdb

Please sign in to comment.