From 6871857e7f9301b5129e2e9af05fc44bba14373e Mon Sep 17 00:00:00 2001 From: Tomasz Pietruszka Date: Thu, 17 Aug 2023 16:49:24 +0100 Subject: [PATCH] feat: response validation --- README.md | 26 +++++++++++++++++++++++-- rate_limited/exceptions.py | 15 ++++++++++++++ rate_limited/runner.py | 21 ++++++++++++++------ tests/test_runner.py | 40 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 94 insertions(+), 8 deletions(-) create mode 100644 rate_limited/exceptions.py diff --git a/README.md b/README.md index e472470..756e901 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ those of Large Language Models (LLMs). ## Features - parallel execution - be as quick as possible within the rate limit you have - retry failed requests +- validate the response against your criteria and retry if not valid (for non-deterministic APIs) - if interrupted (`KeyboardInterrupt`) e.g. in a notebook: - returns partial results - if ran again, continues from where it left off - and returns full results @@ -68,6 +69,29 @@ for topic in topics: results, exceptions = runner.run() ``` +### Validating the response +We can provide custom validation logic to the Runner - to retry the request if the response +does not meet our criteria - for example, if it does not conform to the schema we expect. This +assumes that the API is non-deterministic. + +Example above continued: +```python +def character_number_is_even(response): + poem = response["choices"][0]["message"]["content"] + return len([ch for ch in poem if ch.isalpha()]) % 2 == 0 + +validating_runner = Runner( + openai.ChatCompletion.create, + resources, + max_concurrent=32, + validation_function=character_number_is_even, +) +for topic in topics: + messages = [{"role": "user", "content": f"Please write a short poem about {topic}, containing an even number of letters"}] + validating_runner.schedule(model=model, messages=messages, max_tokens=256, request_timeout=60) +results, exceptions = validating_runner.run() +``` + ### Custom server with a "requests per minute" limit proceed as above - replace `openai.ChatCompletion.create` with your own function, and describe resources as follows: @@ -139,8 +163,6 @@ flake8 && black --check . && mypy . - more ready-made API descriptions - incl. batched ones? - fix the "interrupt and resume" test in Python 3.11 ### Nice to have: -- add an optional "result verification" mechanism, for when the server might return, but - the results might be incorrect (e.g. LM not conforming with the given format) - so we retry - (optional) slow start feature - pace the initial requests, instead of sending them all at once - text-based logging if tqdm is not installed - if/where possible, detect RateLimitExceeded - notify the user, slow down diff --git a/rate_limited/exceptions.py b/rate_limited/exceptions.py new file mode 100644 index 0000000..b0ccdaa --- /dev/null +++ b/rate_limited/exceptions.py @@ -0,0 +1,15 @@ +class BaseException(Exception): + pass + + +class ValidationError(BaseException): + """ + API response failed the user-provided validation + + Store the invalid value for post-hoc inspection/debugging + """ + + def __init__(self, message, call, value, *args, **kwargs): + super().__init__(message, *args, **kwargs) + self.call = call + self.value = value diff --git a/rate_limited/runner.py b/rate_limited/runner.py index c3a7501..f6d706e 100644 --- a/rate_limited/runner.py +++ b/rate_limited/runner.py @@ -5,9 +5,10 @@ from concurrent.futures import ThreadPoolExecutor from inspect import signature from logging import getLogger -from typing import Callable, Collection, List, Optional, Tuple +from typing import Any, Callable, Collection, List, Optional, Tuple from rate_limited.calls import Call +from rate_limited.exceptions import ValidationError from rate_limited.progress_bar import ProgressBar from rate_limited.queue import CompletionTrackingQueue from rate_limited.resource_manager import ResourceManager @@ -22,6 +23,7 @@ def __init__( resources: Collection[Resource], max_concurrent: int, max_retries: int = 5, + validation_function: Optional[Callable[[Any], bool]] = None, progress_interval: float = 1.0, long_wait_warning_seconds: Optional[float] = 2.0, ): @@ -30,10 +32,9 @@ def __init__( self.max_concurrent = max_concurrent self.requests_executor_pool = ThreadPoolExecutor(max_workers=max_concurrent) self.max_retries = max_retries + self.validation_function = validation_function self.progress_interval = progress_interval self.long_wait_warning_seconds = long_wait_warning_seconds - # TODO: add verification functions? - # (checking if response meets criteria, retrying otherwise) self.logger = getLogger(f"rate_limited.Runner.{function.__name__}") @@ -172,12 +173,20 @@ async def worker(self): self.resource_manager.pre_allocate(call) try: # TODO: add a timeout mechanism? - call.result = await to_thread_in_pool( + result = await to_thread_in_pool( self.requests_executor_pool, self.function, *call.args, **call.kwargs ) # TODO: are there cases where we need to register result-based usage on error? - # (one case: if we have user-defined verification functions) - self.resource_manager.register_result(call.result) + self.resource_manager.register_result(result) + if self.validation_function is not None: + if not self.validation_function(result): + raise ValidationError( + message="Validation failed", + call=call, + value=result, + ) + call.result = result + except Exception as e: will_retry = call.num_retries < self.max_retries self.logger.warning( diff --git a/tests/test_runner.py b/tests/test_runner.py index 0dd9940..58f2b76 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -1,9 +1,11 @@ import asyncio +import random from typing import List import pytest import requests +from rate_limited.exceptions import ValidationError from rate_limited.resources import Resource from rate_limited.runner import Runner @@ -226,3 +228,41 @@ def scenario(): assert exceptions == [[]] * num_requests test_executor(scenario) + + +def test_result_validation(running_dummy_server): + """ + Check that the results are validated using the validation function and retried if necessary. + """ + rng = random.Random(42) + + def random_client(url: str, how_many=2, failure_proba: float = 0.2) -> dict: + """Request between 1 and `how_many` calculations from the server, with a `failure_proba`""" + how_many = rng.randint(1, how_many) + result = requests.get(f"{url}/calculate_things/{how_many}?failure_proba={failure_proba}") + # this imitates the behavior of an API client, raising e.g. on a timeout error (or some + # other kind of error) + result.raise_for_status() + parsed = result.json() + return parsed + + def validate(result: dict) -> bool: + return result["output"].count("x") == 2 + + runner = Runner( + random_client, + resources=dummy_resources(num_requests=5), + validation_function=validate, + max_concurrent=5, + max_retries=10, + ) + num_requests = 5 + for _ in range(num_requests): + runner.schedule(running_dummy_server) + + results, exceptions = runner.run() + outputs = [result["output"] for result in results] + assert outputs == ["xx"] * num_requests + + exceptions_flat = [e for sublist in exceptions for e in sublist] + assert any(isinstance(e, ValidationError) for e in exceptions_flat)