Skip to content

Commit

Permalink
Merge pull request #1652 from materialsproject/dev
Browse files Browse the repository at this point in the history
client improvs and bugfixes
  • Loading branch information
tschaume authored Oct 18, 2023
2 parents 5a4d12b + 1dcc56e commit 06d6ed8
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions mpcontribs-client/mpcontribs/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,11 @@ def _get_future(
setattr(future, "track_id", track_id)
return future

def available_query_params(self, resource="contributions"):
def available_query_params(
self,
startswith: tuple = None,
resource: str = "contributions"
) -> list:
resources = self.swagger_spec.resources
resource_obj = resources.get(resource)
if not resource_obj:
Expand All @@ -962,7 +966,14 @@ def available_query_params(self, resource="contributions"):

op_key = f"query{resource.capitalize()}"
operation = resource_obj.operations[op_key]
return [param.name for param in operation.params.values()]
params = [param.name for param in operation.params.values()]
if not startswith:
return params

return [
param for param in params
if param.startswith(startswith)
]

def get_project(self, name: str = None, fields: list = None) -> Type[Dict]:
"""Retrieve a project entry
Expand Down Expand Up @@ -1093,14 +1104,17 @@ def update_project(self, update: dict, name: str = None):
raise MPContribsClientError("initialize client with project or set `name` argument!")

disallowed = ["is_approved", "stats", "columns", "is_public", "owner"]
for k in disallowed:
if k in update:
for k in list(update.keys()):
if k in disallowed:
logger.warning(f"removing `{k}` from update - not allowed.")
update.pop(k)
if k == "columns":
logger.info("use `client.init_columns()` to update project columns.")
elif k == "is_public":
logger.info("use `client.make_public/private()` to set `is_public`.")
elif not isinstance(update[k], bool) and not update[k]:
logger.warning(f"removing `{k}` from update - no update requested.")
update.pop(k)

if not update:
logger.warning("nothing to update")
Expand All @@ -1121,7 +1135,7 @@ def update_project(self, update: dict, name: str = None):

payload = {
k: v for k, v in update.items()
if v and k in fields and project.get(k, None) != v
if k in fields and project.get(k, None) != v
}
if not payload:
logger.warning("nothing to update")
Expand Down Expand Up @@ -1469,6 +1483,10 @@ def get_totals(

return result["total_count"], result["total_pages"]

def count(self, query: dict = None) -> int:
"""shortcut for get_totals()"""
return self.get_totals(query=query)[0]

def get_unique_identifiers_flags(self, query: dict = None) -> dict:
"""Retrieve values for `unique_identifiers` flags.
Expand Down Expand Up @@ -2060,6 +2078,7 @@ def submit_contributions(
# submit contributions
if contribs:
total, total_processed = 0, 0
nmax = 1000 # TODO this should be set dynamically from `bulk_update_limit`

def post_future(track_id, payload):
future = self.session.post(
Expand Down Expand Up @@ -2101,8 +2120,9 @@ def put_future(pk, payload):
else:
logger.error(f"SKIPPED: update of {project_name}/{pk} too large.")
else:
next_payload = ujson.dumps(post_chunk + [c]).encode("utf-8")
if len(next_payload) >= MAX_PAYLOAD:
next_post_chunk = post_chunk + [c]
next_payload = ujson.dumps(next_post_chunk).encode("utf-8")
if len(next_post_chunk) > nmax or len(next_payload) >= MAX_PAYLOAD:
if post_chunk:
payload = ujson.dumps(post_chunk).encode("utf-8")
futures.append(post_future(idx, payload))
Expand Down

0 comments on commit 06d6ed8

Please sign in to comment.