From 8edbb8f3dd8cdd8eee742359937da81bd99f3162 Mon Sep 17 00:00:00 2001 From: Mark Liffiton Date: Tue, 30 Jul 2024 11:39:40 -0500 Subject: [PATCH] Switch LLM config objects from TypedDict to dataclass. --- src/codehelp/helper.py | 36 ++++++++++++++++++------------------ src/codehelp/tutor.py | 24 ++++++++++++------------ src/gened/class_config.py | 8 ++++---- src/gened/openai.py | 34 ++++++++++++++++++---------------- src/starburst/helper.py | 16 ++++++++-------- 5 files changed, 60 insertions(+), 58 deletions(-) diff --git a/src/codehelp/helper.py b/src/codehelp/helper.py index 1dda6dc..c942875 100644 --- a/src/codehelp/helper.py +++ b/src/codehelp/helper.py @@ -25,7 +25,7 @@ ) from gened.classes import switch_class from gened.db import get_db -from gened.openai import LLMDict, get_completion, with_llm +from gened.openai import LLMConfig, get_completion, with_llm from gened.queries import get_history, get_query from gened.testing.mocks import mock_async_completion from werkzeug.wrappers.response import Response @@ -119,15 +119,15 @@ def help_view(query_id: int) -> str | Response: return render_template("help_view.html", query=query_row, responses=responses, history=history, topics=topics) -async def run_query_prompts(llm_dict: LLMDict, context: ContextConfig | None, code: str, error: str, issue: str) -> tuple[list[dict[str, str]], dict[str, str]]: +async def run_query_prompts(llm: LLMConfig, context: ContextConfig | None, code: str, error: str, issue: str) -> tuple[list[dict[str, str]], dict[str, str]]: ''' Run the given query against the coding help system of prompts. Returns a tuple containing: 1) A list of response objects from the OpenAI completion (to be stored in the database) 2) A dictionary of response text, potentially including keys 'insufficient' and 'main'. ''' - client = llm_dict['client'] - model = llm_dict['model'] + client = llm.client + model = llm.model context_str = context.prompt_str() if context is not None else None @@ -176,10 +176,10 @@ async def run_query_prompts(llm_dict: LLMDict, context: ContextConfig | None, co return responses, {'insufficient': response_sufficient_txt, 'main': response_txt} -def run_query(llm_dict: LLMDict, context: ContextConfig | None, code: str, error: str, issue: str) -> int: +def run_query(llm: LLMConfig, context: ContextConfig | None, code: str, error: str, issue: str) -> int: query_id = record_query(context, code, error, issue) - responses, texts = asyncio.run(run_query_prompts(llm_dict, context, code, error, issue)) + responses, texts = asyncio.run(run_query_prompts(llm, context, code, error, issue)) record_response(query_id, responses, texts) @@ -225,7 +225,7 @@ def record_response(query_id: int, responses: list[dict[str, str]], texts: dict[ @login_required @class_enabled_required @with_llm() -def help_request(llm_dict: LLMDict) -> Response: +def help_request(llm: LLMConfig) -> Response: if 'context' in request.form: context = get_context_by_name(request.form['context']) if context is None: @@ -239,15 +239,15 @@ def help_request(llm_dict: LLMDict) -> Response: # TODO: limit length of code/error/issue - query_id = run_query(llm_dict, context, code, error, issue) + query_id = run_query(llm, context, code, error, issue) return redirect(url_for(".help_view", query_id=query_id)) @bp.route("/load_test", methods=["POST"]) @admin_required -@with_llm(use_system_key=True) # get a populated LLMDict -def load_test(llm_dict: LLMDict) -> Response: +@with_llm(use_system_key=True) # get a populated LLMConfig +def load_test(llm: LLMConfig) -> Response: # Require that we're logged in as the load_test admin user auth = get_auth() if auth['display_name'] != 'load_test': @@ -263,7 +263,7 @@ def load_test(llm_dict: LLMDict) -> Response: # simulate a 2 second delay for a network request mocked.side_effect = mock_async_completion(delay=2.0) - query_id = run_query(llm_dict, context, code, error, issue) + query_id = run_query(llm, context, code, error, issue) return redirect(url_for(".help_view", query_id=query_id)) @@ -285,8 +285,8 @@ def post_helpful() -> str: @login_required @tester_required @with_llm() -def get_topics_html(llm_dict: LLMDict, query_id: int) -> str: - topics = get_topics(llm_dict, query_id) +def get_topics_html(llm: LLMConfig, query_id: int) -> str: + topics = get_topics(llm, query_id) if not topics: return render_template("topics_fragment.html", error=True) else: @@ -297,12 +297,12 @@ def get_topics_html(llm_dict: LLMDict, query_id: int) -> str: @login_required @tester_required @with_llm() -def get_topics_raw(llm_dict: LLMDict, query_id: int) -> list[str]: - topics = get_topics(llm_dict, query_id) +def get_topics_raw(llm: LLMConfig, query_id: int) -> list[str]: + topics = get_topics(llm, query_id) return topics -def get_topics(llm_dict: LLMDict, query_id: int) -> list[str]: +def get_topics(llm: LLMConfig, query_id: int) -> list[str]: query_row, responses = get_query(query_id) if not query_row or not responses or 'main' not in responses: @@ -317,8 +317,8 @@ def get_topics(llm_dict: LLMDict, query_id: int) -> list[str]: ) response, response_txt = asyncio.run(get_completion( - client=llm_dict['client'], - model=llm_dict['model'], + client=llm.client, + model=llm.model, messages=messages, )) diff --git a/src/codehelp/tutor.py b/src/codehelp/tutor.py index 79d390b..5aaeebe 100644 --- a/src/codehelp/tutor.py +++ b/src/codehelp/tutor.py @@ -21,7 +21,7 @@ from gened.classes import switch_class from gened.db import get_db from gened.experiments import experiment_required -from gened.openai import LLMDict, get_completion, with_llm +from gened.openai import LLMConfig, get_completion, with_llm from gened.queries import get_query from openai.types.chat import ChatCompletionMessageParam from werkzeug.wrappers.response import Response @@ -92,7 +92,7 @@ def tutor_form(class_id: int | None = None, ctx_name: str | None = None) -> str @bp.route("/chat/create", methods=["POST"]) @with_llm() -def start_chat(llm_dict: LLMDict) -> Response: +def start_chat(llm: LLMConfig) -> Response: topic = request.form['topic'] if 'context' in request.form: @@ -105,14 +105,14 @@ def start_chat(llm_dict: LLMDict) -> Response: chat_id = create_chat(topic, context) - run_chat_round(llm_dict, chat_id) + run_chat_round(llm, chat_id) return redirect(url_for("tutor.chat_interface", chat_id=chat_id)) @bp.route("/chat/create_from_query", methods=["POST"]) @with_llm() -def start_chat_from_query(llm_dict: LLMDict) -> Response: +def start_chat_from_query(llm: LLMConfig) -> Response: topic = request.form['topic'] # build context from the specified query @@ -123,7 +123,7 @@ def start_chat_from_query(llm_dict: LLMDict) -> Response: chat_id = create_chat(topic, context) - run_chat_round(llm_dict, chat_id) + run_chat_round(llm, chat_id) return redirect(url_for("tutor.chat_interface", chat_id=chat_id)) @@ -209,7 +209,7 @@ def get_chat(chat_id: int) -> tuple[list[ChatCompletionMessageParam], str, str, return chat, topic, context_name, context_string -def get_response(llm_dict: LLMDict, chat: list[ChatCompletionMessageParam]) -> tuple[dict[str, str], str]: +def get_response(llm: LLMConfig, chat: list[ChatCompletionMessageParam]) -> tuple[dict[str, str], str]: ''' Get a new 'assistant' completion for the specified chat. Parameters: @@ -221,8 +221,8 @@ def get_response(llm_dict: LLMDict, chat: list[ChatCompletionMessageParam]) -> t 2) The response text. ''' response, text = asyncio.run(get_completion( - client=llm_dict['client'], - model=llm_dict['model'], + client=llm.client, + model=llm.model, messages=chat, )) @@ -238,7 +238,7 @@ def save_chat(chat_id: int, chat: list[ChatCompletionMessageParam]) -> None: db.commit() -def run_chat_round(llm_dict: LLMDict, chat_id: int, message: str|None = None) -> None: +def run_chat_round(llm: LLMConfig, chat_id: int, message: str|None = None) -> None: # Get the specified chat try: chat, topic, context_name, context_string = get_chat(chat_id) @@ -262,7 +262,7 @@ def run_chat_round(llm_dict: LLMDict, chat_id: int, message: str|None = None) -> {'role': 'assistant', 'content': prompts.tutor_monologue}, ] - response_obj, response_txt = get_response(llm_dict, expanded_chat) + response_obj, response_txt = get_response(llm, expanded_chat) # Update the chat w/ the response chat.append({ @@ -274,14 +274,14 @@ def run_chat_round(llm_dict: LLMDict, chat_id: int, message: str|None = None) -> @bp.route("/message", methods=["POST"]) @with_llm() -def new_message(llm_dict: LLMDict) -> Response: +def new_message(llm: LLMConfig) -> Response: chat_id = int(request.form["id"]) new_msg = request.form["message"] # TODO: limit length # Run a round of the chat with the given message. - run_chat_round(llm_dict, chat_id, new_msg) + run_chat_round(llm, chat_id, new_msg) # Send the user back to the now-updated chat view return redirect(url_for("tutor.chat_interface", chat_id=chat_id)) diff --git a/src/gened/class_config.py b/src/gened/class_config.py index 101b7e1..139d0a7 100644 --- a/src/gened/class_config.py +++ b/src/gened/class_config.py @@ -14,7 +14,7 @@ from .auth import get_auth, instructor_required from .db import get_db -from .openai import LLMDict, get_completion, get_models, with_llm +from .openai import LLMConfig, get_completion, get_models, with_llm from .tz import date_is_past bp = Blueprint('class_config', __name__, url_prefix="/instructor/config", template_folder='templates') @@ -73,10 +73,10 @@ def config_form() -> str: @bp.route("/test_llm") @instructor_required @with_llm() -def test_llm(llm_dict: LLMDict) -> str: +def test_llm(llm: LLMConfig) -> str: response, response_txt = asyncio.run(get_completion( - client=llm_dict['client'], - model=llm_dict['model'], + client=llm.client, + model=llm.model, prompt="Please write 'OK'" )) diff --git a/src/gened/openai.py b/src/gened/openai.py index a8bb28b..fad9a53 100644 --- a/src/gened/openai.py +++ b/src/gened/openai.py @@ -3,9 +3,10 @@ # SPDX-License-Identifier: AGPL-3.0-only from collections.abc import Callable +from dataclasses import dataclass from functools import wraps from sqlite3 import Row -from typing import ParamSpec, TypedDict, TypeVar +from typing import ParamSpec, TypeVar import openai from flask import current_app, flash, render_template @@ -28,12 +29,13 @@ class NoTokensError(Exception): pass -class LLMDict(TypedDict): +@dataclass(frozen=True) +class LLMConfig: client: AsyncOpenAI model: str -def _get_llm(*, use_system_key: bool) -> LLMDict: +def _get_llm(*, use_system_key: bool) -> LLMConfig: ''' Get model details and an initialized OpenAI client based on the arguments and the current user and class. @@ -51,13 +53,13 @@ def _get_llm(*, use_system_key: bool) -> LLMDict: key is used with GPT-3.5. Returns: - LLMDict dictionary with an OpenAI client and model name. + LLMConfig with an OpenAI client and model name. Raises various exceptions in cases where a key and model are not available. ''' db = get_db() - def make_system_client() -> LLMDict: + def make_system_client() -> LLMConfig: """ Factory function to initialize a default client (using the system key) only if/when needed. """ @@ -65,10 +67,10 @@ def make_system_client() -> LLMDict: model_row = db.execute("SELECT models.model FROM models WHERE models.id=2").fetchone() system_model = model_row['model'] system_key = current_app.config["OPENAI_API_KEY"] - return { - 'client': AsyncOpenAI(api_key=system_key), - 'model': system_model, - } + return LLMConfig( + client=AsyncOpenAI(api_key=system_key), + model=system_model, + ) if use_system_key: return make_system_client() @@ -102,10 +104,10 @@ def make_system_client() -> LLMDict: raise NoKeyFoundError api_key = class_row['openai_key'] - return { - 'client': AsyncOpenAI(api_key=api_key), - 'model': class_row['model'], - } + return LLMConfig( + client=AsyncOpenAI(api_key=api_key), + model=class_row['model'], + ) # Get user data for tokens, auth_provider user_row = db.execute(""" @@ -139,7 +141,7 @@ def make_system_client() -> LLMDict: def with_llm(*, use_system_key: bool = False) -> Callable[[Callable[P, R]], Callable[P, str | R]]: '''Decorate a view function that requires an LLM and API key. - Assigns an 'llm_dict' named argument. + Assigns an 'llm' named argument. Checks that the current user has access to an LLM and API key (configured in an LTI consumer or user-created class), then passes the appropriate @@ -153,7 +155,7 @@ def decorator(f: Callable[P, R]) -> Callable[P, str | R]: @wraps(f) def decorated_function(*args: P.args, **kwargs: P.kwargs) -> str | R: try: - llm_dict = _get_llm(use_system_key=use_system_key) + llm = _get_llm(use_system_key=use_system_key) except ClassDisabledError: flash("Error: The current class is archived or disabled. Request cannot be submitted.") return render_template("error.html") @@ -164,7 +166,7 @@ def decorated_function(*args: P.args, **kwargs: P.kwargs) -> str | R: flash("You have used all of your free tokens. If you are using this application in a class, please connect using the link from your class. Otherwise, you can create a class and add an OpenAI API key or contact us if you want to continue using this application.", "warning") return render_template("error.html") - kwargs['llm_dict'] = llm_dict + kwargs['llm'] = llm return f(*args, **kwargs) return decorated_function return decorator diff --git a/src/starburst/helper.py b/src/starburst/helper.py index c82921b..f953d52 100644 --- a/src/starburst/helper.py +++ b/src/starburst/helper.py @@ -8,7 +8,7 @@ from flask import Blueprint, redirect, render_template, request, url_for from gened.auth import class_enabled_required, get_auth, login_required from gened.db import get_db -from gened.openai import LLMDict, get_completion, with_llm +from gened.openai import LLMConfig, get_completion, with_llm from gened.queries import get_history, get_query from werkzeug.wrappers.response import Response @@ -42,7 +42,7 @@ def help_view(query_id: int) -> str: return render_template("help_view.html", query=query_row, responses=responses, history=history) -async def run_query_prompts(llm_dict: LLMDict, assignment: str, topics: str) -> tuple[list[dict[str, str]], dict[str, str]]: +async def run_query_prompts(llm: LLMConfig, assignment: str, topics: str) -> tuple[list[dict[str, str]], dict[str, str]]: ''' Run the given query against the coding help system of prompts. Returns a tuple containing: @@ -51,8 +51,8 @@ async def run_query_prompts(llm_dict: LLMDict, assignment: str, topics: str) -> ''' task_main = asyncio.create_task( get_completion( - client=llm_dict['client'], - model=llm_dict['model'], + client=llm.client, + model=llm.model, prompt=prompts.make_main_prompt(assignment, topics), ) ) @@ -70,10 +70,10 @@ async def run_query_prompts(llm_dict: LLMDict, assignment: str, topics: str) -> return responses, {'main': response_txt} -def run_query(llm_dict: LLMDict, assignment: str, topics: str) -> int: +def run_query(llm: LLMConfig, assignment: str, topics: str) -> int: query_id = record_query(assignment, topics) - responses, texts = asyncio.run(run_query_prompts(llm_dict, assignment, topics)) + responses, texts = asyncio.run(run_query_prompts(llm, assignment, topics)) record_response(query_id, responses, texts) @@ -110,11 +110,11 @@ def record_response(query_id: int, responses: list[dict[str, str]], texts: dict[ @login_required @class_enabled_required @with_llm(use_system_key=True) -def help_request(llm_dict: LLMDict) -> Response: +def help_request(llm: LLMConfig) -> Response: assignment = request.form["assignment"] topics = request.form["topics"] - query_id = run_query(llm_dict, assignment, topics) + query_id = run_query(llm, assignment, topics) return redirect(url_for(".help_view", query_id=query_id))