Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Anyscale OpenAI models #285

Merged
merged 10 commits into from
Jun 14, 2024
10 changes: 9 additions & 1 deletion paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,15 @@ def set_client(
embedding_client: Any | None = None,
):
if client is None and isinstance(self.llm_model, OpenAILLMModel):
client = AsyncOpenAI()
if (api_key := os.environ.get("ANYSCALE_API_KEY")) and (
mskarlin marked this conversation as resolved.
Show resolved Hide resolved
base_url := os.environ.get("ANYSCALE_BASE_URL")
):
client = AsyncOpenAI(
api_key=api_key,
base_url=base_url,
)
else:
client = AsyncOpenAI()
self._client = client
if embedding_client is None:
# check if we have an openai embedding model in use
Expand Down
24 changes: 23 additions & 1 deletion paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import datetime
import os
import re
from abc import ABC, abstractmethod
from enum import Enum
Expand Down Expand Up @@ -54,6 +55,18 @@ def guess_model_type(model_name: str) -> str: # noqa: PLR0911
return "completion"
if model_name.startswith("davinci"):
return "completion"
if (
os.environ.get("ANYSCALE_API_KEY")
and os.environ.get("ANYSCALE_BASE_URL")
and model_name.startswith("meta-llama/Meta-Llama-3-")
):
return "chat"
if (
os.environ.get("ANYSCALE_API_KEY")
and os.environ.get("ANYSCALE_BASE_URL")
and (model_name.startswith(("mistralai/Mistral-", "mistralai/Mixtral-")))
):
return "chat"
if "instruct" in model_name:
return "completion"
if model_name.startswith("gpt-4"):
Expand All @@ -66,7 +79,16 @@ def guess_model_type(model_name: str) -> str: # noqa: PLR0911


def is_openai_model(model_name) -> bool:
return model_name.startswith(("gpt-", "babbage", "davinci", "ft:gpt-"))
open_ai_model_prefixes = {"gpt-", "babbage", "davinci", "ft:gpt-"}
# add special prefixes if the user has anyscale models
# https://docs.anyscale.com/endpoints/text-generation/query-a-model/
if os.environ.get("ANYSCALE_API_KEY") and os.environ.get("ANYSCALE_BASE_URL"):
open_ai_model_prefixes |= {
"meta-llama/Meta-Llama-3-",
"mistralai/Mistral-",
"mistralai/Mixtral-",
}
return model_name.startswith(tuple(open_ai_model_prefixes))


def process_llm_config(
Expand Down
57 changes: 57 additions & 0 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ def test_is_openai_model():
assert not is_openai_model("llama")
assert not is_openai_model("labgpt")
assert not is_openai_model("mixtral-7B")
os.environ["ANYSCALE_API_KEY"] = "abc123"
os.environ["ANYSCALE_BASE_URL"] = "https://example.com"
assert is_openai_model("meta-llama/Meta-Llama-3-70B-Instruct")
assert is_openai_model("mistralai/Mixtral-8x22B-Instruct-v0.1")
os.environ.pop("ANYSCALE_API_KEY")
os.environ.pop("ANYSCALE_BASE_URL")
assert not is_openai_model("meta-llama/Meta-Llama-3-70B-Instruct")
assert not is_openai_model("mistralai/Mixtral-8x22B-Instruct-v0.1")


def test_guess_model_type():
Expand All @@ -65,6 +73,12 @@ def test_guess_model_type():
assert guess_model_type("gpt-4-1106-preview") == "chat"
assert guess_model_type("gpt-3.5-turbo-instruct") == "completion"
assert guess_model_type("davinci-002") == "completion"
os.environ["ANYSCALE_API_KEY"] = "abc123"
os.environ["ANYSCALE_BASE_URL"] = "https://example.com"
assert guess_model_type("meta-llama/Meta-Llama-3-70B-Instruct") == "chat"
assert guess_model_type("mistralai/Mixtral-8x22B-Instruct-v0.1") == "chat"
os.environ.pop("ANYSCALE_API_KEY")
os.environ.pop("ANYSCALE_BASE_URL")


def test_get_citations():
Expand Down Expand Up @@ -544,6 +558,49 @@ def accum(x):
await docs.aquery("What is the national flag of Canada?", answer=answer)


@pytest.mark.skipif(
not (os.environ.get("ANYSCALE_BASE_URL") and os.environ.get("ANYSCALE_API_KEY")),
reason="Anyscale URL and keys are not set",
)
@pytest.mark.asyncio()
async def test_anyscale_chain():
client = AsyncOpenAI(
base_url=os.environ["ANYSCALE_BASE_URL"], api_key=os.environ["ANYSCALE_API_KEY"]
)
llm = OpenAILLMModel(
config={
"temperature": 0,
"model": "meta-llama/Meta-Llama-3-8B-Instruct",
"max_tokens": 56,
mskarlin marked this conversation as resolved.
Show resolved Hide resolved
}
)
call = llm.make_chain(
client,
"The {animal} says",
skip_system=True,
)
outputs = []

def accum(x):
outputs.append(x)

completion = await call({"animal": "duck"}, callbacks=[accum]) # type: ignore[call-arg]
mskarlin marked this conversation as resolved.
Show resolved Hide resolved
assert completion.seconds_to_first_token > 0
assert completion.prompt_count > 0
assert completion.completion_count > 0
assert str(completion) == "".join(outputs)

completion = await call({"animal": "duck"}) # type: ignore[call-arg]
assert completion.seconds_to_first_token == 0
assert completion.seconds_to_last_token > 0

# check with mixed callbacks
async def ac(x): # noqa: ARG001
pass

completion = await call({"animal": "duck"}, callbacks=[accum, ac]) # type: ignore[call-arg]


def test_docs():
docs = Docs(llm="babbage-002")
docs.add_url(
Expand Down