Skip to content

Commit

Permalink
Switch LLM config objects from TypedDict to dataclass.
Browse files Browse the repository at this point in the history
  • Loading branch information
liffiton committed Jul 30, 2024
1 parent 0e51551 commit 8edbb8f
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 58 deletions.
36 changes: 18 additions & 18 deletions src/codehelp/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from gened.classes import switch_class
from gened.db import get_db
from gened.openai import LLMDict, get_completion, with_llm
from gened.openai import LLMConfig, get_completion, with_llm
from gened.queries import get_history, get_query
from gened.testing.mocks import mock_async_completion
from werkzeug.wrappers.response import Response
Expand Down Expand Up @@ -119,15 +119,15 @@ def help_view(query_id: int) -> str | Response:
return render_template("help_view.html", query=query_row, responses=responses, history=history, topics=topics)


async def run_query_prompts(llm_dict: LLMDict, context: ContextConfig | None, code: str, error: str, issue: str) -> tuple[list[dict[str, str]], dict[str, str]]:
async def run_query_prompts(llm: LLMConfig, context: ContextConfig | None, code: str, error: str, issue: str) -> tuple[list[dict[str, str]], dict[str, str]]:
''' Run the given query against the coding help system of prompts.
Returns a tuple containing:
1) A list of response objects from the OpenAI completion (to be stored in the database)
2) A dictionary of response text, potentially including keys 'insufficient' and 'main'.
'''
client = llm_dict['client']
model = llm_dict['model']
client = llm.client
model = llm.model

context_str = context.prompt_str() if context is not None else None

Expand Down Expand Up @@ -176,10 +176,10 @@ async def run_query_prompts(llm_dict: LLMDict, context: ContextConfig | None, co
return responses, {'insufficient': response_sufficient_txt, 'main': response_txt}


def run_query(llm_dict: LLMDict, context: ContextConfig | None, code: str, error: str, issue: str) -> int:
def run_query(llm: LLMConfig, context: ContextConfig | None, code: str, error: str, issue: str) -> int:
query_id = record_query(context, code, error, issue)

responses, texts = asyncio.run(run_query_prompts(llm_dict, context, code, error, issue))
responses, texts = asyncio.run(run_query_prompts(llm, context, code, error, issue))

record_response(query_id, responses, texts)

Expand Down Expand Up @@ -225,7 +225,7 @@ def record_response(query_id: int, responses: list[dict[str, str]], texts: dict[
@login_required
@class_enabled_required
@with_llm()
def help_request(llm_dict: LLMDict) -> Response:
def help_request(llm: LLMConfig) -> Response:
if 'context' in request.form:
context = get_context_by_name(request.form['context'])
if context is None:
Expand All @@ -239,15 +239,15 @@ def help_request(llm_dict: LLMDict) -> Response:

# TODO: limit length of code/error/issue

query_id = run_query(llm_dict, context, code, error, issue)
query_id = run_query(llm, context, code, error, issue)

return redirect(url_for(".help_view", query_id=query_id))


@bp.route("/load_test", methods=["POST"])
@admin_required
@with_llm(use_system_key=True) # get a populated LLMDict
def load_test(llm_dict: LLMDict) -> Response:
@with_llm(use_system_key=True) # get a populated LLMConfig
def load_test(llm: LLMConfig) -> Response:
# Require that we're logged in as the load_test admin user
auth = get_auth()
if auth['display_name'] != 'load_test':
Expand All @@ -263,7 +263,7 @@ def load_test(llm_dict: LLMDict) -> Response:
# simulate a 2 second delay for a network request
mocked.side_effect = mock_async_completion(delay=2.0)

query_id = run_query(llm_dict, context, code, error, issue)
query_id = run_query(llm, context, code, error, issue)

return redirect(url_for(".help_view", query_id=query_id))

Expand All @@ -285,8 +285,8 @@ def post_helpful() -> str:
@login_required
@tester_required
@with_llm()
def get_topics_html(llm_dict: LLMDict, query_id: int) -> str:
topics = get_topics(llm_dict, query_id)
def get_topics_html(llm: LLMConfig, query_id: int) -> str:
topics = get_topics(llm, query_id)
if not topics:
return render_template("topics_fragment.html", error=True)
else:
Expand All @@ -297,12 +297,12 @@ def get_topics_html(llm_dict: LLMDict, query_id: int) -> str:
@login_required
@tester_required
@with_llm()
def get_topics_raw(llm_dict: LLMDict, query_id: int) -> list[str]:
topics = get_topics(llm_dict, query_id)
def get_topics_raw(llm: LLMConfig, query_id: int) -> list[str]:
topics = get_topics(llm, query_id)
return topics


def get_topics(llm_dict: LLMDict, query_id: int) -> list[str]:
def get_topics(llm: LLMConfig, query_id: int) -> list[str]:
query_row, responses = get_query(query_id)

if not query_row or not responses or 'main' not in responses:
Expand All @@ -317,8 +317,8 @@ def get_topics(llm_dict: LLMDict, query_id: int) -> list[str]:
)

response, response_txt = asyncio.run(get_completion(
client=llm_dict['client'],
model=llm_dict['model'],
client=llm.client,
model=llm.model,
messages=messages,
))

Expand Down
24 changes: 12 additions & 12 deletions src/codehelp/tutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from gened.classes import switch_class
from gened.db import get_db
from gened.experiments import experiment_required
from gened.openai import LLMDict, get_completion, with_llm
from gened.openai import LLMConfig, get_completion, with_llm
from gened.queries import get_query
from openai.types.chat import ChatCompletionMessageParam
from werkzeug.wrappers.response import Response
Expand Down Expand Up @@ -92,7 +92,7 @@ def tutor_form(class_id: int | None = None, ctx_name: str | None = None) -> str

@bp.route("/chat/create", methods=["POST"])
@with_llm()
def start_chat(llm_dict: LLMDict) -> Response:
def start_chat(llm: LLMConfig) -> Response:
topic = request.form['topic']

if 'context' in request.form:
Expand All @@ -105,14 +105,14 @@ def start_chat(llm_dict: LLMDict) -> Response:

chat_id = create_chat(topic, context)

run_chat_round(llm_dict, chat_id)
run_chat_round(llm, chat_id)

return redirect(url_for("tutor.chat_interface", chat_id=chat_id))


@bp.route("/chat/create_from_query", methods=["POST"])
@with_llm()
def start_chat_from_query(llm_dict: LLMDict) -> Response:
def start_chat_from_query(llm: LLMConfig) -> Response:
topic = request.form['topic']

# build context from the specified query
Expand All @@ -123,7 +123,7 @@ def start_chat_from_query(llm_dict: LLMDict) -> Response:

chat_id = create_chat(topic, context)

run_chat_round(llm_dict, chat_id)
run_chat_round(llm, chat_id)

return redirect(url_for("tutor.chat_interface", chat_id=chat_id))

Expand Down Expand Up @@ -209,7 +209,7 @@ def get_chat(chat_id: int) -> tuple[list[ChatCompletionMessageParam], str, str,
return chat, topic, context_name, context_string


def get_response(llm_dict: LLMDict, chat: list[ChatCompletionMessageParam]) -> tuple[dict[str, str], str]:
def get_response(llm: LLMConfig, chat: list[ChatCompletionMessageParam]) -> tuple[dict[str, str], str]:
''' Get a new 'assistant' completion for the specified chat.
Parameters:
Expand All @@ -221,8 +221,8 @@ def get_response(llm_dict: LLMDict, chat: list[ChatCompletionMessageParam]) -> t
2) The response text.
'''
response, text = asyncio.run(get_completion(
client=llm_dict['client'],
model=llm_dict['model'],
client=llm.client,
model=llm.model,
messages=chat,
))

Expand All @@ -238,7 +238,7 @@ def save_chat(chat_id: int, chat: list[ChatCompletionMessageParam]) -> None:
db.commit()


def run_chat_round(llm_dict: LLMDict, chat_id: int, message: str|None = None) -> None:
def run_chat_round(llm: LLMConfig, chat_id: int, message: str|None = None) -> None:
# Get the specified chat
try:
chat, topic, context_name, context_string = get_chat(chat_id)
Expand All @@ -262,7 +262,7 @@ def run_chat_round(llm_dict: LLMDict, chat_id: int, message: str|None = None) ->
{'role': 'assistant', 'content': prompts.tutor_monologue},
]

response_obj, response_txt = get_response(llm_dict, expanded_chat)
response_obj, response_txt = get_response(llm, expanded_chat)

# Update the chat w/ the response
chat.append({
Expand All @@ -274,14 +274,14 @@ def run_chat_round(llm_dict: LLMDict, chat_id: int, message: str|None = None) ->

@bp.route("/message", methods=["POST"])
@with_llm()
def new_message(llm_dict: LLMDict) -> Response:
def new_message(llm: LLMConfig) -> Response:
chat_id = int(request.form["id"])
new_msg = request.form["message"]

# TODO: limit length

# Run a round of the chat with the given message.
run_chat_round(llm_dict, chat_id, new_msg)
run_chat_round(llm, chat_id, new_msg)

# Send the user back to the now-updated chat view
return redirect(url_for("tutor.chat_interface", chat_id=chat_id))
Expand Down
8 changes: 4 additions & 4 deletions src/gened/class_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from .auth import get_auth, instructor_required
from .db import get_db
from .openai import LLMDict, get_completion, get_models, with_llm
from .openai import LLMConfig, get_completion, get_models, with_llm
from .tz import date_is_past

bp = Blueprint('class_config', __name__, url_prefix="/instructor/config", template_folder='templates')
Expand Down Expand Up @@ -73,10 +73,10 @@ def config_form() -> str:
@bp.route("/test_llm")
@instructor_required
@with_llm()
def test_llm(llm_dict: LLMDict) -> str:
def test_llm(llm: LLMConfig) -> str:
response, response_txt = asyncio.run(get_completion(
client=llm_dict['client'],
model=llm_dict['model'],
client=llm.client,
model=llm.model,
prompt="Please write 'OK'"
))

Expand Down
34 changes: 18 additions & 16 deletions src/gened/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
# SPDX-License-Identifier: AGPL-3.0-only

from collections.abc import Callable
from dataclasses import dataclass
from functools import wraps
from sqlite3 import Row
from typing import ParamSpec, TypedDict, TypeVar
from typing import ParamSpec, TypeVar

import openai
from flask import current_app, flash, render_template
Expand All @@ -28,12 +29,13 @@ class NoTokensError(Exception):
pass


class LLMDict(TypedDict):
@dataclass(frozen=True)
class LLMConfig:
client: AsyncOpenAI
model: str


def _get_llm(*, use_system_key: bool) -> LLMDict:
def _get_llm(*, use_system_key: bool) -> LLMConfig:
''' Get model details and an initialized OpenAI client based on the
arguments and the current user and class.
Expand All @@ -51,24 +53,24 @@ def _get_llm(*, use_system_key: bool) -> LLMDict:
key is used with GPT-3.5.
Returns:
LLMDict dictionary with an OpenAI client and model name.
LLMConfig with an OpenAI client and model name.
Raises various exceptions in cases where a key and model are not available.
'''
db = get_db()

def make_system_client() -> LLMDict:
def make_system_client() -> LLMConfig:
""" Factory function to initialize a default client (using the system key)
only if/when needed.
"""
# Get the systemwide default model (TODO: better control than just id=2)
model_row = db.execute("SELECT models.model FROM models WHERE models.id=2").fetchone()
system_model = model_row['model']
system_key = current_app.config["OPENAI_API_KEY"]
return {
'client': AsyncOpenAI(api_key=system_key),
'model': system_model,
}
return LLMConfig(
client=AsyncOpenAI(api_key=system_key),
model=system_model,
)

if use_system_key:
return make_system_client()
Expand Down Expand Up @@ -102,10 +104,10 @@ def make_system_client() -> LLMDict:
raise NoKeyFoundError

api_key = class_row['openai_key']
return {
'client': AsyncOpenAI(api_key=api_key),
'model': class_row['model'],
}
return LLMConfig(
client=AsyncOpenAI(api_key=api_key),
model=class_row['model'],
)

# Get user data for tokens, auth_provider
user_row = db.execute("""
Expand Down Expand Up @@ -139,7 +141,7 @@ def make_system_client() -> LLMDict:
def with_llm(*, use_system_key: bool = False) -> Callable[[Callable[P, R]], Callable[P, str | R]]:
'''Decorate a view function that requires an LLM and API key.
Assigns an 'llm_dict' named argument.
Assigns an 'llm' named argument.
Checks that the current user has access to an LLM and API key (configured
in an LTI consumer or user-created class), then passes the appropriate
Expand All @@ -153,7 +155,7 @@ def decorator(f: Callable[P, R]) -> Callable[P, str | R]:
@wraps(f)
def decorated_function(*args: P.args, **kwargs: P.kwargs) -> str | R:
try:
llm_dict = _get_llm(use_system_key=use_system_key)
llm = _get_llm(use_system_key=use_system_key)
except ClassDisabledError:
flash("Error: The current class is archived or disabled. Request cannot be submitted.")
return render_template("error.html")
Expand All @@ -164,7 +166,7 @@ def decorated_function(*args: P.args, **kwargs: P.kwargs) -> str | R:
flash("You have used all of your free tokens. If you are using this application in a class, please connect using the link from your class. Otherwise, you can create a class and add an OpenAI API key or contact us if you want to continue using this application.", "warning")
return render_template("error.html")

kwargs['llm_dict'] = llm_dict
kwargs['llm'] = llm
return f(*args, **kwargs)
return decorated_function
return decorator
Expand Down
16 changes: 8 additions & 8 deletions src/starburst/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from flask import Blueprint, redirect, render_template, request, url_for
from gened.auth import class_enabled_required, get_auth, login_required
from gened.db import get_db
from gened.openai import LLMDict, get_completion, with_llm
from gened.openai import LLMConfig, get_completion, with_llm
from gened.queries import get_history, get_query
from werkzeug.wrappers.response import Response

Expand Down Expand Up @@ -42,7 +42,7 @@ def help_view(query_id: int) -> str:
return render_template("help_view.html", query=query_row, responses=responses, history=history)


async def run_query_prompts(llm_dict: LLMDict, assignment: str, topics: str) -> tuple[list[dict[str, str]], dict[str, str]]:
async def run_query_prompts(llm: LLMConfig, assignment: str, topics: str) -> tuple[list[dict[str, str]], dict[str, str]]:
''' Run the given query against the coding help system of prompts.
Returns a tuple containing:
Expand All @@ -51,8 +51,8 @@ async def run_query_prompts(llm_dict: LLMDict, assignment: str, topics: str) ->
'''
task_main = asyncio.create_task(
get_completion(
client=llm_dict['client'],
model=llm_dict['model'],
client=llm.client,
model=llm.model,
prompt=prompts.make_main_prompt(assignment, topics),
)
)
Expand All @@ -70,10 +70,10 @@ async def run_query_prompts(llm_dict: LLMDict, assignment: str, topics: str) ->
return responses, {'main': response_txt}


def run_query(llm_dict: LLMDict, assignment: str, topics: str) -> int:
def run_query(llm: LLMConfig, assignment: str, topics: str) -> int:
query_id = record_query(assignment, topics)

responses, texts = asyncio.run(run_query_prompts(llm_dict, assignment, topics))
responses, texts = asyncio.run(run_query_prompts(llm, assignment, topics))

record_response(query_id, responses, texts)

Expand Down Expand Up @@ -110,11 +110,11 @@ def record_response(query_id: int, responses: list[dict[str, str]], texts: dict[
@login_required
@class_enabled_required
@with_llm(use_system_key=True)
def help_request(llm_dict: LLMDict) -> Response:
def help_request(llm: LLMConfig) -> Response:
assignment = request.form["assignment"]
topics = request.form["topics"]

query_id = run_query(llm_dict, assignment, topics)
query_id = run_query(llm, assignment, topics)

return redirect(url_for(".help_view", query_id=query_id))

Expand Down

0 comments on commit 8edbb8f

Please sign in to comment.