From 2a35fa18f45f850f71d334a5e6d9acc6540a0539 Mon Sep 17 00:00:00 2001 From: Yifan Mai Date: Thu, 9 Jan 2025 14:44:56 -0800 Subject: [PATCH] Add caching to Nova client --- src/helm/clients/bedrock_client.py | 14 +++++++++++--- src/helm/common/request.py | 2 +- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/helm/clients/bedrock_client.py b/src/helm/clients/bedrock_client.py index bd22ff3834d..f9b345ad864 100644 --- a/src/helm/clients/bedrock_client.py +++ b/src/helm/clients/bedrock_client.py @@ -133,14 +133,22 @@ def convert_request_to_raw_request(self, request: Request) -> Dict: def make_request(self, request: Request) -> RequestResult: raw_request = self.convert_request_to_raw_request(request) - response = self.bedrock_client.converse(**raw_request) + cache_key = CachingClient.make_cache_key(raw_request, request) + + def do_it() -> Dict[Any, Any]: + return self.bedrock_client.converse(**raw_request) + + response, cached = self.cache.get(cache_key, do_it) + completions = self.convert_raw_response_to_completions(response, request) dt = datetime.strptime(response["ResponseMetadata"]["HTTPHeaders"]["date"], "%a, %d %b %Y %H:%M:%S GMT") + # Use API reported latency rather than client measured latency + request_time = response["metrics"]["latencyMs"] / 1000 return RequestResult( success=True, - cached=False, - request_time=(response["metrics"]["latencyMs"] / 1000), + cached=cached, + request_time=request_time, request_datetime=int(dt.timestamp()), completions=completions, embedding=[], diff --git a/src/helm/common/request.py b/src/helm/common/request.py index 1c1bd288bdd..1838cdaab8d 100644 --- a/src/helm/common/request.py +++ b/src/helm/common/request.py @@ -193,7 +193,7 @@ class RequestResult: """Whether the request was actually cached""" request_time: Optional[float] = None - """How long did the request take?""" + """How long the request took in seconds""" request_datetime: Optional[int] = None """When was the request sent?