Skip to content

Commit

Permalink
feat: response validation
Browse files Browse the repository at this point in the history
  • Loading branch information
tpietruszka committed Aug 17, 2023
1 parent ba4daf3 commit 6871857
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 8 deletions.
26 changes: 24 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ those of Large Language Models (LLMs).
## Features
- parallel execution - be as quick as possible within the rate limit you have
- retry failed requests
- validate the response against your criteria and retry if not valid (for non-deterministic APIs)
- if interrupted (`KeyboardInterrupt`) e.g. in a notebook:
- returns partial results
- if ran again, continues from where it left off - and returns full results
Expand Down Expand Up @@ -68,6 +69,29 @@ for topic in topics:
results, exceptions = runner.run()
```

### Validating the response
We can provide custom validation logic to the Runner - to retry the request if the response
does not meet our criteria - for example, if it does not conform to the schema we expect. This
assumes that the API is non-deterministic.

Example above continued:
```python
def character_number_is_even(response):
poem = response["choices"][0]["message"]["content"]
return len([ch for ch in poem if ch.isalpha()]) % 2 == 0

validating_runner = Runner(
openai.ChatCompletion.create,
resources,
max_concurrent=32,
validation_function=character_number_is_even,
)
for topic in topics:
messages = [{"role": "user", "content": f"Please write a short poem about {topic}, containing an even number of letters"}]
validating_runner.schedule(model=model, messages=messages, max_tokens=256, request_timeout=60)
results, exceptions = validating_runner.run()
```

### Custom server with a "requests per minute" limit
proceed as above - replace `openai.ChatCompletion.create` with your own function, and
describe resources as follows:
Expand Down Expand Up @@ -139,8 +163,6 @@ flake8 && black --check . && mypy .
- more ready-made API descriptions - incl. batched ones?
- fix the "interrupt and resume" test in Python 3.11
### Nice to have:
- add an optional "result verification" mechanism, for when the server might return, but
the results might be incorrect (e.g. LM not conforming with the given format) - so we retry
- (optional) slow start feature - pace the initial requests, instead of sending them all at once
- text-based logging if tqdm is not installed
- if/where possible, detect RateLimitExceeded - notify the user, slow down
Expand Down
15 changes: 15 additions & 0 deletions rate_limited/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
class BaseException(Exception):
pass


class ValidationError(BaseException):
"""
API response failed the user-provided validation
Store the invalid value for post-hoc inspection/debugging
"""

def __init__(self, message, call, value, *args, **kwargs):
super().__init__(message, *args, **kwargs)
self.call = call
self.value = value
21 changes: 15 additions & 6 deletions rate_limited/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from concurrent.futures import ThreadPoolExecutor
from inspect import signature
from logging import getLogger
from typing import Callable, Collection, List, Optional, Tuple
from typing import Any, Callable, Collection, List, Optional, Tuple

from rate_limited.calls import Call
from rate_limited.exceptions import ValidationError
from rate_limited.progress_bar import ProgressBar
from rate_limited.queue import CompletionTrackingQueue
from rate_limited.resource_manager import ResourceManager
Expand All @@ -22,6 +23,7 @@ def __init__(
resources: Collection[Resource],
max_concurrent: int,
max_retries: int = 5,
validation_function: Optional[Callable[[Any], bool]] = None,
progress_interval: float = 1.0,
long_wait_warning_seconds: Optional[float] = 2.0,
):
Expand All @@ -30,10 +32,9 @@ def __init__(
self.max_concurrent = max_concurrent
self.requests_executor_pool = ThreadPoolExecutor(max_workers=max_concurrent)
self.max_retries = max_retries
self.validation_function = validation_function
self.progress_interval = progress_interval
self.long_wait_warning_seconds = long_wait_warning_seconds
# TODO: add verification functions?
# (checking if response meets criteria, retrying otherwise)

self.logger = getLogger(f"rate_limited.Runner.{function.__name__}")

Expand Down Expand Up @@ -172,12 +173,20 @@ async def worker(self):
self.resource_manager.pre_allocate(call)
try:
# TODO: add a timeout mechanism?
call.result = await to_thread_in_pool(
result = await to_thread_in_pool(
self.requests_executor_pool, self.function, *call.args, **call.kwargs
)
# TODO: are there cases where we need to register result-based usage on error?
# (one case: if we have user-defined verification functions)
self.resource_manager.register_result(call.result)
self.resource_manager.register_result(result)
if self.validation_function is not None:
if not self.validation_function(result):
raise ValidationError(
message="Validation failed",
call=call,
value=result,
)
call.result = result

except Exception as e:
will_retry = call.num_retries < self.max_retries
self.logger.warning(
Expand Down
40 changes: 40 additions & 0 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import random
from typing import List

import pytest
import requests

from rate_limited.exceptions import ValidationError
from rate_limited.resources import Resource
from rate_limited.runner import Runner

Expand Down Expand Up @@ -226,3 +228,41 @@ def scenario():
assert exceptions == [[]] * num_requests

test_executor(scenario)


def test_result_validation(running_dummy_server):
"""
Check that the results are validated using the validation function and retried if necessary.
"""
rng = random.Random(42)

def random_client(url: str, how_many=2, failure_proba: float = 0.2) -> dict:
"""Request between 1 and `how_many` calculations from the server, with a `failure_proba`"""
how_many = rng.randint(1, how_many)
result = requests.get(f"{url}/calculate_things/{how_many}?failure_proba={failure_proba}")
# this imitates the behavior of an API client, raising e.g. on a timeout error (or some
# other kind of error)
result.raise_for_status()
parsed = result.json()
return parsed

def validate(result: dict) -> bool:
return result["output"].count("x") == 2

runner = Runner(
random_client,
resources=dummy_resources(num_requests=5),
validation_function=validate,
max_concurrent=5,
max_retries=10,
)
num_requests = 5
for _ in range(num_requests):
runner.schedule(running_dummy_server)

results, exceptions = runner.run()
outputs = [result["output"] for result in results]
assert outputs == ["xx"] * num_requests

exceptions_flat = [e for sublist in exceptions for e in sublist]
assert any(isinstance(e, ValidationError) for e in exceptions_flat)

0 comments on commit 6871857

Please sign in to comment.