Skip to content

Commit

Permalink
Enable vision arena across all tabs (#3483)
Browse files Browse the repository at this point in the history
  • Loading branch information
BabyChouSr authored Aug 30, 2024
1 parent 68023e1 commit 1ccbe8b
Show file tree
Hide file tree
Showing 7 changed files with 385 additions and 154 deletions.
181 changes: 124 additions & 57 deletions fastchat/serve/gradio_block_arena_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import json
import os
import time
from typing import List, Union

import gradio as gr
from gradio.data_classes import FileData
Expand All @@ -26,6 +27,7 @@
from fastchat.model.model_adapter import (
get_conversation_template,
)
from fastchat.serve.gradio_global_state import Context
from fastchat.serve.gradio_web_server import (
get_model_description_md,
acknowledgment_md,
Expand Down Expand Up @@ -144,14 +146,18 @@ def clear_history(request: gr.Request):
ip = get_ip(request)
logger.info(f"clear_history. ip: {ip}")
state = None
return (state, [], None) + (disable_btn,) * 5
return (state, [], enable_multimodal, invisible_text, invisible_btn) + (
disable_btn,
) * 5


def clear_history_example(request: gr.Request):
ip = get_ip(request)
logger.info(f"clear_history_example. ip: {ip}")
state = None
return (state, [], enable_multimodal) + (disable_btn,) * 5
return (state, [], enable_multimodal, invisible_text, invisible_btn) + (
disable_btn,
) * 5


# TODO(Chris): At some point, we would like this to be a live-reporting feature.
Expand Down Expand Up @@ -209,17 +215,40 @@ def moderate_input(state, text, all_conv_text, model_list, images, ip):
return text, image_flagged, csam_flagged


def add_text(state, model_selector, chat_input, request: gr.Request):
text, images = chat_input["text"], chat_input["files"]
def add_text(
state,
model_selector,
chat_input: Union[str, dict],
context: Context,
request: gr.Request,
):
if isinstance(chat_input, dict):
text, images = chat_input["text"], chat_input["files"]
else:
text, images = chat_input, []

if (
len(images) > 0
and model_selector in context.text_models
and model_selector not in context.vision_models
):
gr.Warning(f"{model_selector} is a text-only model. Image is ignored.")
images = []

ip = get_ip(request)
logger.info(f"add_text. ip: {ip}. len: {len(text)}")

if state is None:
state = State(model_selector, is_vision=True)
if len(images) == 0:
state = State(model_selector, is_vision=False)
else:
state = State(model_selector, is_vision=True)

if len(text) <= 0:
state.skip_next = True
return (state, state.to_gradio_chatbot(), None) + (no_change_btn,) * 5
return (state, state.to_gradio_chatbot(), None, "", no_change_btn) + (
no_change_btn,
) * 5

all_conv_text = state.conv.get_prompt()
all_conv_text = all_conv_text[-2000:] + "\nuser: " + text
Expand All @@ -233,26 +262,40 @@ def add_text(state, model_selector, chat_input, request: gr.Request):
if image_flagged:
logger.info(f"image flagged. ip: {ip}. text: {text}")
state.skip_next = True
return (state, state.to_gradio_chatbot(), {"text": IMAGE_MODERATION_MSG}) + (
return (
state,
state.to_gradio_chatbot(),
{"text": IMAGE_MODERATION_MSG},
"",
no_change_btn,
) * 5
) + (no_change_btn,) * 5

if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
state.skip_next = True
return (state, state.to_gradio_chatbot(), {"text": CONVERSATION_LIMIT_MSG}) + (
return (
state,
state.to_gradio_chatbot(),
{"text": CONVERSATION_LIMIT_MSG},
"",
no_change_btn,
) * 5
) + (no_change_btn,) * 5

text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
text = _prepare_text_with_image(state, text, images, csam_flag=csam_flag)
state.conv.append_message(state.conv.roles[0], text)
state.conv.append_message(state.conv.roles[1], None)
return (state, state.to_gradio_chatbot(), None) + (disable_btn,) * 5
return (
state,
state.to_gradio_chatbot(),
disable_multimodal,
visible_text,
enable_btn,
) + (disable_btn,) * 5


def build_single_vision_language_model_ui(
models, add_promotion_links=False, random_questions=None
context: Context, add_promotion_links=False, random_questions=None
):
promotion = (
"""
Expand All @@ -272,33 +315,29 @@ def build_single_vision_language_model_ui(

state = gr.State()
gr.Markdown(notice_markdown, elem_id="notice_markdown")
text_and_vision_models = list(set(context.text_models + context.vision_models))
context_state = gr.State(context)

with gr.Group():
with gr.Row(elem_id="model_selector_row"):
model_selector = gr.Dropdown(
choices=models,
value=models[0] if len(models) > 0 else "",
choices=text_and_vision_models,
value=text_and_vision_models[0]
if len(text_and_vision_models) > 0
else "",
interactive=True,
show_label=False,
container=False,
)

with gr.Accordion(
f"🔍 Expand to see the descriptions of {len(models)} models", open=False
f"🔍 Expand to see the descriptions of {len(text_and_vision_models)} models",
open=False,
):
model_description_md = get_model_description_md(models)
model_description_md = get_model_description_md(text_and_vision_models)
gr.Markdown(model_description_md, elem_id="model_description_markdown")

with gr.Row():
textbox = gr.MultimodalTextbox(
file_types=["image"],
show_label=False,
placeholder="Enter your prompt or add image here",
container=True,
render=False,
elem_id="input_box",
)

with gr.Column(scale=2, visible=False) as image_column:
imagebox = gr.Image(
type="pil",
Expand All @@ -311,9 +350,24 @@ def build_single_vision_language_model_ui(
)

with gr.Row():
textbox.render()
# with gr.Column(scale=1, min_width=50):
# send_btn = gr.Button(value="Send", variant="primary")
textbox = gr.Textbox(
show_label=False,
placeholder="👉 Enter your prompt and press ENTER",
elem_id="input_box",
visible=False,
)

send_btn = gr.Button(
value="Send", variant="primary", scale=0, visible=False, interactive=False
)

multimodal_textbox = gr.MultimodalTextbox(
file_types=["image"],
show_label=False,
placeholder="Enter your prompt or add image here",
container=True,
elem_id="input_box",
)

with gr.Row(elem_id="buttons"):
if random_questions:
Expand All @@ -327,22 +381,6 @@ def build_single_vision_language_model_ui(
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)

cur_dir = os.path.dirname(os.path.abspath(__file__))

examples = gr.Examples(
examples=[
{
"text": "How can I prepare a delicious meal using these ingredients?",
"files": [f"{cur_dir}/example_images/fridge.jpg"],
},
{
"text": "What might the woman on the right be thinking about?",
"files": [f"{cur_dir}/example_images/distracted.jpg"],
},
],
inputs=[textbox],
)

with gr.Accordion("Parameters", open=False) as parameter_row:
temperature = gr.Slider(
minimum=0.0,
Expand Down Expand Up @@ -394,23 +432,50 @@ def build_single_vision_language_model_ui(
[state, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
)
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
clear_btn.click(
clear_history,
None,
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
)

model_selector.change(
clear_history, None, [state, chatbot, textbox] + btn_list
).then(set_visible_image, [textbox], [image_column])
examples.dataset.click(
clear_history_example, None, [state, chatbot, textbox] + btn_list
clear_history,
None,
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
).then(set_visible_image, [multimodal_textbox], [image_column])

multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then(
set_visible_image, [multimodal_textbox], [image_column]
).then(
clear_history_example,
None,
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
)

textbox.input(add_image, [textbox], [imagebox]).then(
set_visible_image, [textbox], [image_column]
).then(clear_history_example, None, [state, chatbot, textbox] + btn_list)
multimodal_textbox.submit(
add_text,
[state, model_selector, multimodal_textbox, context_state],
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
).then(set_invisible_image, [], [image_column]).then(
bot_response,
[state, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
)

textbox.submit(
add_text,
[state, model_selector, textbox],
[state, chatbot, textbox] + btn_list,
[state, model_selector, textbox, context_state],
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
).then(set_invisible_image, [], [image_column]).then(
bot_response,
[state, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
)

send_btn.click(
add_text,
[state, model_selector, textbox, context_state],
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
).then(set_invisible_image, [], [image_column]).then(
bot_response,
[state, temperature, top_p, max_output_tokens],
Expand All @@ -421,9 +486,11 @@ def build_single_vision_language_model_ui(
random_btn.click(
get_vqa_sample, # First, get the VQA sample
[], # Pass the path to the VQA samples
[textbox, imagebox], # Outputs are textbox and imagebox
).then(set_visible_image, [textbox], [image_column]).then(
clear_history_example, None, [state, chatbot, textbox] + btn_list
[multimodal_textbox, imagebox], # Outputs are textbox and imagebox
).then(set_visible_image, [multimodal_textbox], [image_column]).then(
clear_history_example,
None,
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
)

return [state, model_selector]
Loading

0 comments on commit 1ccbe8b

Please sign in to comment.