-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from silvanmelchior/llama
Llama
- Loading branch information
Showing
20 changed files
with
673 additions
and
277 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
from .base import BaseLLM, LLMException | ||
from .types import Message, Response | ||
from .selector import get_llm | ||
from .gpt_openai import GPTOpenAI |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .gpt_openai import GPTOpenAI | ||
from .gpt_azure import GPTAzure |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from typing import Generator | ||
|
||
import openai | ||
from openai import OpenAIError | ||
|
||
from llm.base import BaseLLM, LLMException | ||
from llm.types import Message, Response | ||
from .parsing import msg_to_gpt_msg, lazy_parse_args, fill_dict | ||
from .prompt import FUNCTIONS | ||
|
||
|
||
class GPT(BaseLLM): | ||
def __init__(self, model_selection: dict): | ||
self._model_selection = model_selection | ||
|
||
def chat(self, history: list[Message]) -> Generator[Response, None, None]: | ||
messages = [msg_to_gpt_msg(msg) for msg in history] | ||
|
||
try: | ||
chunk_generator = openai.ChatCompletion.create( | ||
**self._model_selection, | ||
messages=messages, | ||
temperature=0, | ||
functions=FUNCTIONS, | ||
function_call="auto", | ||
stream=True, | ||
) | ||
|
||
response = {} | ||
previous_code = None | ||
for chunk_all in chunk_generator: | ||
chunk = chunk_all["choices"][0]["delta"] | ||
fill_dict(response, chunk) | ||
|
||
text = None | ||
if "content" in response: | ||
text = response["content"] | ||
|
||
code = None | ||
if ( | ||
"function_call" in response | ||
and "arguments" in response["function_call"] | ||
): | ||
args = response["function_call"]["arguments"] | ||
code = lazy_parse_args(args) | ||
if code is None: | ||
code = previous_code | ||
previous_code = code | ||
|
||
yield Response(text=text, code=code) | ||
|
||
except OpenAIError as e: | ||
raise LLMException(str(e)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import openai | ||
|
||
from utils import get_env_var | ||
from .gpt import GPT | ||
|
||
|
||
class GPTAzure(GPT): | ||
def __init__(self, engine_name: str): | ||
openai.api_type = "azure" | ||
openai.api_base = get_env_var("AZURE_API_BASE") | ||
openai.api_version = "2023-07-01-preview" | ||
openai.api_key = get_env_var("AZURE_API_KEY") | ||
super().__init__({"engine": engine_name}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import openai | ||
|
||
from utils import get_env_var | ||
from .gpt import GPT | ||
|
||
|
||
class GPTOpenAI(GPT): | ||
def __init__(self, model_name: str): | ||
openai.api_key = get_env_var("OPENAI_API_KEY") | ||
super().__init__({"model": model_name}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import re | ||
import json | ||
|
||
from llm.types import Message | ||
|
||
|
||
def msg_to_gpt_msg(msg: Message) -> dict: | ||
if msg.role == "user": | ||
return {"role": "user", "content": msg.text} | ||
if msg.role == "model": | ||
response = { | ||
"role": "assistant", | ||
"content": msg.text or None, | ||
} | ||
if msg.code: | ||
response["function_call"] = { | ||
"name": "run_python_code", | ||
"arguments": json.dumps({"code": msg.code}), | ||
} | ||
return response | ||
if msg.role == "interpreter": | ||
return { | ||
"role": "function", | ||
"name": "run_python_code", | ||
"content": msg.code_result, | ||
} | ||
raise ValueError(f"Invalid message role {msg.role}") | ||
|
||
|
||
def lazy_parse_args(args_partial): | ||
args = args_partial | ||
if not re.sub(r"\s+", "", args).endswith('"}'): | ||
args += '"}' | ||
|
||
try: | ||
args = json.loads(args) | ||
if "code" not in args: | ||
return None | ||
except json.JSONDecodeError: | ||
return None | ||
|
||
return args["code"] | ||
|
||
|
||
def fill_dict(dst: dict, chunk: dict): | ||
for key in chunk: | ||
if chunk[key] is None: | ||
dst[key] = None | ||
elif isinstance(chunk[key], dict): | ||
if key not in dst: | ||
dst[key] = {} | ||
fill_dict(dst[key], chunk[key]) | ||
elif isinstance(chunk[key], str): | ||
if key not in dst: | ||
dst[key] = "" | ||
dst[key] += chunk[key] | ||
else: | ||
raise ValueError(f"Unsupported type {type(chunk[key])}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
FUNCTIONS = [ | ||
{ | ||
"name": "run_python_code", | ||
"description": "Runs arbitrary Python code and returns stdout and stderr. " | ||
+ "The code is executed in an interactive shell, imports and variables are preserved between calls. " | ||
+ "The environment has internet and file system access. " | ||
+ "The current working directory is shared with the user, so files can be exchanged. " | ||
+ "There are many libraries pre-installed, including numpy, pandas, matplotlib, and scikit-learn. " | ||
+ "You cannot show rich outputs like plots or images, but you can store them in the working directory and point the user to them. " | ||
+ "If the code runs too long, there will be a timeout.", | ||
"parameters": { | ||
"type": "object", | ||
"properties": { | ||
"code": { | ||
"type": "string", | ||
"description": "The Python code to run", | ||
}, | ||
}, | ||
"required": ["code"], | ||
}, | ||
}, | ||
] |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .llama_replicate import LlamaReplicate | ||
from .llama_tgi import LlamaTGI |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from typing import Generator, Optional | ||
|
||
import replicate | ||
from replicate.exceptions import ReplicateException | ||
|
||
from llm.base import BaseLLM, LLMException | ||
from llm.types import Message, Response | ||
from utils import get_env_var | ||
|
||
from .prompt import SYSTEM_PROMPT | ||
from .parsing import msg_to_llama_msg, split_output | ||
|
||
|
||
class LlamaReplicate(BaseLLM): | ||
def __init__(self, model_name: str): | ||
self._model_name = model_name | ||
self._client = replicate.Client(api_token=get_env_var("REPLICATE_API_KEY")) | ||
|
||
def chat(self, history: list[Message]) -> Generator[Response, None, None]: | ||
messages = [msg_to_llama_msg(msg) for msg in history] | ||
try: | ||
output = self._client.run( | ||
self._model_name, | ||
input={ | ||
"prompt": " ".join(messages), | ||
"system_prompt": SYSTEM_PROMPT, | ||
"temperature": 0.01, | ||
}, | ||
) | ||
|
||
full_text = "" | ||
for item in output: | ||
full_text += item | ||
|
||
text, code, finished = split_output(full_text) | ||
if text is not None or code is not None: | ||
yield Response(text=text, code=code) | ||
|
||
if finished: | ||
break | ||
|
||
except ReplicateException as e: | ||
raise LLMException(str(e)) |
Oops, something went wrong.