diff --git a/minichain/backend.py b/minichain/backend.py index ebae501..2292588 100644 --- a/minichain/backend.py +++ b/minichain/backend.py @@ -4,7 +4,8 @@ from dataclasses import dataclass from types import TracebackType from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple - +import litellm +from litellm import completion, embedding from eliot import start_action, to_file if TYPE_CHECKING: @@ -190,29 +191,13 @@ def run(self, request: str) -> str: import manifest chat = {"gpt-4", "gpt-3.5-turbo"} - manifest = manifest.Manifest( - client_name="openaichat" if self.model in chat else "openai", - max_tokens=self.options["max_tokens"], - cache_name="sqlite", - cache_connection=f"{MinichainContext.name}", - ) + messages=[{"role": "user", "content": request}] + ans = completion(model=self.model, messages=messages, stop=self.stop) - ans = manifest.run( - request, - stop_sequences=self.stop, - ) - return str(ans) + return ans["choices"][0]["messages"]["content"] def run_stream(self, prompt: str) -> Iterator[str]: - import openai - - self.api_key = os.environ.get("OPENAI_API_KEY") - assert ( - self.api_key - ), "Need an OPENAI_API_KEY. Get one here https://openai.com/api/" - openai.api_key = self.api_key - - for chunk in openai.ChatCompletion.create( + for chunk in completion( model=self.model, messages=[{"role": "user", "content": prompt}], stream=True, @@ -233,14 +218,13 @@ def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None def run(self, request: str) -> str: import openai - self.api_key = os.environ.get("OPENAI_API_KEY") + litellm.openai_key = os.environ.get("OPENAI_API_KEY") assert ( - self.api_key + litellm.openai_key ), "Need an OPENAI_API_KEY. Get one here https://openai.com/api/" - openai.api_key = self.api_key - ans = openai.Embedding.create( - engine=self.model, + ans = embedding( + model=self.model, input=request, ) return ans["data"][0]["embedding"] # type: ignore