forked from openai/evals
-
Notifications
You must be signed in to change notification settings - Fork 0
/
langchain_llm.py
34 lines (25 loc) · 1.22 KB
/
langchain_llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import importlib
from typing import Optional
from evals.api import CompletionFn, CompletionResult
from langchain.llms import BaseLLM
from evals.prompt.base import CompletionPrompt
from evals.record import record_sampling
class LangChainLLMCompletionResult(CompletionResult):
def __init__(self, response) -> None:
self.response = response
def get_completions(self) -> list[str]:
return [self.response.strip()]
class LangChainLLMCompletionFn(CompletionFn):
def __init__(self, llm: str, llm_kwargs: Optional[dict] = {}, **kwargs) -> None:
# Import and resolve self.llm to an instance of llm argument here, assuming it's always a subclass of BaseLLM
module = importlib.import_module("langchain.llms")
LLMClass = getattr(module, llm)
if issubclass(LLMClass, BaseLLM):
self.llm = LLMClass(**llm_kwargs)
else:
raise ValueError(f"{llm} is not a subclass of BaseLLM")
def __call__(self, prompt, **kwargs) -> LangChainLLMCompletionResult:
prompt = CompletionPrompt(prompt).to_formatted_prompt()
response = self.llm(prompt)
record_sampling(prompt=prompt, sampled=response)
return LangChainLLMCompletionResult(response)