Skip to content

Commit

Permalink
Support Anyscale OpenAI models (#285)
Browse files Browse the repository at this point in the history
* support anyscale openai models

* set up openAI client for anyscale

* add completion test for anyscale

* correct input order

* replace async client w walrus assigned variables

* pop env vars to maintain other tests

* Update tests/test_paperqa.py

Co-authored-by: James Braza <[email protected]>

* comment on max length size

* use outputs.append correctly after bad merge

---------

Co-authored-by: Michael Skarlinski <[email protected]>
Co-authored-by: James Braza <[email protected]>
  • Loading branch information
3 people authored Jun 14, 2024
1 parent b7a3d68 commit 51b1b3d
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 2 deletions.
10 changes: 9 additions & 1 deletion paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,15 @@ def set_client(
embedding_client: Any | None = None,
):
if client is None and isinstance(self.llm_model, OpenAILLMModel):
client = AsyncOpenAI()
if (api_key := os.environ.get("ANYSCALE_API_KEY")) and (
base_url := os.environ.get("ANYSCALE_BASE_URL")
):
client = AsyncOpenAI(
api_key=api_key,
base_url=base_url,
)
else:
client = AsyncOpenAI()
self._client = client
if embedding_client is None:
# check if we have an openai embedding model in use
Expand Down
24 changes: 23 additions & 1 deletion paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import datetime
import os
import re
from abc import ABC, abstractmethod
from enum import Enum
Expand Down Expand Up @@ -54,6 +55,18 @@ def guess_model_type(model_name: str) -> str: # noqa: PLR0911
return "completion"
if model_name.startswith("davinci"):
return "completion"
if (
os.environ.get("ANYSCALE_API_KEY")
and os.environ.get("ANYSCALE_BASE_URL")
and model_name.startswith("meta-llama/Meta-Llama-3-")
):
return "chat"
if (
os.environ.get("ANYSCALE_API_KEY")
and os.environ.get("ANYSCALE_BASE_URL")
and (model_name.startswith(("mistralai/Mistral-", "mistralai/Mixtral-")))
):
return "chat"
if "instruct" in model_name:
return "completion"
if model_name.startswith("gpt-4"):
Expand All @@ -66,7 +79,16 @@ def guess_model_type(model_name: str) -> str: # noqa: PLR0911


def is_openai_model(model_name) -> bool:
return model_name.startswith(("gpt-", "babbage", "davinci", "ft:gpt-"))
open_ai_model_prefixes = {"gpt-", "babbage", "davinci", "ft:gpt-"}
# add special prefixes if the user has anyscale models
# https://docs.anyscale.com/endpoints/text-generation/query-a-model/
if os.environ.get("ANYSCALE_API_KEY") and os.environ.get("ANYSCALE_BASE_URL"):
open_ai_model_prefixes |= {
"meta-llama/Meta-Llama-3-",
"mistralai/Mistral-",
"mistralai/Mixtral-",
}
return model_name.startswith(tuple(open_ai_model_prefixes))


def process_llm_config(
Expand Down
49 changes: 49 additions & 0 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ def test_is_openai_model():
assert not is_openai_model("llama")
assert not is_openai_model("labgpt")
assert not is_openai_model("mixtral-7B")
os.environ["ANYSCALE_API_KEY"] = "abc123"
os.environ["ANYSCALE_BASE_URL"] = "https://example.com"
assert is_openai_model("meta-llama/Meta-Llama-3-70B-Instruct")
assert is_openai_model("mistralai/Mixtral-8x22B-Instruct-v0.1")
os.environ.pop("ANYSCALE_API_KEY")
os.environ.pop("ANYSCALE_BASE_URL")
assert not is_openai_model("meta-llama/Meta-Llama-3-70B-Instruct")
assert not is_openai_model("mistralai/Mixtral-8x22B-Instruct-v0.1")


def test_guess_model_type():
Expand All @@ -65,6 +73,12 @@ def test_guess_model_type():
assert guess_model_type("gpt-4-1106-preview") == "chat"
assert guess_model_type("gpt-3.5-turbo-instruct") == "completion"
assert guess_model_type("davinci-002") == "completion"
os.environ["ANYSCALE_API_KEY"] = "abc123"
os.environ["ANYSCALE_BASE_URL"] = "https://example.com"
assert guess_model_type("meta-llama/Meta-Llama-3-70B-Instruct") == "chat"
assert guess_model_type("mistralai/Mixtral-8x22B-Instruct-v0.1") == "chat"
os.environ.pop("ANYSCALE_API_KEY")
os.environ.pop("ANYSCALE_BASE_URL")


def test_get_citations():
Expand Down Expand Up @@ -544,6 +558,41 @@ def accum(x):
await docs.aquery("What is the national flag of Canada?", answer=answer)


@pytest.mark.skipif(
not (os.environ.get("ANYSCALE_BASE_URL") and os.environ.get("ANYSCALE_API_KEY")),
reason="Anyscale URL and keys are not set",
)
@pytest.mark.asyncio()
async def test_anyscale_chain():
client = AsyncOpenAI(
base_url=os.environ["ANYSCALE_BASE_URL"], api_key=os.environ["ANYSCALE_API_KEY"]
)
llm = OpenAILLMModel(
config={
"temperature": 0,
"model": "meta-llama/Meta-Llama-3-8B-Instruct",
"max_tokens": 56, # matches openAI chat test
}
)
call = llm.make_chain(
client,
"The {animal} says",
skip_system=True,
)
outputs = [] # type: ignore[var-annotated]
completion = await call({"animal": "duck"}, callbacks=[outputs.append]) # type: ignore[call-arg]
assert completion.seconds_to_first_token > 0
assert completion.prompt_count > 0
assert completion.completion_count > 0
assert str(completion) == "".join(outputs)

completion = await call({"animal": "duck"}) # type: ignore[call-arg]
assert completion.seconds_to_first_token == 0
assert completion.seconds_to_last_token > 0

completion = await call({"animal": "duck"}, callbacks=[outputs.append]) # type: ignore[call-arg]


def test_docs():
docs = Docs(llm="babbage-002")
docs.add_url(
Expand Down

0 comments on commit 51b1b3d

Please sign in to comment.