From 251c686acc801ffe8b99306d4359769375324bf1 Mon Sep 17 00:00:00 2001 From: Silvan Melchior Date: Wed, 9 Aug 2023 16:50:44 +0200 Subject: [PATCH 01/23] adjust architecture to streaming llms, ui complete, llm service WIP --- {interpreter => services}/.gitignore | 0 .../interpreter/__init__.py | 0 .../interpreter/ipython_interpreter.py | 0 services/llm/__init__.py | 0 .../main.py => services/main_interpreter.py | 0 services/main_llm.py | 68 +++++++++++++++++++ {interpreter => services}/poetry.lock | 0 {interpreter => services}/pyproject.toml | 7 +- .../tests/test_interpreter.py | 0 ui/app/api/chat/route.ts | 2 + ui/app/page.tsx | 20 ++++-- ui/app/session/approval/approver.tsx | 13 ++-- ui/app/session/communication/chat_round.tsx | 65 +++++++++++------- .../{api_calls.tsx => interpreter.tsx} | 22 +----- ui/app/session/communication/llm.tsx | 41 +++++++++++ ui/app/session/session.tsx | 18 +++-- ui/app/session/session_manager.tsx | 5 +- 17 files changed, 197 insertions(+), 64 deletions(-) rename {interpreter => services}/.gitignore (100%) rename {interpreter => services}/interpreter/__init__.py (100%) rename {interpreter => services}/interpreter/ipython_interpreter.py (100%) create mode 100644 services/llm/__init__.py rename interpreter/main.py => services/main_interpreter.py (100%) create mode 100644 services/main_llm.py rename {interpreter => services}/poetry.lock (100%) rename {interpreter => services}/pyproject.toml (82%) rename {interpreter => services}/tests/test_interpreter.py (100%) rename ui/app/session/communication/{api_calls.tsx => interpreter.tsx} (71%) create mode 100644 ui/app/session/communication/llm.tsx diff --git a/interpreter/.gitignore b/services/.gitignore similarity index 100% rename from interpreter/.gitignore rename to services/.gitignore diff --git a/interpreter/interpreter/__init__.py b/services/interpreter/__init__.py similarity index 100% rename from interpreter/interpreter/__init__.py rename to services/interpreter/__init__.py diff --git a/interpreter/interpreter/ipython_interpreter.py b/services/interpreter/ipython_interpreter.py similarity index 100% rename from interpreter/interpreter/ipython_interpreter.py rename to services/interpreter/ipython_interpreter.py diff --git a/services/llm/__init__.py b/services/llm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/interpreter/main.py b/services/main_interpreter.py similarity index 100% rename from interpreter/main.py rename to services/main_interpreter.py diff --git a/services/main_llm.py b/services/main_llm.py new file mode 100644 index 0000000..c3e0551 --- /dev/null +++ b/services/main_llm.py @@ -0,0 +1,68 @@ +import os +import sys +import time +from typing import Literal, Optional + +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, TypeAdapter + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +try: + LLM_SETTING = os.environ["LLM"] +except KeyError: + print("ERROR: Missing environment variables, exiting...", file=sys.stderr) + sys.exit(1) + + +class Message(BaseModel): + role: Literal["user", "model", "interpreter"] + text: Optional[str] = None + code: Optional[str] = None + code_result: Optional[str] = None + + +Request = TypeAdapter(list[Message]) + + +class Response(BaseModel): + text: Optional[str] = None + code: Optional[str] = None + + +@app.websocket("/chat") +async def chat(websocket: WebSocket): + await websocket.accept() + + history = await websocket.receive_text() + history = Request.validate_json(history) + print(history, type(history)) + + text = "I will print hello world" + for i in range(len(text)): + response = Response(text=text[: i + 1]) + await websocket.send_text(response.model_dump_json(exclude_none=True)) + time.sleep(0.1) + + code = "print('hello world')" + for i in range(len(code)): + response = Response(text=text, code=code[: i + 1]) + await websocket.send_text(response.model_dump_json(exclude_none=True)) + time.sleep(0.1) + + await websocket.close() + + # TODO: handle disconnect (stop generator) + # try: + # ... + # except WebSocketDisconnect: + # pass + # TODO: error handling as in main_interpreter.py diff --git a/interpreter/poetry.lock b/services/poetry.lock similarity index 100% rename from interpreter/poetry.lock rename to services/poetry.lock diff --git a/interpreter/pyproject.toml b/services/pyproject.toml similarity index 82% rename from interpreter/pyproject.toml rename to services/pyproject.toml index fcd4b78..9e8fb3b 100644 --- a/interpreter/pyproject.toml +++ b/services/pyproject.toml @@ -1,10 +1,13 @@ [tool.poetry] -name = "interpreter" +name = "services" version = "0.0.0" description = "" authors = ["Silvan Melchior"] license = "MIT" -packages = [{include = "interpreter"}] +packages = [ + {include = "interpreter"}, + {include = "llm"}, +] [tool.poetry.dependencies] python = ">=3.9" diff --git a/interpreter/tests/test_interpreter.py b/services/tests/test_interpreter.py similarity index 100% rename from interpreter/tests/test_interpreter.py rename to services/tests/test_interpreter.py diff --git a/ui/app/api/chat/route.ts b/ui/app/api/chat/route.ts index 5aa8767..9706abf 100644 --- a/ui/app/api/chat/route.ts +++ b/ui/app/api/chat/route.ts @@ -27,3 +27,5 @@ export async function POST( throw e; } } + +// TODO: REMOVE ALL OF THIS, ALSO llm FOLDER with base and gpt4 diff --git a/ui/app/page.tsx b/ui/app/page.tsx index 658408f..a0e3b58 100644 --- a/ui/app/page.tsx +++ b/ui/app/page.tsx @@ -1,5 +1,5 @@ import React from "react"; -import Session_manager from "@/app/session/session_manager"; +import SessionManager from "@/app/session/session_manager"; import path from "path"; import * as fs from "fs"; @@ -13,6 +13,14 @@ function getInterpreterUrl() { return interpreterUrl; } +function getLlmUrl() { + const llmUrl = process.env.LLM_URL; + if (llmUrl === undefined) { + throw new Error("LLM_URL is undefined"); + } + return llmUrl; +} + function getVersion(): Promise { const versionDir = path.dirname( path.dirname(path.dirname(path.dirname(__dirname))), @@ -30,7 +38,11 @@ function getVersion(): Promise { } export default async function Home() { - const interpreterUrl = getInterpreterUrl(); - const version = await getVersion(); - return ; + return ( + + ); } diff --git a/ui/app/session/approval/approver.tsx b/ui/app/session/approval/approver.tsx index e65539c..b686788 100644 --- a/ui/app/session/approval/approver.tsx +++ b/ui/app/session/approval/approver.tsx @@ -4,7 +4,6 @@ export class Approver { private resolveHandler: ((value: void) => void) | null = null; constructor( - private readonly setContent: (content: string) => void, private autoApprove: boolean, private readonly _setAutoApprove: (autoApprove: boolean) => void, private readonly setAskApprove: (askApprove: boolean) => void, @@ -26,10 +25,9 @@ export class Approver { } }; - getApproval = (content: string, tmpAutoApprove: boolean = false) => { - this.setContent(content); + getApproval = () => { return new Promise((resolve, reject) => { - if (this.autoApprove || tmpAutoApprove) { + if (this.autoApprove) { resolve(); } else { this.resolveHandler = resolve; @@ -39,12 +37,11 @@ export class Approver { }; } -export function useApprover(): [Approver, string | null, boolean, boolean] { - const [content, setContent] = React.useState(null); +export function useApprover(): [Approver, boolean, boolean] { const [askApprove, setAskApprove] = React.useState(false); const [autoApprove, setAutoApprove] = React.useState(false); const approverRef = React.useRef( - new Approver(setContent, autoApprove, setAutoApprove, setAskApprove), + new Approver(autoApprove, setAutoApprove, setAskApprove), ); - return [approverRef.current, content, askApprove, autoApprove]; + return [approverRef.current, askApprove, autoApprove]; } diff --git a/ui/app/session/communication/chat_round.tsx b/ui/app/session/communication/chat_round.tsx index 21e24aa..4a761ff 100644 --- a/ui/app/session/communication/chat_round.tsx +++ b/ui/app/session/communication/chat_round.tsx @@ -1,5 +1,6 @@ import { Message } from "@/llm/base"; -import { chatCall, Interpreter } from "@/app/session/communication/api_calls"; +import Interpreter from "@/app/session/communication/interpreter"; +import LLM from "@/app/session/communication/llm"; import { Approver } from "@/app/session/approval/approver"; export type ChatRoundState = @@ -9,27 +10,42 @@ export type ChatRoundState = | "waiting for approval"; export class ChatRound { + private readonly llm: LLM; + constructor( - private _history: Message[], - private readonly _setHistory: (message: Message[]) => void, - private readonly _approverIn: Approver, - private readonly _approverOut: Approver, - private readonly _interpreter: Interpreter, - private readonly _setState: (state: ChatRoundState) => void, - ) {} + private history: Message[], + private readonly setHistory: (message: Message[]) => void, + private readonly approverIn: Approver, + private readonly approverOut: Approver, + private readonly interpreter: Interpreter, + private readonly setState: (state: ChatRoundState) => void, + private readonly setCodeResult: (result: string) => void, + llmUrl: string, + ) { + this.llm = new LLM(llmUrl); + } private extendHistory(message: Message) { - const newHistory = [...this._history, message]; - this._setHistory(newHistory); - this._history = newHistory; + const newHistory = [...this.history, message]; + this.setHistory(newHistory); + this.history = newHistory; + } + + private modifyHistory(message: Message) { + const newHistory = [...this.history.slice(0, -1), message]; + this.setHistory(newHistory); + this.history = newHistory; } private sendMessage = async (message: Message): Promise => { this.extendHistory(message); - this._setState("waiting for model"); - const response = await chatCall(this._history); + this.setState("waiting for model"); + const response: Message = { role: "model" }; this.extendHistory(response); - return response; + await this.llm.chatCall(this.history, (response) => { + this.modifyHistory(response); + }); + return this.history[this.history.length - 1]; }; run = async (message: string) => { @@ -43,7 +59,7 @@ export class ChatRound { await this.approveOut(result); response = await this.sendResult(result); } else { - this._setState("not active"); + this.setState("not active"); break; } } @@ -53,22 +69,25 @@ export class ChatRound { }; private approveIn = async (code: string) => { - this._setState("waiting for approval"); - await this._approverIn.getApproval(code); + this.setState("waiting for approval"); + await this.approverIn.getApproval(); }; private executeCode = async (code: string): Promise => { - this._setState("waiting for interpreter"); - return await this._interpreter.run(code); + this.setState("waiting for interpreter"); + return await this.interpreter.run(code); }; private approveOut = async (result: string) => { - this._setState("waiting for approval"); - const tmpAutoApprove = result === ""; - const resultText = tmpAutoApprove + this.setState("waiting for approval"); + const emptyAutoApprove = result === ""; + const resultText = emptyAutoApprove ? "(empty output was automatically approved)" : result; - await this._approverOut.getApproval(resultText, tmpAutoApprove); + this.setCodeResult(resultText); + if (!emptyAutoApprove) { + await this.approverOut.getApproval(); + } }; private sendResult = async (result: string) => { diff --git a/ui/app/session/communication/api_calls.tsx b/ui/app/session/communication/interpreter.tsx similarity index 71% rename from ui/app/session/communication/api_calls.tsx rename to ui/app/session/communication/interpreter.tsx index 9a6d892..ef29991 100644 --- a/ui/app/session/communication/api_calls.tsx +++ b/ui/app/session/communication/interpreter.tsx @@ -1,22 +1,4 @@ -import axios, { AxiosError } from "axios"; -import { Message } from "@/llm/base"; - -export async function chatCall(messages: Message[]): Promise { - try { - const response = await axios.post("/api/chat", messages); - return response.data; - } catch (e) { - if (e instanceof AxiosError) { - const msg = e.response?.data; - if (msg !== undefined && msg !== "") { - throw new Error(msg); - } - } - throw e; - } -} - -export class Interpreter { +export default class Interpreter { private ws: WebSocket | null = null; constructor(private readonly interpreterUrl: string) {} @@ -31,7 +13,7 @@ export class Interpreter { reject(Error(event.data)); } }; - this.ws!.onerror = (event) => { + this.ws!.onerror = () => { reject(Error("Could not connect to interpreter")); }; }); diff --git a/ui/app/session/communication/llm.tsx b/ui/app/session/communication/llm.tsx new file mode 100644 index 0000000..3433dab --- /dev/null +++ b/ui/app/session/communication/llm.tsx @@ -0,0 +1,41 @@ +import { Message } from "@/llm/base"; + +export default class LLM { + constructor(private readonly llmUrl: string) {} + + private connect(): Promise { + return new Promise((resolve, reject) => { + const ws = new WebSocket(`ws://${this.llmUrl}/chat`); + ws.onopen = () => { + resolve(ws); + }; + ws.onerror = () => { + reject(Error("Could not connect to LLM")); + }; + }); + } + + private waitClose(ws: WebSocket): Promise { + return new Promise((resolve, reject) => { + ws.onclose = () => { + resolve(); + }; + }); + } + + async chatCall( + messages: Message[], + onResponse: (response: Message) => void, + ): Promise { + const ws = await this.connect(); + + ws.send(JSON.stringify(messages)); + ws.onmessage = (event) => { + const message = JSON.parse(event.data) as Message; + message.role = "model"; + onResponse(message); + }; + + await this.waitClose(ws); + } +} diff --git a/ui/app/session/session.tsx b/ui/app/session/session.tsx index 676646f..1c5d7d6 100644 --- a/ui/app/session/session.tsx +++ b/ui/app/session/session.tsx @@ -3,7 +3,7 @@ import { Message } from "@/llm/base"; import ChatInput from "@/app/session/chat/chat_input"; import ChatHistory from "@/app/session/chat/chat_history"; import InterpreterIO from "@/app/session/approval/interpreter_io"; -import { Interpreter } from "@/app/session/communication/api_calls"; +import Interpreter from "@/app/session/communication/interpreter"; import { useApprover } from "@/app/session/approval/approver"; import { ChatRound, @@ -14,10 +14,12 @@ import Brand from "@/app/session/chat/brand"; export default function Session({ interpreterUrl, + llmUrl, refreshSession, version, }: { interpreterUrl: string; + llmUrl: string; refreshSession: () => void; version: string; }) { @@ -27,20 +29,22 @@ export default function Session({ const [chatRoundState, setChatRoundState] = React.useState("not active"); - const [approverIn, code, askApproveIn, autoApproveIn] = useApprover(); - const [approverOut, result, askApproveOut, autoApproveOut] = useApprover(); + const [approverIn, askApproveIn, autoApproveIn] = useApprover(); + const [approverOut, askApproveOut, autoApproveOut] = useApprover(); + const [codeResult, setCodeResult] = React.useState(null); const chatInputRef = React.useRef(null); const interpreterRef = React.useRef(null); if (interpreterRef.current === null) { interpreterRef.current = new Interpreter(interpreterUrl); } + const code = history.findLast((msg) => msg.code !== undefined)?.code ?? null; React.useEffect(() => { - if (chatRoundState === "waiting for approval") { + if (code !== null) { setShowIO(true); } - }, [chatRoundState]); + }, [code]); const focusChatInput = () => { setTimeout(() => chatInputRef.current && chatInputRef.current.focus(), 100); @@ -55,6 +59,8 @@ export default function Session({ approverOut, interpreterRef.current!, setChatRoundState, + setCodeResult, + llmUrl, ); chatRound .run(message) @@ -116,7 +122,7 @@ export default function Session({
); From 95288e075c443da9ebd5ba8cf10d04af6fdfabf2 Mon Sep 17 00:00:00 2001 From: Silvan Melchior Date: Thu, 10 Aug 2023 18:21:47 +0200 Subject: [PATCH 02/23] backend service skeletons --- services/llm/__init__.py | 4 + services/llm/base.py | 14 +++ services/llm/gpt_openai.py | 28 +++++ services/llm/selector.py | 9 ++ services/llm/types.py | 15 +++ services/main_interpreter.py | 54 +++++----- services/main_llm.py | 89 ++++++--------- services/poetry.lock | 114 ++++++++++++++++++-- services/pyproject.toml | 2 + services/utils/__init__.py | 2 + services/utils/app.py | 14 +++ services/utils/env_var.py | 13 +++ ui/app/session/communication/chat_round.tsx | 2 +- ui/app/session/communication/llm.tsx | 39 +++++-- 14 files changed, 295 insertions(+), 104 deletions(-) create mode 100644 services/llm/base.py create mode 100644 services/llm/gpt_openai.py create mode 100644 services/llm/selector.py create mode 100644 services/llm/types.py create mode 100644 services/utils/__init__.py create mode 100644 services/utils/app.py create mode 100644 services/utils/env_var.py diff --git a/services/llm/__init__.py b/services/llm/__init__.py index e69de29..b3f14e2 100644 --- a/services/llm/__init__.py +++ b/services/llm/__init__.py @@ -0,0 +1,4 @@ +from .base import BaseLLM, LLMException +from .types import Message, Response +from .selector import get_llm +from .gpt_openai import GPTOpenAI diff --git a/services/llm/base.py b/services/llm/base.py new file mode 100644 index 0000000..b7fdc01 --- /dev/null +++ b/services/llm/base.py @@ -0,0 +1,14 @@ +from abc import ABC, abstractmethod +from typing import Generator + +from llm.types import Message, Response + + +class BaseLLM(ABC): + @abstractmethod + def chat(self, history: list[Message]) -> Generator[Response, None, None]: + """Given a chat history, return a generator which streams the response.""" + + +class LLMException(Exception): + """If an error occurs in the LLM, raise this exception, will be shown in UI.""" diff --git a/services/llm/gpt_openai.py b/services/llm/gpt_openai.py new file mode 100644 index 0000000..1f6c697 --- /dev/null +++ b/services/llm/gpt_openai.py @@ -0,0 +1,28 @@ +import time + +import openai + +from llm.base import BaseLLM, LLMException +from llm.types import Response +from utils import get_env_var + + +class GPTOpenAI(BaseLLM): + def __init__(self, model_name: str): + self._model_name = model_name + openai.api_key = get_env_var("OPENAI_API_KEY") + + def chat(self, history): + print(history, type(history)) + + text = "I will print hello world" + for i in range(len(text)): + yield Response(text=text[: i + 1]) + time.sleep(0.1) + + # raise LLMException("This is an error") + + code = "print('hello world')" + for i in range(len(code)): + yield Response(text=text, code=code[: i + 1]) + time.sleep(0.1) diff --git a/services/llm/selector.py b/services/llm/selector.py new file mode 100644 index 0000000..c65bd87 --- /dev/null +++ b/services/llm/selector.py @@ -0,0 +1,9 @@ +from llm.gpt_openai import GPTOpenAI +from llm.base import BaseLLM + + +def get_llm(llm_setting: str) -> BaseLLM: + if llm_setting.startswith("gpt-openai:"): + return GPTOpenAI(llm_setting[11:]) + + raise ValueError(f"Unknown LLM setting: {llm_setting}") diff --git a/services/llm/types.py b/services/llm/types.py new file mode 100644 index 0000000..84dee3e --- /dev/null +++ b/services/llm/types.py @@ -0,0 +1,15 @@ +from typing import Literal, Optional + +from pydantic import BaseModel + + +class Message(BaseModel): + role: Literal["user", "model", "interpreter"] + text: Optional[str] = None + code: Optional[str] = None + code_result: Optional[str] = None + + +class Response(BaseModel): + text: Optional[str] = None + code: Optional[str] = None diff --git a/services/main_interpreter.py b/services/main_interpreter.py index d8f3932..35683b2 100644 --- a/services/main_interpreter.py +++ b/services/main_interpreter.py @@ -1,34 +1,18 @@ -import os -import sys from pathlib import Path -from fastapi import FastAPI, WebSocket, WebSocketDisconnect -from fastapi.middleware.cors import CORSMiddleware +from fastapi import WebSocket, WebSocketDisconnect +from websockets.exceptions import ConnectionClosedError from interpreter import IPythonInterpreter +from utils import get_app, get_env_var -app = FastAPI() -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) -TIMEOUT = ( - int(os.environ["INTERPRETER_TIMEOUT"]) - if "INTERPRETER_TIMEOUT" in os.environ - else 30 -) -TIMEOUT_MESSAGE = "ERROR: TIMEOUT REACHED" +app = get_app() -try: - WORKING_DIRECTORY = Path(os.environ["WORKING_DIRECTORY"]) - IPYTHON_PATH = Path(os.environ["IPYTHON_PATH"]) -except KeyError: - print("ERROR: Missing environment variables, exiting...", file=sys.stderr) - sys.exit(1) +WORKING_DIRECTORY = Path(get_env_var("WORKING_DIRECTORY")) +IPYTHON_PATH = Path(get_env_var("IPYTHON_PATH")) +TIMEOUT = int(get_env_var("INTERPRETER_TIMEOUT", "30")) +TIMEOUT_MESSAGE = "ERROR: TIMEOUT REACHED" def get_interpreter() -> IPythonInterpreter: @@ -43,13 +27,27 @@ def get_interpreter() -> IPythonInterpreter: @app.websocket("/run") async def run(websocket: WebSocket): - await websocket.accept() + ws_exceptions = WebSocketDisconnect, ConnectionClosedError + + try: + await websocket.accept() + except ws_exceptions: + return + try: interpreter = get_interpreter() except Exception as e: - await websocket.send_text(str(e)) + try: + await websocket.send_text(str(e)) + except ws_exceptions: + return + return + + try: + await websocket.send_text("_ready_") + except ws_exceptions: + interpreter.stop() return - await websocket.send_text("_ready_") try: while True: @@ -62,7 +60,7 @@ async def run(websocket: WebSocket): except Exception as e: response = f"_error_ {e}" await websocket.send_text(response) - except WebSocketDisconnect: + except ws_exceptions: pass interpreter.stop() diff --git a/services/main_llm.py b/services/main_llm.py index c3e0551..73758d3 100644 --- a/services/main_llm.py +++ b/services/main_llm.py @@ -1,68 +1,47 @@ -import os -import sys -import time -from typing import Literal, Optional +from fastapi import WebSocket +from fastapi.websockets import WebSocketDisconnect +from pydantic import TypeAdapter +from websockets.exceptions import ConnectionClosedError -from fastapi import FastAPI, WebSocket, WebSocketDisconnect -from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel, TypeAdapter +from llm import LLMException, Message, get_llm +from utils import get_app, get_env_var -app = FastAPI() -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) -try: - LLM_SETTING = os.environ["LLM"] -except KeyError: - print("ERROR: Missing environment variables, exiting...", file=sys.stderr) - sys.exit(1) - - -class Message(BaseModel): - role: Literal["user", "model", "interpreter"] - text: Optional[str] = None - code: Optional[str] = None - code_result: Optional[str] = None +app = get_app() +LLM_SETTING = get_env_var("LLM", "gpt-openai:gpt-4") +llm = get_llm(LLM_SETTING) Request = TypeAdapter(list[Message]) -class Response(BaseModel): - text: Optional[str] = None - code: Optional[str] = None - - @app.websocket("/chat") async def chat(websocket: WebSocket): - await websocket.accept() + ws_exceptions = WebSocketDisconnect, ConnectionClosedError - history = await websocket.receive_text() - history = Request.validate_json(history) - print(history, type(history)) - - text = "I will print hello world" - for i in range(len(text)): - response = Response(text=text[: i + 1]) - await websocket.send_text(response.model_dump_json(exclude_none=True)) - time.sleep(0.1) + try: + await websocket.accept() + history = await websocket.receive_text() + except ws_exceptions: + return - code = "print('hello world')" - for i in range(len(code)): - response = Response(text=text, code=code[: i + 1]) - await websocket.send_text(response.model_dump_json(exclude_none=True)) - time.sleep(0.1) - - await websocket.close() + history = Request.validate_json(history) - # TODO: handle disconnect (stop generator) - # try: - # ... - # except WebSocketDisconnect: - # pass - # TODO: error handling as in main_interpreter.py + try: + response_generator = llm.chat(history) + try: + for response in response_generator: + msg = "_success_ " + response.model_dump_json(exclude_none=True) + await websocket.send_text(msg) + await websocket.close() + + except ws_exceptions: + response_generator.close() + return + + except LLMException as e: + try: + await websocket.send_text("_error_ " + str(e)) + await websocket.close() + except ws_exceptions: + return diff --git a/services/poetry.lock b/services/poetry.lock index dd174ed..747100d 100644 --- a/services/poetry.lock +++ b/services/poetry.lock @@ -2,7 +2,7 @@ name = "aiohttp" version = "3.8.5" description = "Async http client/server framework (asyncio)" -category = "dev" +category = "main" optional = false python-versions = ">=3.6" @@ -22,7 +22,7 @@ speedups = ["Brotli", "aiodns", "cchardet"] name = "aiosignal" version = "1.3.1" description = "aiosignal: a list of registered asynchronous callbacks" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" @@ -81,7 +81,7 @@ test = ["astroid", "pytest"] name = "async-timeout" version = "4.0.2" description = "Timeout context manager for asyncio programs" -category = "dev" +category = "main" optional = false python-versions = ">=3.6" @@ -89,7 +89,7 @@ python-versions = ">=3.6" name = "attrs" version = "23.1.0" description = "Classes Without Boilerplate" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" @@ -132,11 +132,19 @@ d = ["aiohttp (>=3.7.4)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] +[[package]] +name = "certifi" +version = "2023.7.22" +description = "Python package for providing Mozilla's CA Bundle." +category = "main" +optional = false +python-versions = ">=3.6" + [[package]] name = "charset-normalizer" version = "3.2.0" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." -category = "dev" +category = "main" optional = false python-versions = ">=3.7.0" @@ -209,7 +217,7 @@ all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)" name = "frozenlist" version = "1.4.0" description = "A list-like structure which implements collections.abc.MutableSequence" -category = "dev" +category = "main" optional = false python-versions = ">=3.8" @@ -304,7 +312,7 @@ traitlets = "*" name = "multidict" version = "6.0.4" description = "multidict implementation" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" @@ -316,6 +324,25 @@ category = "dev" optional = false python-versions = ">=3.5" +[[package]] +name = "openai" +version = "0.27.8" +description = "Python client library for the OpenAI API" +category = "main" +optional = false +python-versions = ">=3.7.1" + +[package.dependencies] +aiohttp = "*" +requests = ">=2.20" +tqdm = "*" + +[package.extras] +datalib = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] +dev = ["black (>=21.6b0,<22.0)", "pytest (>=6.0.0,<7.0.0)", "pytest-asyncio", "pytest-mock"] +embeddings = ["matplotlib", "numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "plotly", "scikit-learn (>=1.0.2)", "scipy", "tenacity (>=8.0.1)"] +wandb = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "wandb"] + [[package]] name = "packaging" version = "23.1" @@ -474,6 +501,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "requests" +version = "2.31.0" +description = "Python HTTP for Humans." +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +certifi = ">=2017.4.17" +charset-normalizer = ">=2,<4" +idna = ">=2.5,<4" +urllib3 = ">=1.21.1,<3" + +[package.extras] +socks = ["PySocks (>=1.5.6,!=1.5.7)"] +use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] + [[package]] name = "six" version = "1.16.0" @@ -529,6 +574,23 @@ category = "dev" optional = false python-versions = ">=3.7" +[[package]] +name = "tqdm" +version = "4.66.1" +description = "Fast, Extensible Progress Meter" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"] +notebook = ["ipywidgets (>=6)"] +slack = ["slack-sdk"] +telegram = ["requests"] + [[package]] name = "traitlets" version = "5.9.0" @@ -549,6 +611,20 @@ category = "main" optional = false python-versions = ">=3.7" +[[package]] +name = "urllib3" +version = "2.0.4" +description = "HTTP library with thread-safe connection pooling, file post, and more." +category = "main" +optional = false +python-versions = ">=3.7" + +[package.extras] +brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"] +socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] +zstd = ["zstandard (>=0.18.0)"] + [[package]] name = "uvicorn" version = "0.23.1" @@ -585,7 +661,7 @@ python-versions = ">=3.7" name = "yarl" version = "1.9.2" description = "Yet another URL library" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" @@ -596,7 +672,7 @@ multidict = ">=4.0" [metadata] lock-version = "1.1" python-versions = ">=3.9" -content-hash = "d0c9a9ea8374b4d8ca3d1b96220b626d496ddbc173f29f613527eeaf36e3995b" +content-hash = "8b755d341aae82835ddcf1481f8338866b2f2228f786e0c84d8d5d2d48c82c1e" [metadata.files] aiohttp = [ @@ -744,6 +820,10 @@ black = [ {file = "black-23.7.0-py3-none-any.whl", hash = "sha256:9fd59d418c60c0348505f2ddf9609c1e1de8e7493eab96198fc89d9f865e7a96"}, {file = "black-23.7.0.tar.gz", hash = "sha256:022a582720b0d9480ed82576c920a8c1dde97cc38ff11d8d8859b3bd6ca9eedb"}, ] +certifi = [ + {file = "certifi-2023.7.22-py3-none-any.whl", hash = "sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9"}, + {file = "certifi-2023.7.22.tar.gz", hash = "sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082"}, +] charset-normalizer = [ {file = "charset-normalizer-3.2.0.tar.gz", hash = "sha256:3bb3d25a8e6c0aedd251753a79ae98a093c7e7b471faa3aa9a93a81431987ace"}, {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b87549028f680ca955556e3bd57013ab47474c3124dc069faa0b6545b6c9710"}, @@ -1012,6 +1092,10 @@ mypy-extensions = [ {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, ] +openai = [ + {file = "openai-0.27.8-py3-none-any.whl", hash = "sha256:e0a7c2f7da26bdbe5354b03c6d4b82a2f34bd4458c7a17ae1a7092c3e397e03c"}, + {file = "openai-0.27.8.tar.gz", hash = "sha256:2483095c7db1eee274cebac79e315a986c4e55207bb4fa7b82d185b3a2ed9536"}, +] packaging = [ {file = "packaging-23.1-py3-none-any.whl", hash = "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61"}, {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"}, @@ -1167,6 +1251,10 @@ pytest = [ {file = "pytest-7.4.0-py3-none-any.whl", hash = "sha256:78bf16451a2eb8c7a2ea98e32dc119fd2aa758f1d5d66dbf0a59d69a3969df32"}, {file = "pytest-7.4.0.tar.gz", hash = "sha256:b4bf8c45bd59934ed84001ad51e11b4ee40d40a1229d2c79f9c592b0a3f6bd8a"}, ] +requests = [ + {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, + {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, +] six = [ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, @@ -1187,6 +1275,10 @@ tomli = [ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, ] +tqdm = [ + {file = "tqdm-4.66.1-py3-none-any.whl", hash = "sha256:d302b3c5b53d47bce91fea46679d9c3c6508cf6332229aa1e7d8653723793386"}, + {file = "tqdm-4.66.1.tar.gz", hash = "sha256:d88e651f9db8d8551a62556d3cff9e3034274ca5d66e93197cf2490e2dcb69c7"}, +] traitlets = [ {file = "traitlets-5.9.0-py3-none-any.whl", hash = "sha256:9e6ec080259b9a5940c797d58b613b5e31441c2257b87c2e795c5228ae80d2d8"}, {file = "traitlets-5.9.0.tar.gz", hash = "sha256:f6cde21a9c68cf756af02035f72d5a723bf607e862e7be33ece505abf4a3bad9"}, @@ -1195,6 +1287,10 @@ typing-extensions = [ {file = "typing_extensions-4.7.1-py3-none-any.whl", hash = "sha256:440d5dd3af93b060174bf433bccd69b0babc3b15b1a8dca43789fd7f61514b36"}, {file = "typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2"}, ] +urllib3 = [ + {file = "urllib3-2.0.4-py3-none-any.whl", hash = "sha256:de7df1803967d2c2a98e4b11bb7d6bd9210474c46e8a0401514e3a42a75ebde4"}, + {file = "urllib3-2.0.4.tar.gz", hash = "sha256:8d22f86aae8ef5e410d4f539fde9ce6b2113a001bb4d189e0aed70642d602b11"}, +] uvicorn = [ {file = "uvicorn-0.23.1-py3-none-any.whl", hash = "sha256:1d55d46b83ee4ce82b4e82f621f2050adb3eb7b5481c13f9af1744951cae2f1f"}, {file = "uvicorn-0.23.1.tar.gz", hash = "sha256:da9b0c8443b2d7ee9db00a345f1eee6db7317432c9d4400f5049cc8d358383be"}, diff --git a/services/pyproject.toml b/services/pyproject.toml index 9e8fb3b..3a7c60d 100644 --- a/services/pyproject.toml +++ b/services/pyproject.toml @@ -7,6 +7,7 @@ license = "MIT" packages = [ {include = "interpreter"}, {include = "llm"}, + {include = "utils"}, ] [tool.poetry.dependencies] @@ -14,6 +15,7 @@ python = ">=3.9" fastapi = "^0.100.0" uvicorn = "^0.23.0" websockets = "^11.0.3" +openai = "^0.27.8" [tool.poetry.group.dev.dependencies] pytest = "^7.4.0" diff --git a/services/utils/__init__.py b/services/utils/__init__.py new file mode 100644 index 0000000..a4594a3 --- /dev/null +++ b/services/utils/__init__.py @@ -0,0 +1,2 @@ +from .app import get_app +from .env_var import get_env_var diff --git a/services/utils/app.py b/services/utils/app.py new file mode 100644 index 0000000..8e6460c --- /dev/null +++ b/services/utils/app.py @@ -0,0 +1,14 @@ +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + + +def get_app() -> FastAPI: + app = FastAPI() + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + return app diff --git a/services/utils/env_var.py b/services/utils/env_var.py new file mode 100644 index 0000000..58b94f6 --- /dev/null +++ b/services/utils/env_var.py @@ -0,0 +1,13 @@ +import os +import sys + + +def get_env_var(key: str, default: str = None) -> str: + if key in os.environ: + return os.environ[key] + + if default is not None: + return default + + print(f"ERROR: Missing environment variables {key}, exiting...", file=sys.stderr) + sys.exit(1) diff --git a/ui/app/session/communication/chat_round.tsx b/ui/app/session/communication/chat_round.tsx index 4a761ff..23f3b18 100644 --- a/ui/app/session/communication/chat_round.tsx +++ b/ui/app/session/communication/chat_round.tsx @@ -42,7 +42,7 @@ export class ChatRound { this.setState("waiting for model"); const response: Message = { role: "model" }; this.extendHistory(response); - await this.llm.chatCall(this.history, (response) => { + await this.llm.chatCall(this.history.slice(0, -1), (response) => { this.modifyHistory(response); }); return this.history[this.history.length - 1]; diff --git a/ui/app/session/communication/llm.tsx b/ui/app/session/communication/llm.tsx index 3433dab..94c1915 100644 --- a/ui/app/session/communication/llm.tsx +++ b/ui/app/session/communication/llm.tsx @@ -23,19 +23,36 @@ export default class LLM { }); } - async chatCall( + chatCall( messages: Message[], onResponse: (response: Message) => void, ): Promise { - const ws = await this.connect(); - - ws.send(JSON.stringify(messages)); - ws.onmessage = (event) => { - const message = JSON.parse(event.data) as Message; - message.role = "model"; - onResponse(message); - }; - - await this.waitClose(ws); + return new Promise((resolve, reject) => { + this.connect() + .then((ws) => { + ws.send(JSON.stringify(messages)); + ws.onmessage = (event) => { + if (event.data.startsWith("_success_")) { + const message = JSON.parse(event.data.substring(10)) as Message; + message.role = "model"; + onResponse(message); + } else if (event.data.startsWith("_error_")) { + reject(Error(event.data.substring(8))); + } else { + reject(Error("Invalid response")); + } + }; + this.waitClose(ws) + .then(() => { + resolve(); + }) + .catch((e) => { + reject(e); + }); + }) + .catch((e) => { + reject(e); + }); + }); } } From e36b3342a7ca411ee18769664e4df02d7d692240 Mon Sep 17 00:00:00 2001 From: Silvan Melchior Date: Thu, 10 Aug 2023 19:09:59 +0200 Subject: [PATCH 03/23] gpt model in new python llm service --- services/llm/gpt_openai.py | 127 ++++++++++++++-- services/main_llm.py | 12 +- ui/app/api/chat/route.ts | 31 ---- ui/app/session/chat/chat_history.tsx | 2 +- ui/app/session/chat/header.tsx | 2 +- ui/app/session/communication/chat_round.tsx | 2 +- ui/app/session/communication/llm.tsx | 2 +- .../session/communication/message.tsx} | 3 - ui/app/session/session.tsx | 2 +- ui/llm/gpt.tsx | 142 ------------------ 10 files changed, 127 insertions(+), 198 deletions(-) delete mode 100644 ui/app/api/chat/route.ts rename ui/{llm/base.tsx => app/session/communication/message.tsx} (69%) delete mode 100644 ui/llm/gpt.tsx diff --git a/services/llm/gpt_openai.py b/services/llm/gpt_openai.py index 1f6c697..0c811c7 100644 --- a/services/llm/gpt_openai.py +++ b/services/llm/gpt_openai.py @@ -1,28 +1,129 @@ -import time +import re +import json +from typing import Generator import openai +from openai import OpenAIError from llm.base import BaseLLM, LLMException -from llm.types import Response +from llm.types import Message, Response from utils import get_env_var +FUNCTIONS = [ + { + "name": "run_python_code", + "description": "Runs arbitrary Python code and returns stdout and stderr. " + + "The code is executed in an interactive shell, imports and variables are preserved between calls. " + + "The environment has internet and file system access. " + + "The current working directory is shared with the user, so files can be exchanged. " + + "There are many libraries pre-installed, including numpy, pandas, matplotlib, and scikit-learn. " + + "You cannot show rich outputs like plots or images, but you can store them in the working directory and point the user to them. " + + "If the code runs too long, there will be a timeout.", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "The Python code to run", + }, + }, + "required": ["code"], + }, + }, +] + + +def msg_to_gpt_msg(msg: Message) -> dict: + if msg.role == "user": + return {"role": "user", "content": msg.text} + if msg.role == "model": + response = { + "role": "assistant", + "content": msg.text or None, + } + if msg.code: + response["function_call"] = { + "name": "run_python_code", + "arguments": json.dumps({"code": msg.code}), + } + return response + if msg.role == "interpreter": + return { + "role": "function", + "name": "run_python_code", + "content": msg.code_result, + } + raise ValueError(f"Invalid message role {msg.role}") + + +def fill_dict(dst: dict, chunk: dict): + for key in chunk: + if chunk[key] is None: + dst[key] = None + elif isinstance(chunk[key], dict): + if key not in dst: + dst[key] = {} + fill_dict(dst[key], chunk[key]) + elif isinstance(chunk[key], str): + if key not in dst: + dst[key] = "" + dst[key] += chunk[key] + else: + raise ValueError(f"Unsupported type {type(chunk[key])}") + + +def lazy_parse_args(args_partial): + args = args_partial + if not re.sub(r"\s+", "", args).endswith('"}'): + args += '"}' + + try: + args = json.loads(args) + if "code" not in args: + return None + except json.JSONDecodeError: + return None + + return args["code"] + + class GPTOpenAI(BaseLLM): def __init__(self, model_name: str): self._model_name = model_name openai.api_key = get_env_var("OPENAI_API_KEY") - def chat(self, history): - print(history, type(history)) + def chat(self, history: list[Message]) -> Generator[Response, None, None]: + messages = [msg_to_gpt_msg(msg) for msg in history] + + try: + chunk_generator = openai.ChatCompletion.create( + model=self._model_name, + messages=messages, + temperature=0, + functions=FUNCTIONS, + function_call="auto", + stream=True, + ) + except OpenAIError as e: + raise LLMException(str(e)) + + response = {} + previous_code = None + for chunk_all in chunk_generator: + chunk = chunk_all["choices"][0]["delta"] + fill_dict(response, chunk) - text = "I will print hello world" - for i in range(len(text)): - yield Response(text=text[: i + 1]) - time.sleep(0.1) + text = None + if "content" in response: + text = response["content"] - # raise LLMException("This is an error") + code = None + if "function_call" in response and "arguments" in response["function_call"]: + args = response["function_call"]["arguments"] + code = lazy_parse_args(args) + if code is None: + code = previous_code + previous_code = code - code = "print('hello world')" - for i in range(len(code)): - yield Response(text=text, code=code[: i + 1]) - time.sleep(0.1) + yield Response(text=text, code=code) diff --git a/services/main_llm.py b/services/main_llm.py index 73758d3..132697a 100644 --- a/services/main_llm.py +++ b/services/main_llm.py @@ -25,9 +25,8 @@ async def chat(websocket: WebSocket): except ws_exceptions: return - history = Request.validate_json(history) - try: + history = Request.validate_json(history) response_generator = llm.chat(history) try: for response in response_generator: @@ -39,9 +38,14 @@ async def chat(websocket: WebSocket): response_generator.close() return - except LLMException as e: + except Exception as e: try: - await websocket.send_text("_error_ " + str(e)) + if isinstance(e, LLMException): + error = str(e) + else: + print(e, type(e)) + error = "Internal error" + await websocket.send_text("_error_ " + error) await websocket.close() except ws_exceptions: return diff --git a/ui/app/api/chat/route.ts b/ui/app/api/chat/route.ts deleted file mode 100644 index 9706abf..0000000 --- a/ui/app/api/chat/route.ts +++ /dev/null @@ -1,31 +0,0 @@ -import { NextRequest, NextResponse } from "next/server"; -import { LLMException, Message } from "@/llm/base"; -import { chat as gptChat } from "@/llm/gpt"; - -const DEFAULT_MODEL = "gpt:gpt-4"; - -function getLLM() { - const setting = process.env.LLM ?? DEFAULT_MODEL; - if (setting.startsWith("gpt:")) { - return gptChat(setting.slice(4)); - } - throw new LLMException("Invalid LLM setting"); -} - -export async function POST( - request: NextRequest, -): Promise | Response> { - try { - const history = (await request.json()) as Message[]; - const chat = getLLM(); - const response = await chat(history); - return NextResponse.json(response); - } catch (e) { - if (e instanceof LLMException) { - return new Response(e.message, { status: 500 }); - } - throw e; - } -} - -// TODO: REMOVE ALL OF THIS, ALSO llm FOLDER with base and gpt4 diff --git a/ui/app/session/chat/chat_history.tsx b/ui/app/session/chat/chat_history.tsx index b7d5b58..488ccd8 100644 --- a/ui/app/session/chat/chat_history.tsx +++ b/ui/app/session/chat/chat_history.tsx @@ -1,4 +1,4 @@ -import { Message } from "@/llm/base"; +import { Message } from "@/app/session/communication/message"; import { TbUser } from "react-icons/tb"; import Image from "next/image"; import robotIcon from "../../icon.png"; diff --git a/ui/app/session/chat/header.tsx b/ui/app/session/chat/header.tsx index 502f8de..cb1f9ef 100644 --- a/ui/app/session/chat/header.tsx +++ b/ui/app/session/chat/header.tsx @@ -14,7 +14,7 @@ export function Header({ {error !== null && (
-
Error: {error}
+
Error: {error}
-
+ +
); diff --git a/ui/app/session/communication/chat_round.tsx b/ui/app/session/communication/chat_round.tsx index 8850584..b94b1df 100644 --- a/ui/app/session/communication/chat_round.tsx +++ b/ui/app/session/communication/chat_round.tsx @@ -62,9 +62,15 @@ export class ChatRound { for (; round < 10; round++) { const code = response.code; if (code !== undefined) { - await this.approveIn(code); - const result = await this.executeCode(code); - await this.approveOut(result); + const approvedIn = await this.approveIn(code); + let result = "ERROR: User did not approve code execution!"; + if (approvedIn) { + const resultCode = await this.executeCode(code); + const approvedOut = await this.approveOut(resultCode); + if (approvedOut) { + result = resultCode; + } + } response = await this.sendResult(result); } else { this.setState("not active"); @@ -78,7 +84,7 @@ export class ChatRound { private approveIn = async (code: string) => { this.setState("waiting for approval"); - await this.approverIn.getApproval(); + return await this.approverIn.getApproval(); }; private executeCode = async (code: string): Promise => { @@ -94,8 +100,9 @@ export class ChatRound { : result; this.setCodeResult(resultText); if (!emptyAutoApprove) { - await this.approverOut.getApproval(); + return await this.approverOut.getApproval(); } + return true; }; private sendResult = async (result: string) => { From b534ac25b0ffeea93440c19d6b3d14447766cbf2 Mon Sep 17 00:00:00 2001 From: Silvan Melchior Date: Sun, 13 Aug 2023 15:28:21 +0200 Subject: [PATCH 12/23] new thinking animation --- ui/app/session/chat/chat_history.tsx | 76 ++++++++++++++++----------- ui/app/session/chat/chat_input.tsx | 11 +--- ui/app/session/session.tsx | 10 +++- ui/public/thinking.gif | Bin 5539 -> 9529 bytes 4 files changed, 55 insertions(+), 42 deletions(-) diff --git a/ui/app/session/chat/chat_history.tsx b/ui/app/session/chat/chat_history.tsx index 9f14d13..26fad5a 100644 --- a/ui/app/session/chat/chat_history.tsx +++ b/ui/app/session/chat/chat_history.tsx @@ -2,7 +2,13 @@ import { Message } from "@/app/session/communication/message"; import { TbUser } from "react-icons/tb"; import React from "react"; -export default function ChatHistory({ history }: { history: Message[] }) { +export default function ChatHistory({ + history, + thinking, +}: { + history: Message[]; + thinking: boolean; +}) { const bottomRef = React.useRef(null); React.useEffect(() => { @@ -11,40 +17,48 @@ export default function ChatHistory({ history }: { history: Message[] }) { }, 100); }, [history]); + const historyFiltered = history.filter( + (msg, idx) => + msg.role === "user" || + (msg.role === "model" && + (msg.text !== undefined || (thinking && idx == history.length - 1))), + ); + return (
- {history - .filter( - (msg) => - msg.role === "user" || - (msg.role === "model" && msg.text !== undefined), - ) - .map((msg, idx) => ( -
- {msg.role === "model" ? ( -
- robot -
- ) : ( -
- )} -
- {msg.text} + {historyFiltered.map((msg, idx) => ( +
+ {msg.role === "model" ? ( +
+ robot + {thinking && idx === historyFiltered.length - 1 && ( + thinking + )}
- {msg.role === "user" ? ( -
- -
- ) : ( -
- )} + ) : ( +
+ )} +
+ {msg.text === "" || msg.text === undefined ? "..." : msg.text}
- ))} + {msg.role === "user" ? ( +
+ +
+ ) : ( +
+ )} +
+ ))}
); diff --git a/ui/app/session/chat/chat_input.tsx b/ui/app/session/chat/chat_input.tsx index b539b84..18de2fa 100644 --- a/ui/app/session/chat/chat_input.tsx +++ b/ui/app/session/chat/chat_input.tsx @@ -4,12 +4,10 @@ import { BiSend } from "react-icons/bi"; export default function ChatInput({ innerRef, disabled, - llmAnimation, onMessage, }: { innerRef: React.MutableRefObject; disabled: boolean; - llmAnimation: boolean; onMessage: (message: string) => void; }) { const [message, setMessage] = React.useState(""); @@ -34,8 +32,8 @@ export default function ChatInput({ }; return ( -
-
+
+