Skip to content

Commit

Permalink
Test multiple providers. (#165)
Browse files Browse the repository at this point in the history
This will be invoked by the GitHub workflow
each time a release is cut. This is part of
the pre-work for creating release automation.
Ignore the integration test using mark.
  • Loading branch information
rohitprasad15 authored Dec 26, 2024
1 parent 271af0d commit 5b83ec0
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ jobs:
pip install poetry
poetry install --all-extras --with test
- name: Test with pytest
run: poetry run pytest
run: poetry run pytest -m "not integration"

5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ build-backend = "poetry.core.masonry.api"

[tool.pytest.ini_options]
testpaths="tests"
addopts=[
"--cov=aisuite",
"--cov-report=term-missing"
markers = [
"integration: marks tests as integration tests that interact with external services",
]
74 changes: 74 additions & 0 deletions tests/client/test_prerelease.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Run this test before releasing a new version.
# It will test all the models in the client.

import pytest
import aisuite as ai
from typing import List, Dict
from dotenv import load_dotenv, find_dotenv


def setup_client() -> ai.Client:
"""Initialize the AI client with environment variables."""
load_dotenv(find_dotenv())
return ai.Client()


def get_test_models() -> List[str]:
"""Return a list of model identifiers to test."""
return [
"anthropic:claude-3-5-sonnet-20240620",
"aws:meta.llama3-1-8b-instruct-v1:0",
"huggingface:mistralai/Mistral-7B-Instruct-v0.3",
"groq:llama3-8b-8192",
"mistral:open-mistral-7b",
"openai:gpt-3.5-turbo",
"cohere:command-r-plus-08-2024",
]


def get_test_messages() -> List[Dict[str, str]]:
"""Return the test messages to send to each model."""
return [
{
"role": "system",
"content": "Respond in Pirate English. Always try to include the phrase - No rum No fun.",
},
{"role": "user", "content": "Tell me a joke about Captain Jack Sparrow"},
]


@pytest.mark.integration
@pytest.mark.parametrize("model_id", get_test_models())
def test_model_pirate_response(model_id: str):
"""
Test that each model responds appropriately to the pirate prompt.
Args:
model_id: The provider:model identifier to test
"""
client = setup_client()
messages = get_test_messages()

try:
response = client.chat.completions.create(
model=model_id, messages=messages, temperature=0.75
)

content = response.choices[0].message.content.lower()

# Check if either version of the required phrase is present
assert any(
phrase in content for phrase in ["no rum no fun", "no rum, no fun"]
), f"Model {model_id} did not include required phrase 'No rum No fun'"

assert len(content) > 0, f"Model {model_id} returned empty response"
assert isinstance(
content, str
), f"Model {model_id} returned non-string response"

except Exception as e:
pytest.fail(f"Error testing model {model_id}: {str(e)}")


if __name__ == "__main__":
pytest.main([__file__, "-v"])

0 comments on commit 5b83ec0

Please sign in to comment.