Skip to content

Commit

Permalink
Simplify interface for getting contexts.
Browse files Browse the repository at this point in the history
  • Loading branch information
liffiton committed Jul 20, 2024
1 parent 4595be2 commit b548eb9
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 31 deletions.
25 changes: 11 additions & 14 deletions src/codehelp/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
)
from gened.classes import switch_class
from gened.contexts import (
ContextNotFoundError,
get_available_contexts,
get_context_by_name,
)
Expand Down Expand Up @@ -55,21 +54,20 @@ def help_form(query_id: int | None = None, class_id: int | None = None, ctx_name
success = switch_class(class_id)
if not success:
# Can't access the specified context
flash(f"Cannot access class and context. Make sure you are logged in correctly before using this link.", "danger")
flash("Cannot access class and context. Make sure you are logged in correctly before using this link.", "danger")
return make_response(render_template("error.html"), 400)

# we may select a context from a given ctx_name, from a given query_id, or from the user's most recently-used context
selected_context_name = None
query_row = None
if ctx_name is not None:
# see if the given context is part of the current class (whether available or not)
try:
context = get_context_by_name(CodeHelpContext, ctx_name)
contexts_list = [context] # this will be the only context in this page -- no other options
selected_context_name = ctx_name
except ContextNotFoundError:
context = get_context_by_name(CodeHelpContext, ctx_name)
if context is None:
flash(f"Context not found: {ctx_name}", "danger")
return make_response(render_template("error.html"), 404)
contexts_list = [context] # this will be the only context in this page -- no other options
selected_context_name = ctx_name
else:
contexts_list = get_available_contexts(CodeHelpContext) # all *available* contexts will be shown
if query_id is not None:
Expand All @@ -86,11 +84,11 @@ def help_form(query_id: int | None = None, class_id: int | None = None, ctx_name

# verify the context is real and part of the current class
if selected_context_name is not None:
try:
context = get_context_by_name(CodeHelpContext, selected_context_name)
contexts_list.append(context) # add this context to the list - may be hidden - if duplicate, dict comprehension will automatically filter
except ContextNotFoundError:
context = get_context_by_name(CodeHelpContext, selected_context_name)
if context is None:
selected_context_name = None
else:
contexts_list.append(context) # add this context to the list - may be hidden - if duplicate, dict comprehension will automatically filter

# turn contexts into format we can pass to js via JSON
contexts = {ctx.name: ctx.desc_html() for ctx in contexts_list}
Expand Down Expand Up @@ -229,9 +227,8 @@ def record_response(query_id: int, responses: list[dict[str, str]], texts: dict[
@with_llm()
def help_request(llm_dict: LLMDict) -> Response:
if 'context' in request.form:
try:
context = get_context_by_name(CodeHelpContext, request.form['context'])
except ContextNotFoundError:
context = get_context_by_name(CodeHelpContext, request.form['context'])
if context is None:
flash(f"Context not found: {request.form['context']}")
return make_response(render_template("error.html"), 400)
else:
Expand Down
26 changes: 17 additions & 9 deletions src/codehelp/tutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from gened.admin import register_admin_link
from gened.auth import get_auth, login_required
from gened.contexts import (
ContextNotFoundError,
get_available_contexts,
get_context_by_name,
)
Expand Down Expand Up @@ -72,11 +71,14 @@ def tutor_form() -> str:
@with_llm()
def start_chat(llm_dict: LLMDict) -> Response:
topic = request.form['topic']
try:

if 'context' in request.form:
context = get_context_by_name(CodeHelpContext, request.form['context'])
except ContextNotFoundError:
flash(f"Context not found: {request.form['context']}")
return make_response(render_template("error.html"), 400)
if context is None:
flash(f"Context not found: {request.form['context']}")
return make_response(render_template("error.html"), 400)
else:
context = None

chat_id = create_chat(topic, context)

Expand Down Expand Up @@ -116,16 +118,22 @@ def chat_interface(chat_id: int) -> str | Response:
return render_template("tutor_view.html", chat_id=chat_id, topic=topic, context_name=context_name, chat=chat, chat_history=chat_history)


def create_chat(topic: str, context: CodeHelpContext) -> int:
def create_chat(topic: str, context: CodeHelpContext | None) -> int:
db = get_db()
auth = get_auth()
user_id = auth['user_id']
role_id = auth['role_id']
context_string_id = record_context_string(context.prompt_str())

db = get_db()
if context is not None:
context_name = context.name
context_string_id = record_context_string(context.prompt_str())
else:
context_name = None
context_string_id = None

cur = db.execute(
"INSERT INTO chats (user_id, role_id, topic, context_name, context_string_id, chat_json) VALUES (?, ?, ?, ?, ?, ?)",
[user_id, role_id, topic, context.name, context_string_id, json.dumps([])]
[user_id, role_id, topic, context_name, context_string_id, json.dumps([])]
)
new_row_id = cur.lastrowid

Expand Down
18 changes: 10 additions & 8 deletions src/gened/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,10 @@ def get_available_contexts(ctx_class: type[T]) -> list[T]:
return [ctx_class.from_row(row) for row in context_rows]


class ContextNotFoundError(Exception):
pass


def get_context_config_by_id(ctx_class: type[T], ctx_id: int) -> T:
def get_context_config_by_id(ctx_class: type[T], ctx_id: int) -> T | None:
""" Return a context object of the given class based on the specified id
or return None if no context exists with that name.
"""
assert _context_class is not None

db = get_db()
Expand All @@ -324,12 +323,15 @@ def get_context_config_by_id(ctx_class: type[T], ctx_id: int) -> T:
context_row = db.execute("SELECT * FROM contexts WHERE class_id=? AND id=?", [class_id, ctx_id]).fetchone()

if not context_row:
raise ContextNotFoundError
return None

return ctx_class.from_row(context_row)


def get_context_by_name(ctx_class: type[T], ctx_name: str) -> T:
def get_context_by_name(ctx_class: type[T], ctx_name: str) -> T | None:
""" Return a context object of the given class based on the specified name
or return None if no context exists with that name.
"""
assert _context_class is not None

db = get_db()
Expand All @@ -340,6 +342,6 @@ def get_context_by_name(ctx_class: type[T], ctx_name: str) -> T:
context_row = db.execute("SELECT * FROM contexts WHERE class_id=? AND name=?", [class_id, ctx_name]).fetchone()

if not context_row:
raise ContextNotFoundError
return None

return ctx_class.from_row(context_row)

0 comments on commit b548eb9

Please sign in to comment.