Skip to content

Commit

Permalink
add router at /courses/[canvas_id]/chat that can start new chats and …
Browse files Browse the repository at this point in the history
…accept new messages from students
  • Loading branch information
nattvara committed Mar 16, 2024
1 parent 20f97cd commit 688edef
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 1 deletion.
4 changes: 4 additions & 0 deletions db/actions/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

def find_chat_by_id(chat_id: str):
from db.models.chat import Chat
return Chat.select().filter(Chat.public_id == chat_id).first()
4 changes: 4 additions & 0 deletions db/actions/course.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

def find_course_by_canvas_id(canvas_id: str):
from db.models.course import Course
return Course.select().filter(Course.canvas_id == canvas_id).first()
4 changes: 4 additions & 0 deletions db/actions/message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

def all_messages_in_chat(chat_id: int):
from db.models.message import Message
return Message.select().where(Message.chat == chat_id)
2 changes: 2 additions & 0 deletions http_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
websocket,
sessions,
index,
chat,
)


Expand All @@ -22,6 +23,7 @@ def get_app():
app.include_router(index.router)
app.include_router(websocket.router)
app.include_router(sessions.router)
app.include_router(chat.router)

return app

Expand Down
105 changes: 105 additions & 0 deletions http_api/routers/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import List

from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, constr

from db.actions.course import find_course_by_canvas_id
from db.actions.message import all_messages_in_chat
from http_api.auth import get_current_session
from db.actions.chat import find_chat_by_id
from db.models import Chat, Session, Message


router = APIRouter()


class ChatResponse(BaseModel):
public_id: str


class MessageResponse(BaseModel):
message_id: str
content: str
sender: str
created_at: str


class MessagesResponse(BaseModel):
messages: List[MessageResponse]


class MessageRequestBody(BaseModel):
content: constr(min_length=1, max_length=2048)


@router.post('/course/{course_canvas_id}/chat', dependencies=[Depends(get_current_session)], response_model=ChatResponse)
async def start_new_chat(course_canvas_id: str, session: Session = Depends(get_current_session)) -> ChatResponse:
course = find_course_by_canvas_id(course_canvas_id)
if course is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Course not found")

chat = Chat(course=course, session=session)
chat.save()

return ChatResponse(public_id=chat.public_id)


@router.post(
'/course/{course_canvas_id}/chat/{chat_id}/messages',
dependencies=[Depends(get_current_session)],
status_code=status.HTTP_201_CREATED,
response_model=MessageResponse
)
async def send_message(
course_canvas_id: str,
chat_id: str,
body: MessageRequestBody,
) -> MessageResponse:
course = find_course_by_canvas_id(course_canvas_id)
if course is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Course not found")

chat = find_chat_by_id(chat_id)
if chat is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")

msg = Message(chat=chat, content=body.content, sender=Message.Sender.STUDENT)
msg.save()

return MessageResponse(
message_id=msg.message_id,
content=msg.content,
sender=msg.sender,
created_at=str(msg.created_at),
)


@router.get(
'/course/{course_canvas_id}/chat/{chat_id}/messages',
dependencies=[Depends(get_current_session)],
status_code=status.HTTP_200_OK,
response_model=MessagesResponse
)
async def get_messages(
course_canvas_id: str,
chat_id: str,
) -> MessagesResponse:
course = find_course_by_canvas_id(course_canvas_id)
if course is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Course not found")

chat = find_chat_by_id(chat_id)
if chat is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")

messages = all_messages_in_chat(chat.id)
out = []
for msg in messages:
out.append(MessageResponse(
message_id=msg.message_id,
content=msg.content,
sender=msg.sender,
created_at=str(msg.created_at),
))

return MessagesResponse(messages=out)
59 changes: 59 additions & 0 deletions tests/api/chat_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from db.models import Course, Chat, Message


def test_chats_are_tied_to_course_room(api_client, authenticated_session, valid_course):
response = api_client.post(f'/course/{valid_course.canvas_id}/chat', headers=authenticated_session.headers)
chat = Chat.filter(Chat.public_id == response.json()['public_id']).first()

assert response.status_code == 200
assert chat.course.canvas_id == valid_course.canvas_id


def test_chats_are_tied_to_session(api_client, authenticated_session, valid_course):
response = api_client.post(f'/course/{valid_course.canvas_id}/chat', headers=authenticated_session.headers)
chat = Chat.filter(Chat.public_id == response.json()['public_id']).first()

assert response.status_code == 200
assert chat.session.public_id == authenticated_session.session.public_id


def test_start_chat_fails_with_invalid_course_id(api_client, authenticated_session):
invalid_course = 'something_bogus'

response = api_client.post(f'/course/{invalid_course}/chat', headers=authenticated_session.headers)

assert response.status_code == 404


def test_start_chat_fails_with_invalid_session(api_client, authenticated_session, valid_course):
authenticated_session.session.is_valid = False
authenticated_session.session.save()

response = api_client.post(f'/course/{valid_course.canvas_id}/chat', headers=authenticated_session.headers)

assert response.status_code == 401


def test_user_can_send_message_to_chat(api_client, authenticated_session, new_chat):
url = f'/course/{new_chat.course.canvas_id}/chat/{new_chat.chat.public_id}/messages'
response = api_client.post(url, json={'content': 'foo'}, headers=authenticated_session.headers)

new_chat.chat.refresh()
assert response.status_code == 201
assert new_chat.chat.messages[0].content == 'foo'
assert new_chat.chat.messages[0].sender == Message.Sender.STUDENT


def test_user_get_messages_in_chat(api_client, authenticated_session, new_chat):
new_chat.add_some_messages()

url = f'/course/{new_chat.course.canvas_id}/chat/{new_chat.chat.public_id}/messages'
response = api_client.get(url, headers=authenticated_session.headers)

messages = response.json()['messages']

assert response.status_code == 200
assert len(messages) == len(new_chat.chat.messages)
for idx, message in enumerate(messages):
assert new_chat.chat.messages[idx].content == message['content']
assert new_chat.chat.messages[idx].sender == message['sender']
46 changes: 45 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
from websockets import WebSocketClientProtocol
from fastapi.testclient import TestClient
from numpy.random import rand, randint
import pytest

from services.llm.supported_models import LLMModel
Expand All @@ -13,8 +14,8 @@
# in the main project when running inside pytest
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from db.models import all_models, Session, Course, Chat, Message # noqa
from config.settings import Settings # noqa
from db.models import all_models # noqa
from db.connection import db # noqa
import http_api # noqa

Expand Down Expand Up @@ -107,3 +108,46 @@ def llm_model_name():
@pytest.fixture
def llm_prompt():
return "tell me a fact"


@pytest.fixture
def authenticated_session():
class AuthenticatedSession:
def __init__(self, session: Session):
self.session = session
self.headers = {'X-Session-ID': valid_session.public_id}

valid_session = Session()
valid_session.save()

return AuthenticatedSession(valid_session)


@pytest.fixture
def valid_course():
course = Course(canvas_id="41428")
course.save()
return course


@pytest.fixture
def new_chat(authenticated_session, valid_course):
class NewChat:
def __init__(self, chat: Chat, course: Course):
self.course = course
self.chat = chat

def add_some_messages(self):
for _ in range(randint(5, 10)):
if rand() > 0.5:
sender = Message.Sender.STUDENT
else:
sender = Message.Sender.ASSISTANT

msg = Message(sender=sender, content=f'Hello from {sender}!', chat=self.chat)
msg.save()

c = Chat(course=valid_course, session=authenticated_session.session)
c.save()

return NewChat(c, valid_course)

0 comments on commit 688edef

Please sign in to comment.