From 48a04680b675fd094a518b24336a7ac97c4fd580 Mon Sep 17 00:00:00 2001 From: Alex Dunn Date: Wed, 30 Aug 2023 10:48:24 -0700 Subject: [PATCH] refactor utils --- docker/sweepers_driver.py | 3 +- src/pds/registrysweepers/ancestry/__init__.py | 6 +- .../registrysweepers/ancestry/generation.py | 4 +- src/pds/registrysweepers/ancestry/queries.py | 4 +- src/pds/registrysweepers/provenance.py | 8 +- .../registrysweepers/repairkit/__init__.py | 6 +- src/pds/registrysweepers/utils/__init__.py | 283 +----------------- src/pds/registrysweepers/utils/db/__init__.py | 252 ++++++++++++++++ src/pds/registrysweepers/utils/db/host.py | 3 + src/pds/registrysweepers/utils/misc.py | 33 ++ tests/pds/registrysweepers/test_ancestry.py | 2 +- 11 files changed, 307 insertions(+), 297 deletions(-) create mode 100644 src/pds/registrysweepers/utils/db/host.py create mode 100644 src/pds/registrysweepers/utils/misc.py diff --git a/docker/sweepers_driver.py b/docker/sweepers_driver.py index 16e3e47..b96c4f1 100755 --- a/docker/sweepers_driver.py +++ b/docker/sweepers_driver.py @@ -63,7 +63,8 @@ from typing import Callable from pds.registrysweepers import provenance, ancestry, repairkit -from pds.registrysweepers.utils import configure_logging, get_human_readable_elapsed_since, parse_log_level +from pds.registrysweepers.utils import configure_logging, parse_log_level +from pds.registrysweepers.utils.misc import get_human_readable_elapsed_since configure_logging(filepath=None, log_level=logging.INFO) log = logging.getLogger(__name__) diff --git a/src/pds/registrysweepers/ancestry/__init__.py b/src/pds/registrysweepers/ancestry/__init__.py index 983cdc7..c7e7090 100644 --- a/src/pds/registrysweepers/ancestry/__init__.py +++ b/src/pds/registrysweepers/ancestry/__init__.py @@ -14,10 +14,10 @@ from pds.registrysweepers.ancestry.generation import get_collection_ancestry_records from pds.registrysweepers.ancestry.generation import get_nonaggregate_ancestry_records from pds.registrysweepers.utils import configure_logging -from pds.registrysweepers.utils import Host from pds.registrysweepers.utils import parse_args -from pds.registrysweepers.utils import Update -from pds.registrysweepers.utils import write_updated_docs +from pds.registrysweepers.utils.db import write_updated_docs +from pds.registrysweepers.utils.db.host import Host +from pds.registrysweepers.utils.db.update import Update log = logging.getLogger(__name__) diff --git a/src/pds/registrysweepers/ancestry/generation.py b/src/pds/registrysweepers/ancestry/generation.py index 346736c..0194712 100644 --- a/src/pds/registrysweepers/ancestry/generation.py +++ b/src/pds/registrysweepers/ancestry/generation.py @@ -11,8 +11,8 @@ from pds.registrysweepers.ancestry.queries import get_collection_ancestry_records_bundles_query from pds.registrysweepers.ancestry.queries import get_collection_ancestry_records_collections_query from pds.registrysweepers.ancestry.queries import get_nonaggregate_ancestry_records_query -from pds.registrysweepers.utils import coerce_list_type -from pds.registrysweepers.utils import Host +from pds.registrysweepers.utils.db.host import Host +from pds.registrysweepers.utils.misc import coerce_list_type from pds.registrysweepers.utils.productidentifiers.factory import PdsProductIdentifierFactory from pds.registrysweepers.utils.productidentifiers.pdslid import PdsLid from pds.registrysweepers.utils.productidentifiers.pdslidvid import PdsLidVid diff --git a/src/pds/registrysweepers/ancestry/queries.py b/src/pds/registrysweepers/ancestry/queries.py index 3537ec6..30e08db 100644 --- a/src/pds/registrysweepers/ancestry/queries.py +++ b/src/pds/registrysweepers/ancestry/queries.py @@ -6,8 +6,8 @@ from typing import Iterable from typing import Optional -from pds.registrysweepers.utils import Host -from pds.registrysweepers.utils import query_registry_db_or_mock +from pds.registrysweepers.utils.db import query_registry_db_or_mock +from pds.registrysweepers.utils.db.host import Host log = logging.getLogger(__name__) diff --git a/src/pds/registrysweepers/provenance.py b/src/pds/registrysweepers/provenance.py index ba029fc..9468c1f 100755 --- a/src/pds/registrysweepers/provenance.py +++ b/src/pds/registrysweepers/provenance.py @@ -48,11 +48,11 @@ from pds.registrysweepers.utils import _vid_as_tuple_of_int from pds.registrysweepers.utils import configure_logging -from pds.registrysweepers.utils import get_extant_lidvids -from pds.registrysweepers.utils import Host from pds.registrysweepers.utils import parse_args -from pds.registrysweepers.utils import Update -from pds.registrysweepers.utils import write_updated_docs +from pds.registrysweepers.utils.db import get_extant_lidvids +from pds.registrysweepers.utils.db import write_updated_docs +from pds.registrysweepers.utils.db.host import Host +from pds.registrysweepers.utils.db.update import Update log = logging.getLogger(__name__) diff --git a/src/pds/registrysweepers/repairkit/__init__.py b/src/pds/registrysweepers/repairkit/__init__.py index e8059ec..28e3cce 100644 --- a/src/pds/registrysweepers/repairkit/__init__.py +++ b/src/pds/registrysweepers/repairkit/__init__.py @@ -12,12 +12,12 @@ from typing import Union from pds.registrysweepers.utils import configure_logging -from pds.registrysweepers.utils import Host from pds.registrysweepers.utils import query_registry_db -from pds.registrysweepers.utils import Update -from pds.registrysweepers.utils import write_updated_docs +from pds.registrysweepers.utils.db.host import Host +from pds.registrysweepers.utils.db.update import Update from . import allarrays +from ..utils.db import write_updated_docs """ dictionary repair tools is {field_name:[funcs]} where field_name can be: diff --git a/src/pds/registrysweepers/utils/__init__.py b/src/pds/registrysweepers/utils/__init__.py index 1e00cd0..17f0b5f 100644 --- a/src/pds/registrysweepers/utils/__init__.py +++ b/src/pds/registrysweepers/utils/__init__.py @@ -1,29 +1,11 @@ import argparse -import collections -import functools -import json import logging -import random -import sys -import urllib.parse from argparse import Namespace -from datetime import datetime -from typing import Any -from typing import Callable -from typing import Dict -from typing import Iterable from typing import List -from typing import Mapping -from typing import Optional from typing import Union -import requests -from pds.registrysweepers.utils.db.update import Update -from requests.exceptions import HTTPError -from retry import retry -from retry.api import retry_call - -Host = collections.namedtuple("Host", ["password", "url", "username", "verify"]) +from pds.registrysweepers.utils.db import query_registry_db +from pds.registrysweepers.utils.db.host import Host log = logging.getLogger(__name__) @@ -90,264 +72,3 @@ def configure_logging(filepath: Union[str, None], log_level: int): handlers.append(logging.FileHandler(filepath)) logging.basicConfig(level=log_level, format="%(asctime)s::%(name)s::%(levelname)s::%(message)s", handlers=handlers) - - -def query_registry_db( - host: Host, - query: Dict, - _source: Dict, - index_name: str = "registry", - page_size: int = 10000, - scroll_keepalive_minutes: int = 10, -) -> Iterable[Dict]: - """ - Given an OpenSearch host and query/_source, return an iterable collection of hits - - Example query: {"bool": {"must": [{"terms": {"ops:Tracking_Meta/ops:archive_status": ["archived", "certified"]}}]}} - Example _source: {"includes": ["lidvid"]} - """ - - req_content = { - "query": query, - "_source": _source, - "size": page_size, - } - - query_id = get_random_hex_id() # This is just used to differentiate queries during logging - log.info(f"Initiating query with id {query_id}: {req_content}") - - path = f"{index_name}/_search?scroll={scroll_keepalive_minutes}m" - - served_hits = 0 - - last_info_log_at_percentage = 0 - log.info(f"Query {query_id} progress: 0%") - - more_data_exists = True - while more_data_exists: - resp = retry_call( - requests.get, - fargs=[urllib.parse.urljoin(host.url, path)], - fkwargs={"auth": (host.username, host.password), "verify": host.verify, "json": req_content}, - tries=6, - delay=2, - backoff=2, - logger=log, - ) - resp.raise_for_status() - - data = resp.json() - path = "_search/scroll" - req_content = {"scroll": f"{scroll_keepalive_minutes}m", "scroll_id": data["_scroll_id"]} - - total_hits = data["hits"]["total"]["value"] - log.debug( - f" paging query {query_id} ({served_hits} to {min(served_hits + page_size, total_hits)} of {total_hits})" - ) - - response_hits = data["hits"]["hits"] - for hit in response_hits: - served_hits += 1 - - percentage_of_hits_served = int(served_hits / total_hits * 100) - if last_info_log_at_percentage is None or percentage_of_hits_served >= (last_info_log_at_percentage + 5): - last_info_log_at_percentage = percentage_of_hits_served - log.info(f"Query {query_id} progress: {percentage_of_hits_served}%") - - yield hit - - # This is a temporary, ad-hoc guard against empty/erroneous responses which do not return non-200 status codes. - # Previously, this has cause infinite loops in production due to served_hits sticking and never reaching the - # expected total hits value. - # TODO: Remove this upon implementation of https://github.com/NASA-PDS/registry-sweepers/issues/42 - hits_data_present_in_response = len(response_hits) > 0 - if not hits_data_present_in_response: - log.error( - f"Response for query {query_id} contained no hits when hits were expected. Returned data is incomplete. Response was: {data}" - ) - break - - more_data_exists = served_hits < data["hits"]["total"]["value"] - - # TODO: Determine if the following block is actually necessary - if "scroll_id" in req_content: - path = f'_search/scroll/{req_content["scroll_id"]}' - retry_call( - requests.delete, - fargs=[urllib.parse.urljoin(host.url, path)], - fkwargs={"auth": (host.username, host.password), "verify": host.verify}, - tries=6, - delay=2, - backoff=2, - logger=log, - ) - - log.info(f"Query {query_id} complete!") - - -def query_registry_db_or_mock(mock_f: Optional[Callable[[str], Iterable[Dict]]], mock_query_id: str): - if mock_f is not None: - - def mock_wrapper( - host: Host, - query: Dict, - _source: Dict, - index_name: str = "registry", - page_size: int = 10000, - scroll_validity_duration_minutes: int = 10, - ) -> Iterable[Dict]: - return mock_f(mock_query_id) # type: ignore # see None-check above - - return mock_wrapper - else: - return query_registry_db - - -def get_extant_lidvids(host: Host) -> Iterable[str]: - """ - Given an OpenSearch host, return all extant LIDVIDs - """ - - log.info("Retrieving extant LIDVIDs") - - query = {"bool": {"must": [{"terms": {"ops:Tracking_Meta/ops:archive_status": ["archived", "certified"]}}]}} - _source = {"includes": ["lidvid"]} - - results = query_registry_db(host, query, _source, scroll_keepalive_minutes=1) - - return map(lambda doc: doc["_source"]["lidvid"], results) - - -def write_updated_docs(host: Host, updates: Iterable[Update], index_name: str = "registry"): - log.info("Updating a lazily-generated collection of product documents...") - updated_doc_count = 0 - - bulk_buffer_max_size_mb = 30.0 - bulk_buffer_size_mb = 0.0 - bulk_updates_buffer: List[str] = [] - for update in updates: - if bulk_buffer_size_mb > bulk_buffer_max_size_mb: - pending_product_count = int(len(bulk_updates_buffer) / 2) - log.info( - f"Bulk update buffer has reached {bulk_buffer_max_size_mb}MB threshold - writing {pending_product_count} document updates to db..." - ) - _write_bulk_updates_chunk(host, index_name, bulk_updates_buffer) - bulk_updates_buffer = [] - bulk_buffer_size_mb = 0.0 - - update_statement_strs = update_as_statements(update) - - for s in update_statement_strs: - bulk_buffer_size_mb += sys.getsizeof(s) / 1024**2 - - bulk_updates_buffer.extend(update_statement_strs) - updated_doc_count += 1 - - remaining_products_to_write_count = int(len(bulk_updates_buffer) / 2) - updated_doc_count += remaining_products_to_write_count - - log.info(f"Writing documents updates for {remaining_products_to_write_count} remaining products to db...") - _write_bulk_updates_chunk(host, index_name, bulk_updates_buffer) - - log.info(f"Updated documents for {updated_doc_count} total products!") - - -def update_as_statements(update: Update) -> Iterable[str]: - """Given an Update, convert it to an ElasticSearch-style set of request body content strings""" - update_objs = [{"update": {"_id": update.id}}, {"doc": update.content}] - updates_strs = [json.dumps(obj) for obj in update_objs] - return updates_strs - - -@retry(exceptions=(HTTPError, RuntimeError), tries=6, delay=2, backoff=2, logger=log) -def _write_bulk_updates_chunk(host: Host, index_name: str, bulk_updates: Iterable[str]): - headers = {"Content-Type": "application/x-ndjson"} - path = f"{index_name}/_bulk" - - bulk_data = "\n".join(bulk_updates) + "\n" - - response = requests.put( - urllib.parse.urljoin(host.url, path), - auth=(host.username, host.password), - data=bulk_data, - headers=headers, - verify=host.verify, - ) - - # N.B. HTTP status 200 is insufficient as a success check for _bulk API. - # See: https://github.com/elastic/elasticsearch/issues/41434 - response.raise_for_status() - response_content = response.json() - if response_content.get("errors"): - warn_types = {"document_missing_exception"} # these types represent bad data, not bad sweepers behaviour - items_with_problems = [item for item in response_content["items"] if "error" in item["update"]] - - if log.isEnabledFor(logging.WARNING): - items_with_warnings = [ - item for item in items_with_problems if item["update"]["error"]["type"] in warn_types - ] - warning_aggregates = aggregate_update_error_types(items_with_warnings) - for error_type, reason_aggregate in warning_aggregates.items(): - for error_reason, ids in reason_aggregate.items(): - log.warning( - f"Attempt to update the following documents failed due to {error_type} ({error_reason}): {ids}" - ) - - if log.isEnabledFor(logging.ERROR): - items_with_errors = [ - item for item in items_with_problems if item["update"]["error"]["type"] not in warn_types - ] - error_aggregates = aggregate_update_error_types(items_with_errors) - for error_type, reason_aggregate in error_aggregates.items(): - for error_reason, ids in reason_aggregate.items(): - log.error( - f"Attempt to update the following documents failed unexpectedly due to {error_type} ({error_reason}): {ids}" - ) - - -def aggregate_update_error_types(items: Iterable[Dict]) -> Mapping[str, Dict[str, List[str]]]: - """Return a nested aggregation of ids, aggregated first by error type, then by reason""" - agg: Dict[str, Dict[str, List[str]]] = {} - for item in items: - id = item["update"]["_id"] - error = item["update"]["error"] - error_type = error["type"] - error_reason = error["reason"] - if error_type not in agg: - agg[error_type] = {} - - if error_reason not in agg[error_type]: - agg[error_type][error_reason] = [] - - agg[error_type][error_reason].append(id) - - return agg - - -def coerce_list_type(db_value: Any) -> List[Any]: - """ - Coerce a non-array-typed legacy db record into a list containing itself as the only element, or return the - original argument if it is already an array (list). This is sometimes necessary to support legacy db records which - did not wrap singleton properties in an enclosing array. - """ - - return ( - db_value - if type(db_value) is list - else [ - db_value, - ] - ) - - -def get_human_readable_elapsed_since(begin: datetime) -> str: - elapsed_seconds = (datetime.now() - begin).total_seconds() - h = int(elapsed_seconds / 3600) - m = int(elapsed_seconds % 3600 / 60) - s = int(elapsed_seconds % 60) - return (f"{h}h" if h else "") + (f"{m}m" if m else "") + f"{s}s" - - -def get_random_hex_id(id_len: int = 6) -> str: - val = random.randint(0, 16**id_len) - return hex(val)[2:] diff --git a/src/pds/registrysweepers/utils/db/__init__.py b/src/pds/registrysweepers/utils/db/__init__.py index e69de29..520606a 100644 --- a/src/pds/registrysweepers/utils/db/__init__.py +++ b/src/pds/registrysweepers/utils/db/__init__.py @@ -0,0 +1,252 @@ +import json +import logging +import sys +import urllib.parse +from typing import Callable +from typing import Dict +from typing import Iterable +from typing import List +from typing import Mapping +from typing import Optional + +import requests +from pds.registrysweepers.utils.db.host import Host +from pds.registrysweepers.utils.db.update import Update +from pds.registrysweepers.utils.misc import get_random_hex_id +from requests import HTTPError +from retry import retry +from retry.api import retry_call + +log = logging.getLogger(__name__) + + +def query_registry_db( + host: Host, + query: Dict, + _source: Dict, + index_name: str = "registry", + page_size: int = 10000, + scroll_keepalive_minutes: int = 10, +) -> Iterable[Dict]: + """ + Given an OpenSearch host and query/_source, return an iterable collection of hits + + Example query: {"bool": {"must": [{"terms": {"ops:Tracking_Meta/ops:archive_status": ["archived", "certified"]}}]}} + Example _source: {"includes": ["lidvid"]} + """ + + req_content = { + "query": query, + "_source": _source, + "size": page_size, + } + + query_id = get_random_hex_id() # This is just used to differentiate queries during logging + log.info(f"Initiating query with id {query_id}: {req_content}") + + path = f"{index_name}/_search?scroll={scroll_keepalive_minutes}m" + + served_hits = 0 + + last_info_log_at_percentage = 0 + log.info(f"Query {query_id} progress: 0%") + + more_data_exists = True + while more_data_exists: + resp = retry_call( + requests.get, + fargs=[urllib.parse.urljoin(host.url, path)], + fkwargs={"auth": (host.username, host.password), "verify": host.verify, "json": req_content}, + tries=6, + delay=2, + backoff=2, + logger=log, + ) + resp.raise_for_status() + + data = resp.json() + path = "_search/scroll" + req_content = {"scroll": f"{scroll_keepalive_minutes}m", "scroll_id": data["_scroll_id"]} + + total_hits = data["hits"]["total"]["value"] + log.debug( + f" paging query {query_id} ({served_hits} to {min(served_hits + page_size, total_hits)} of {total_hits})" + ) + + response_hits = data["hits"]["hits"] + for hit in response_hits: + served_hits += 1 + + percentage_of_hits_served = int(served_hits / total_hits * 100) + if last_info_log_at_percentage is None or percentage_of_hits_served >= (last_info_log_at_percentage + 5): + last_info_log_at_percentage = percentage_of_hits_served + log.info(f"Query {query_id} progress: {percentage_of_hits_served}%") + + yield hit + + # This is a temporary, ad-hoc guard against empty/erroneous responses which do not return non-200 status codes. + # Previously, this has cause infinite loops in production due to served_hits sticking and never reaching the + # expected total hits value. + # TODO: Remove this upon implementation of https://github.com/NASA-PDS/registry-sweepers/issues/42 + hits_data_present_in_response = len(response_hits) > 0 + if not hits_data_present_in_response: + log.error( + f"Response for query {query_id} contained no hits when hits were expected. Returned data is incomplete. Response was: {data}" + ) + break + + more_data_exists = served_hits < data["hits"]["total"]["value"] + + # TODO: Determine if the following block is actually necessary + if "scroll_id" in req_content: + path = f'_search/scroll/{req_content["scroll_id"]}' + retry_call( + requests.delete, + fargs=[urllib.parse.urljoin(host.url, path)], + fkwargs={"auth": (host.username, host.password), "verify": host.verify}, + tries=6, + delay=2, + backoff=2, + logger=log, + ) + + log.info(f"Query {query_id} complete!") + + +def query_registry_db_or_mock(mock_f: Optional[Callable[[str], Iterable[Dict]]], mock_query_id: str): + if mock_f is not None: + + def mock_wrapper( + host: Host, + query: Dict, + _source: Dict, + index_name: str = "registry", + page_size: int = 10000, + scroll_validity_duration_minutes: int = 10, + ) -> Iterable[Dict]: + return mock_f(mock_query_id) # type: ignore # see None-check above + + return mock_wrapper + else: + return query_registry_db + + +def write_updated_docs(host: Host, updates: Iterable[Update], index_name: str = "registry"): + log.info("Updating a lazily-generated collection of product documents...") + updated_doc_count = 0 + + bulk_buffer_max_size_mb = 30.0 + bulk_buffer_size_mb = 0.0 + bulk_updates_buffer: List[str] = [] + for update in updates: + if bulk_buffer_size_mb > bulk_buffer_max_size_mb: + pending_product_count = int(len(bulk_updates_buffer) / 2) + log.info( + f"Bulk update buffer has reached {bulk_buffer_max_size_mb}MB threshold - writing {pending_product_count} document updates to db..." + ) + _write_bulk_updates_chunk(host, index_name, bulk_updates_buffer) + bulk_updates_buffer = [] + bulk_buffer_size_mb = 0.0 + + update_statement_strs = update_as_statements(update) + + for s in update_statement_strs: + bulk_buffer_size_mb += sys.getsizeof(s) / 1024**2 + + bulk_updates_buffer.extend(update_statement_strs) + updated_doc_count += 1 + + remaining_products_to_write_count = int(len(bulk_updates_buffer) / 2) + updated_doc_count += remaining_products_to_write_count + + log.info(f"Writing documents updates for {remaining_products_to_write_count} remaining products to db...") + _write_bulk_updates_chunk(host, index_name, bulk_updates_buffer) + + log.info(f"Updated documents for {updated_doc_count} total products!") + + +def update_as_statements(update: Update) -> Iterable[str]: + """Given an Update, convert it to an ElasticSearch-style set of request body content strings""" + update_objs = [{"update": {"_id": update.id}}, {"doc": update.content}] + updates_strs = [json.dumps(obj) for obj in update_objs] + return updates_strs + + +@retry(exceptions=(HTTPError, RuntimeError), tries=6, delay=2, backoff=2, logger=log) +def _write_bulk_updates_chunk(host: Host, index_name: str, bulk_updates: Iterable[str]): + headers = {"Content-Type": "application/x-ndjson"} + path = f"{index_name}/_bulk" + + bulk_data = "\n".join(bulk_updates) + "\n" + + response = requests.put( + urllib.parse.urljoin(host.url, path), + auth=(host.username, host.password), + data=bulk_data, + headers=headers, + verify=host.verify, + ) + + # N.B. HTTP status 200 is insufficient as a success check for _bulk API. + # See: https://github.com/elastic/elasticsearch/issues/41434 + response.raise_for_status() + response_content = response.json() + if response_content.get("errors"): + warn_types = {"document_missing_exception"} # these types represent bad data, not bad sweepers behaviour + items_with_problems = [item for item in response_content["items"] if "error" in item["update"]] + + if log.isEnabledFor(logging.WARNING): + items_with_warnings = [ + item for item in items_with_problems if item["update"]["error"]["type"] in warn_types + ] + warning_aggregates = aggregate_update_error_types(items_with_warnings) + for error_type, reason_aggregate in warning_aggregates.items(): + for error_reason, ids in reason_aggregate.items(): + log.warning( + f"Attempt to update the following documents failed due to {error_type} ({error_reason}): {ids}" + ) + + if log.isEnabledFor(logging.ERROR): + items_with_errors = [ + item for item in items_with_problems if item["update"]["error"]["type"] not in warn_types + ] + error_aggregates = aggregate_update_error_types(items_with_errors) + for error_type, reason_aggregate in error_aggregates.items(): + for error_reason, ids in reason_aggregate.items(): + log.error( + f"Attempt to update the following documents failed unexpectedly due to {error_type} ({error_reason}): {ids}" + ) + + +def aggregate_update_error_types(items: Iterable[Dict]) -> Mapping[str, Dict[str, List[str]]]: + """Return a nested aggregation of ids, aggregated first by error type, then by reason""" + agg: Dict[str, Dict[str, List[str]]] = {} + for item in items: + id = item["update"]["_id"] + error = item["update"]["error"] + error_type = error["type"] + error_reason = error["reason"] + if error_type not in agg: + agg[error_type] = {} + + if error_reason not in agg[error_type]: + agg[error_type][error_reason] = [] + + agg[error_type][error_reason].append(id) + + return agg + + +def get_extant_lidvids(host: Host) -> Iterable[str]: + """ + Given an OpenSearch host, return all extant LIDVIDs + """ + + log.info("Retrieving extant LIDVIDs") + + query = {"bool": {"must": [{"terms": {"ops:Tracking_Meta/ops:archive_status": ["archived", "certified"]}}]}} + _source = {"includes": ["lidvid"]} + + results = query_registry_db(host, query, _source, scroll_keepalive_minutes=1) + + return map(lambda doc: doc["_source"]["lidvid"], results) diff --git a/src/pds/registrysweepers/utils/db/host.py b/src/pds/registrysweepers/utils/db/host.py new file mode 100644 index 0000000..5bd25f1 --- /dev/null +++ b/src/pds/registrysweepers/utils/db/host.py @@ -0,0 +1,3 @@ +import collections + +Host = collections.namedtuple("Host", ["password", "url", "username", "verify"]) diff --git a/src/pds/registrysweepers/utils/misc.py b/src/pds/registrysweepers/utils/misc.py new file mode 100644 index 0000000..b2df742 --- /dev/null +++ b/src/pds/registrysweepers/utils/misc.py @@ -0,0 +1,33 @@ +import random +from datetime import datetime +from typing import Any +from typing import List + + +def coerce_list_type(db_value: Any) -> List[Any]: + """ + Coerce a non-array-typed legacy db record into a list containing itself as the only element, or return the + original argument if it is already an array (list). This is sometimes necessary to support legacy db records which + did not wrap singleton properties in an enclosing array. + """ + + return ( + db_value + if type(db_value) is list + else [ + db_value, + ] + ) + + +def get_human_readable_elapsed_since(begin: datetime) -> str: + elapsed_seconds = (datetime.now() - begin).total_seconds() + h = int(elapsed_seconds / 3600) + m = int(elapsed_seconds % 3600 / 60) + s = int(elapsed_seconds % 60) + return (f"{h}h" if h else "") + (f"{m}m" if m else "") + f"{s}s" + + +def get_random_hex_id(id_len: int = 6) -> str: + val = random.randint(0, 16**id_len) + return hex(val)[2:] diff --git a/tests/pds/registrysweepers/test_ancestry.py b/tests/pds/registrysweepers/test_ancestry.py index 377264a..79eba48 100644 --- a/tests/pds/registrysweepers/test_ancestry.py +++ b/tests/pds/registrysweepers/test_ancestry.py @@ -8,7 +8,7 @@ from pds.registrysweepers import ancestry from pds.registrysweepers.ancestry import AncestryRecord from pds.registrysweepers.ancestry import get_collection_ancestry_records -from pds.registrysweepers.utils import Host +from pds.registrysweepers.utils.db.host import Host from pds.registrysweepers.utils.productidentifiers.pdslidvid import PdsLidVid from tests.mocks.registryquerymock import RegistryQueryMock