diff --git a/mpcontribs-client/mpcontribs/client/__init__.py b/mpcontribs-client/mpcontribs/client/__init__.py index 9f3666364..3f16c49ca 100644 --- a/mpcontribs-client/mpcontribs/client/__init__.py +++ b/mpcontribs-client/mpcontribs/client/__init__.py @@ -953,7 +953,11 @@ def _get_future( setattr(future, "track_id", track_id) return future - def available_query_params(self, resource="contributions"): + def available_query_params( + self, + startswith: tuple = None, + resource: str = "contributions" + ) -> list: resources = self.swagger_spec.resources resource_obj = resources.get(resource) if not resource_obj: @@ -962,7 +966,14 @@ def available_query_params(self, resource="contributions"): op_key = f"query{resource.capitalize()}" operation = resource_obj.operations[op_key] - return [param.name for param in operation.params.values()] + params = [param.name for param in operation.params.values()] + if not startswith: + return params + + return [ + param for param in params + if param.startswith(startswith) + ] def get_project(self, name: str = None, fields: list = None) -> Type[Dict]: """Retrieve a project entry @@ -1093,14 +1104,17 @@ def update_project(self, update: dict, name: str = None): raise MPContribsClientError("initialize client with project or set `name` argument!") disallowed = ["is_approved", "stats", "columns", "is_public", "owner"] - for k in disallowed: - if k in update: + for k in list(update.keys()): + if k in disallowed: logger.warning(f"removing `{k}` from update - not allowed.") update.pop(k) if k == "columns": logger.info("use `client.init_columns()` to update project columns.") elif k == "is_public": logger.info("use `client.make_public/private()` to set `is_public`.") + elif not isinstance(update[k], bool) and not update[k]: + logger.warning(f"removing `{k}` from update - no update requested.") + update.pop(k) if not update: logger.warning("nothing to update") @@ -1121,7 +1135,7 @@ def update_project(self, update: dict, name: str = None): payload = { k: v for k, v in update.items() - if v and k in fields and project.get(k, None) != v + if k in fields and project.get(k, None) != v } if not payload: logger.warning("nothing to update") @@ -1469,6 +1483,10 @@ def get_totals( return result["total_count"], result["total_pages"] + def count(self, query: dict = None) -> int: + """shortcut for get_totals()""" + return self.get_totals(query=query)[0] + def get_unique_identifiers_flags(self, query: dict = None) -> dict: """Retrieve values for `unique_identifiers` flags. @@ -2060,6 +2078,7 @@ def submit_contributions( # submit contributions if contribs: total, total_processed = 0, 0 + nmax = 1000 # TODO this should be set dynamically from `bulk_update_limit` def post_future(track_id, payload): future = self.session.post( @@ -2101,8 +2120,9 @@ def put_future(pk, payload): else: logger.error(f"SKIPPED: update of {project_name}/{pk} too large.") else: - next_payload = ujson.dumps(post_chunk + [c]).encode("utf-8") - if len(next_payload) >= MAX_PAYLOAD: + next_post_chunk = post_chunk + [c] + next_payload = ujson.dumps(next_post_chunk).encode("utf-8") + if len(next_post_chunk) > nmax or len(next_payload) >= MAX_PAYLOAD: if post_chunk: payload = ujson.dumps(post_chunk).encode("utf-8") futures.append(post_future(idx, payload))