From c912658b4b9fee2ee5ae9476ca8b4327a80d93cf Mon Sep 17 00:00:00 2001 From: Jonathan de Bruin Date: Sat, 22 Jul 2023 14:08:18 +0200 Subject: [PATCH] Implement query --- pydatacite/api.py | 22 +++++++++++++++++-- tests/test_datacite.py | 49 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 67 insertions(+), 4 deletions(-) diff --git a/pydatacite/api.py b/pydatacite/api.py index 3426e6f..693b52a 100644 --- a/pydatacite/api.py +++ b/pydatacite/api.py @@ -41,7 +41,10 @@ def _flatten_kv(d, prefix=""): # workaround for bug https://groups.google.com/u/1/g/datacite-users/c/t46RWnzZaXc d = str(d).lower() if isinstance(d, bool) else d - return f"{prefix}:{d}" + if prefix: + return f"{prefix}:{d}" + else: + return str(d) def _params_merge(params, add_params): @@ -204,7 +207,8 @@ def url(self): elif isinstance(v, list): v_quote = [quote_plus(q) for q in v] l_params.append(k + "=" + ",".join(v_quote)) - elif k in ["filter", "sort"]: + elif k in ["filter", "query", "sort"]: + print(_flatten_kv(v)) l_params.append(k + "=" + _flatten_kv(v)) else: l_params.append(k + "=" + quote_plus(str(v))) @@ -249,6 +253,7 @@ def get(self, return_meta=False, page=None, per_page=None, cursor=None): self._add_params("page[number]", page) self._add_params("page[cursor]", cursor) + print(self.params) res_json = self._get_raw(self.url) results = [self.resource_class(ent) for ent in res_json["data"]] @@ -309,6 +314,19 @@ def filter(self, **kwargs): return self + def query(self, *args, **kwargs): + + if len(args) > 1: + raise ValueError("Maximal 1 positional argument possible") + + if len(args) == 1: + self._add_params("query", args[0]) + else: + self._add_params("query", kwargs) + + return self + + def sort(self, **kwargs): self._add_params("sort", kwargs) diff --git a/tests/test_datacite.py b/tests/test_datacite.py index 77c65dc..2b49fb3 100644 --- a/tests/test_datacite.py +++ b/tests/test_datacite.py @@ -2,6 +2,7 @@ from pathlib import Path import pytest +import requests from requests import HTTPError from pydatacite import DOI @@ -71,7 +72,35 @@ def test_work_error(): def test_random_dois(): - assert isinstance(DOIs().random(), dict) + assert isinstance(DOIs().random(), DOI) + + +def test_query_single(): + + r = requests.get( + "https://api.datacite.org/dois?query=climate%20change" + ).json() + + n = DOIs().query("climate change").count() + + assert n == r["meta"]["total"] + +def test_query_count(): + + r = requests.get( + "https://api.datacite.org/dois?query=titles.title:(climate+change)" + ).json() + + n = DOIs().query(titles={"title": "(climate+change)"}).count() + + assert n == r["meta"]["total"] + + +def test_query_get(): + + r = DOIs().query(titles={"title": "(climate+change)"}).get() + + assert isinstance(r, list) # def test_multi_dois(): @@ -196,7 +225,7 @@ def test_query_error(): # assert len(results) > 200 -def test_basic_paging(): +def test_manual_paging(): # get the number of records n = DOIs().filter(prefix="10.5438").count() @@ -223,6 +252,22 @@ def test_basic_paging(): assert len(results) == n +def test_number_paging(): + + # get the number of records + n = DOIs().filter(prefix="10.5438").count() + + # example query + pager = DOIs().filter(prefix="10.5438").paginate(method="number", per_page=100) + + n_paging = 0 + for page in pager: + + n_paging += len(page) + + assert n_paging == n + + def test_cursor_paging(): # get the number of records