Skip to content

Commit

Permalink
fix(ui): gradio bug fixes (#2021)
Browse files Browse the repository at this point in the history
* fix: when two user messages were sent

* fix: add source divider

* fix: add favicon

* fix: add zylon link

* refactor: update label
  • Loading branch information
jaluma authored Jul 29, 2024
1 parent 20bad17 commit d4375d0
Showing 1 changed file with 32 additions and 19 deletions.
51 changes: 32 additions & 19 deletions private_gpt/ui/ui.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""This file should be imported if and only if you want to run the UI locally."""

import itertools
import base64
import logging
import time
from collections.abc import Iterable
Expand Down Expand Up @@ -31,7 +30,7 @@

UI_TAB_TITLE = "My Private GPT"

SOURCES_SEPARATOR = "\n\n Sources: \n"
SOURCES_SEPARATOR = "<hr>Sources: \n"

MODES = ["Query Files", "Search Files", "LLM Chat (no context from files)"]

Expand Down Expand Up @@ -109,25 +108,25 @@ def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]:
+ f"{index}. {source.file} (page {source.page}) \n\n"
)
used_files.add(f"{source.file}-{source.page}")
sources_text += "<hr>\n\n"
full_response += sources_text
yield full_response

def build_history() -> list[ChatMessage]:
history_messages: list[ChatMessage] = list(
itertools.chain(
*[
[
ChatMessage(content=interaction[0], role=MessageRole.USER),
ChatMessage(
# Remove from history content the Sources information
content=interaction[1].split(SOURCES_SEPARATOR)[0],
role=MessageRole.ASSISTANT,
),
]
for interaction in history
]
history_messages: list[ChatMessage] = []

for interaction in history:
history_messages.append(
ChatMessage(content=interaction[0], role=MessageRole.USER)
)
)
if len(interaction) > 1 and interaction[1] is not None:
history_messages.append(
ChatMessage(
# Remove from history content the Sources information
content=interaction[1].split(SOURCES_SEPARATOR)[0],
role=MessageRole.ASSISTANT,
)
)

# max 20 messages to try to avoid context overflow
return history_messages[:20]
Expand Down Expand Up @@ -314,7 +313,13 @@ def _build_ui_blocks(self) -> gr.Blocks:
".contain { display: flex !important; flex-direction: column !important; }"
"#component-0, #component-3, #component-10, #component-8 { height: 100% !important; }"
"#chatbot { flex-grow: 1 !important; overflow: auto !important;}"
"#col { height: calc(100vh - 112px - 16px) !important; }",
"#col { height: calc(100vh - 112px - 16px) !important; }"
"hr { margin-top: 1em; margin-bottom: 1em; border: 0; border-top: 1px solid #FFF; }"
".avatar-image { background-color: antiquewhite; border-radius: 2px; }"
".footer { text-align: center; margin-top: 20px; font-size: 14px; display: flex; align-items: center; justify-content: center; }"
".footer-zylon-link { display:flex; margin-left: 5px; text-decoration: auto; color: #fff; }"
".footer-zylon-link:hover { color: #C7BAFF; }"
".footer-zylon-ico { height: 20px; margin-left: 5px; background-color: antiquewhite; border-radius: 2px; }",
) as blocks:
with gr.Row():
gr.HTML(f"<div class='logo'/><img src={logo_svg} alt=PrivateGPT></div")
Expand Down Expand Up @@ -477,6 +482,14 @@ def get_model_label() -> str | None:
),
additional_inputs=[mode, upload_button, system_prompt_input],
)

with gr.Row():
avatar_byte = AVATAR_BOT.read_bytes()
f_base64 = f"data:image/png;base64,{base64.b64encode(avatar_byte).decode('utf-8')}"
gr.HTML(
f"<div class='footer'><a class='footer-zylon-link' href='https://zylon.ai/'>Maintained by Zylon <img class='footer-zylon-ico' src='{f_base64}' alt=Zylon></a></div>"
)

return blocks

def get_ui_blocks(self) -> gr.Blocks:
Expand All @@ -488,7 +501,7 @@ def mount_in_app(self, app: FastAPI, path: str) -> None:
blocks = self.get_ui_blocks()
blocks.queue()
logger.info("Mounting the gradio UI, at path=%s", path)
gr.mount_gradio_app(app, blocks, path=path)
gr.mount_gradio_app(app, blocks, path=path, favicon_path=AVATAR_BOT)


if __name__ == "__main__":
Expand Down

0 comments on commit d4375d0

Please sign in to comment.