-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add router at /courses/[canvas_id]/chat that can start new chats and …
…accept new messages from students
- Loading branch information
Showing
7 changed files
with
223 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
|
||
def find_chat_by_id(chat_id: str): | ||
from db.models.chat import Chat | ||
return Chat.select().filter(Chat.public_id == chat_id).first() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
|
||
def find_course_by_canvas_id(canvas_id: str): | ||
from db.models.course import Course | ||
return Course.select().filter(Course.canvas_id == canvas_id).first() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
|
||
def all_messages_in_chat(chat_id: int): | ||
from db.models.message import Message | ||
return Message.select().where(Message.chat == chat_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from typing import List | ||
|
||
from fastapi import APIRouter, Depends, HTTPException, status | ||
from pydantic import BaseModel, constr | ||
|
||
from db.actions.course import find_course_by_canvas_id | ||
from db.actions.message import all_messages_in_chat | ||
from http_api.auth import get_current_session | ||
from db.actions.chat import find_chat_by_id | ||
from db.models import Chat, Session, Message | ||
|
||
|
||
router = APIRouter() | ||
|
||
|
||
class ChatResponse(BaseModel): | ||
public_id: str | ||
|
||
|
||
class MessageResponse(BaseModel): | ||
message_id: str | ||
content: str | ||
sender: str | ||
created_at: str | ||
|
||
|
||
class MessagesResponse(BaseModel): | ||
messages: List[MessageResponse] | ||
|
||
|
||
class MessageRequestBody(BaseModel): | ||
content: constr(min_length=1, max_length=2048) | ||
|
||
|
||
@router.post('/course/{course_canvas_id}/chat', dependencies=[Depends(get_current_session)], response_model=ChatResponse) | ||
async def start_new_chat(course_canvas_id: str, session: Session = Depends(get_current_session)) -> ChatResponse: | ||
course = find_course_by_canvas_id(course_canvas_id) | ||
if course is None: | ||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Course not found") | ||
|
||
chat = Chat(course=course, session=session) | ||
chat.save() | ||
|
||
return ChatResponse(public_id=chat.public_id) | ||
|
||
|
||
@router.post( | ||
'/course/{course_canvas_id}/chat/{chat_id}/messages', | ||
dependencies=[Depends(get_current_session)], | ||
status_code=status.HTTP_201_CREATED, | ||
response_model=MessageResponse | ||
) | ||
async def send_message( | ||
course_canvas_id: str, | ||
chat_id: str, | ||
body: MessageRequestBody, | ||
) -> MessageResponse: | ||
course = find_course_by_canvas_id(course_canvas_id) | ||
if course is None: | ||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Course not found") | ||
|
||
chat = find_chat_by_id(chat_id) | ||
if chat is None: | ||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found") | ||
|
||
msg = Message(chat=chat, content=body.content, sender=Message.Sender.STUDENT) | ||
msg.save() | ||
|
||
return MessageResponse( | ||
message_id=msg.message_id, | ||
content=msg.content, | ||
sender=msg.sender, | ||
created_at=str(msg.created_at), | ||
) | ||
|
||
|
||
@router.get( | ||
'/course/{course_canvas_id}/chat/{chat_id}/messages', | ||
dependencies=[Depends(get_current_session)], | ||
status_code=status.HTTP_200_OK, | ||
response_model=MessagesResponse | ||
) | ||
async def get_messages( | ||
course_canvas_id: str, | ||
chat_id: str, | ||
) -> MessagesResponse: | ||
course = find_course_by_canvas_id(course_canvas_id) | ||
if course is None: | ||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Course not found") | ||
|
||
chat = find_chat_by_id(chat_id) | ||
if chat is None: | ||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found") | ||
|
||
messages = all_messages_in_chat(chat.id) | ||
out = [] | ||
for msg in messages: | ||
out.append(MessageResponse( | ||
message_id=msg.message_id, | ||
content=msg.content, | ||
sender=msg.sender, | ||
created_at=str(msg.created_at), | ||
)) | ||
|
||
return MessagesResponse(messages=out) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from db.models import Course, Chat, Message | ||
|
||
|
||
def test_chats_are_tied_to_course_room(api_client, authenticated_session, valid_course): | ||
response = api_client.post(f'/course/{valid_course.canvas_id}/chat', headers=authenticated_session.headers) | ||
chat = Chat.filter(Chat.public_id == response.json()['public_id']).first() | ||
|
||
assert response.status_code == 200 | ||
assert chat.course.canvas_id == valid_course.canvas_id | ||
|
||
|
||
def test_chats_are_tied_to_session(api_client, authenticated_session, valid_course): | ||
response = api_client.post(f'/course/{valid_course.canvas_id}/chat', headers=authenticated_session.headers) | ||
chat = Chat.filter(Chat.public_id == response.json()['public_id']).first() | ||
|
||
assert response.status_code == 200 | ||
assert chat.session.public_id == authenticated_session.session.public_id | ||
|
||
|
||
def test_start_chat_fails_with_invalid_course_id(api_client, authenticated_session): | ||
invalid_course = 'something_bogus' | ||
|
||
response = api_client.post(f'/course/{invalid_course}/chat', headers=authenticated_session.headers) | ||
|
||
assert response.status_code == 404 | ||
|
||
|
||
def test_start_chat_fails_with_invalid_session(api_client, authenticated_session, valid_course): | ||
authenticated_session.session.is_valid = False | ||
authenticated_session.session.save() | ||
|
||
response = api_client.post(f'/course/{valid_course.canvas_id}/chat', headers=authenticated_session.headers) | ||
|
||
assert response.status_code == 401 | ||
|
||
|
||
def test_user_can_send_message_to_chat(api_client, authenticated_session, new_chat): | ||
url = f'/course/{new_chat.course.canvas_id}/chat/{new_chat.chat.public_id}/messages' | ||
response = api_client.post(url, json={'content': 'foo'}, headers=authenticated_session.headers) | ||
|
||
new_chat.chat.refresh() | ||
assert response.status_code == 201 | ||
assert new_chat.chat.messages[0].content == 'foo' | ||
assert new_chat.chat.messages[0].sender == Message.Sender.STUDENT | ||
|
||
|
||
def test_user_get_messages_in_chat(api_client, authenticated_session, new_chat): | ||
new_chat.add_some_messages() | ||
|
||
url = f'/course/{new_chat.course.canvas_id}/chat/{new_chat.chat.public_id}/messages' | ||
response = api_client.get(url, headers=authenticated_session.headers) | ||
|
||
messages = response.json()['messages'] | ||
|
||
assert response.status_code == 200 | ||
assert len(messages) == len(new_chat.chat.messages) | ||
for idx, message in enumerate(messages): | ||
assert new_chat.chat.messages[idx].content == message['content'] | ||
assert new_chat.chat.messages[idx].sender == message['sender'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters