Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-text document upload directly on chat window #467

Merged
merged 13 commits into from
Jul 26, 2024
18 changes: 15 additions & 3 deletions backend/app/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import os
import re
from pathlib import Path
from typing import TypedDict, no_type_check

Expand Down Expand Up @@ -94,6 +95,17 @@ def _get_converse_supported_format(ext: str) -> str:
return supported_formats.get(ext, "txt")


def _convert_to_valid_file_name(file_name: str) -> str:
# Note: The document file name can only contain alphanumeric characters,
# whitespace characters, hyphens, parentheses, and square brackets.
# The name can't contain more than one consecutive whitespace character.
file_name = re.sub(r"[^a-zA-Z0-9\s\-\(\)\[\]]", "", file_name)
file_name = re.sub(r"\s+", " ", file_name)
file_name = file_name.strip()

return file_name


@no_type_check
def compose_args_for_converse_api(
messages: list[MessageModel],
Expand Down Expand Up @@ -124,7 +136,7 @@ def compose_args_for_converse_api(
}
}
)
elif c.content_type == "textAttachment":
elif c.content_type == "attachment":
content_blocks.append(
{
"document": {
Expand All @@ -134,10 +146,10 @@ def compose_args_for_converse_api(
], # e.g. "document.txt" -> "txt"
),
"name": Path(
c.file_name
_convert_to_valid_file_name(c.file_name)
).stem, # e.g. "document.txt" -> "document"
# encode text attachment body
"source": {"bytes": c.body.encode("utf-8")},
"source": {"bytes": base64.b64decode(c.body)},
}
}
)
Expand Down
2 changes: 1 addition & 1 deletion backend/app/repositories/models/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class ContentModel(BaseModel):
content_type: Literal["text", "image", "textAttachment"]
content_type: Literal["text", "image", "attachment"]
media_type: str | None
body: str = Field(
...,
Expand Down
22 changes: 11 additions & 11 deletions backend/app/routes/schemas/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


class Content(BaseSchema):
content_type: Literal["text", "image", "textAttachment"] = Field(
content_type: Literal["text", "image", "attachment"] = Field(
..., description="Content type. Note that image is only available for claude 3."
)
media_type: str | None = Field(
Expand All @@ -27,7 +27,7 @@ class Content(BaseSchema):
)
file_name: str | None = Field(
None,
description="File name of the attachment. Must be specified if `content_type` is `textAttachment`.",
description="File name of the attachment. Must be specified if `content_type` is `attachment`.",
)
body: str = Field(..., description="Content body.")

Expand All @@ -42,18 +42,18 @@ def check_media_type(cls, v, values):
def check_body(cls, v, values):
content_type = values.get("content_type")

# if content_type in ["image", "textAttachment"]:
# try:
# # Check if the body is a valid base64 string
# base64.b64decode(v, validate=True)
# except Exception:
# raise ValueError(
# "body must be a valid base64 string if `content_type` is `image` or `textAttachment`"
# )

if content_type == "text" and not isinstance(v, str):
raise ValueError("body must be str if `content_type` is `text`")

if content_type in ["image", "attachment"]:
try:
# Check if the body is a valid base64 string
base64.b64decode(v, validate=True)
except Exception:
raise ValueError(
"body must be a valid base64 string if `content_type` is `image` or `attachment`"
)

return v


Expand Down
9 changes: 5 additions & 4 deletions backend/tests/test_stream/test_stream.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import sys

sys.path.append(".")
Expand Down Expand Up @@ -151,16 +152,16 @@ def test_run_with_image(self):
self._run(message)

def test_run_with_attachment(self):
# _, aws_pdf_body = get_aws_overview()
# aws_pdf_filename = "aws_arch_overview.pdf"
body = get_test_markdown()
file_name, body = get_aws_overview()
body = base64.b64encode(body).decode("utf-8")
# body = get_test_markdown()
file_name = "test.md"

message = MessageModel(
role="user",
content=[
ContentModel(
content_type="textAttachment",
content_type="attachment",
media_type=None,
body=body,
file_name=file_name,
Expand Down
38 changes: 38 additions & 0 deletions backend/tests/test_usecases/test_chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import sys

sys.path.insert(0, ".")
Expand Down Expand Up @@ -39,6 +40,7 @@
)
from app.vector_search import SearchResult
from tests.test_stream.get_aws_logo import get_aws_logo
from tests.test_stream.get_pdf import get_aws_overview
from tests.test_usecases.utils.bot_factory import (
create_test_instruction_template,
create_test_private_bot,
Expand Down Expand Up @@ -271,6 +273,42 @@ def tearDown(self) -> None:
delete_conversation_by_id("user1", self.output.conversation_id)


class TestAttachmentChat(unittest.TestCase):
def tearDown(self) -> None:
delete_conversation_by_id("user1", self.output.conversation_id)

def test_chat(self):
file_name, body = get_aws_overview()
chat_input = ChatInput(
conversation_id="test_conversation_id",
message=MessageInput(
role="user",
content=[
Content(
content_type="attachment",
body=base64.b64encode(body).decode("utf-8"),
media_type=None,
file_name=file_name,
),
Content(
content_type="text",
body="Summarize the document.",
media_type=None,
file_name=None,
),
],
model=MODEL,
parent_message_id=None,
message_id=None,
),
bot_id=None,
continue_generate=False,
)
output: ChatOutput = chat(user_id="user1", chat_input=chat_input)
pprint(output.model_dump())
self.output = output


class TestMultimodalChat(unittest.TestCase):
def tearDown(self) -> None:
delete_conversation_by_id("user1", self.output.conversation_id)
Expand Down
1 change: 1 addition & 0 deletions cdk/lib/bedrock-chat-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ export class BedrockChatStack extends cdk.Stack {
bedrockRegion: props.bedrockRegion,
largeMessageBucket,
documentBucket,
enableMistral: props.enableMistral,
});
frontend.buildViteApp({
backendApiEndpoint: backendApi.api.apiEndpoint,
Expand Down
2 changes: 2 additions & 0 deletions cdk/lib/constructs/websocket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export interface WebSocketProps {
readonly websocketSessionTable: ITable;
readonly largeMessageBucket: s3.IBucket;
readonly accessLogBucket?: s3.Bucket;
readonly enableMistral: boolean;
}

export class WebSocket extends Construct {
Expand Down Expand Up @@ -110,6 +111,7 @@ export class WebSocket extends Construct {
DB_SECRETS_ARN: props.dbSecrets.secretArn,
LARGE_PAYLOAD_SUPPORT_BUCKET: largePayloadSupportBucket.bucketName,
WEBSOCKET_SESSION_TABLE_NAME: props.websocketSessionTable.tableName,
ENABLE_MISTRAL: props.enableMistral.toString(),
},
role: handlerRole,
});
Expand Down
2 changes: 1 addition & 1 deletion frontend/src/@types/conversation.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export type Model =
| 'mixtral-8x7b-instruct'
| 'mistral-large';
export type Content = {
contentType: 'text' | 'image' | 'textAttachment';
contentType: 'text' | 'image' | 'attachment';
mediaType?: string;
fileName?: string;
body: string;
Expand Down
34 changes: 25 additions & 9 deletions frontend/src/components/ChatMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import ModalDialog from './ModalDialog';
import { useTranslation } from 'react-i18next';
import useChat from '../hooks/useChat';
import DialogFeedback from './DialogFeedback';
import UploadedFileText from './UploadedFileText';
import UploadedAttachedFile from './UploadedAttachedFile';
import { TEXT_FILE_EXTENSIONS } from '../constants/supportedAttachedFiles';

type Props = BaseProps & {
chatContent?: DisplayMessageContent;
Expand Down Expand Up @@ -177,20 +178,35 @@ const ChatMessage: React.FC<Props> = (props) => {
</div>
)}
{chatContent.content.some(
(content) => content.contentType === 'textAttachment'
(content) => content.contentType === 'attachment'
) && (
<div key="files" className="my-2 flex">
{chatContent.content.map((content, idx) => {
if (content.contentType === 'textAttachment') {
if (content.contentType === 'attachment') {
const isTextFile = TEXT_FILE_EXTENSIONS.some(
(ext) => content.fileName?.toLowerCase().endsWith(ext)
);
return (
<UploadedFileText
<UploadedAttachedFile
key={idx}
fileName={content.fileName ?? ''}
onClick={() => {
setDialogFileName(content.fileName ?? '');
setDialogFileContent(content.body);
setIsFileModalOpen(true);
}}
onClick={
// Only text file can be previewed
isTextFile
? () => {
const textContent = new TextDecoder(
'utf-8'
).decode(
Uint8Array.from(atob(content.body), (c) =>
c.charCodeAt(0)
)
); // base64 encoded text to be decoded string
setDialogFileName(content.fileName ?? '');
setDialogFileContent(textContent);
setIsFileModalOpen(true);
}
: undefined
}
/>
);
}
Expand Down
Loading
Loading