Skip to content

Commit

Permalink
Use Reka Python SDK and add script for benchmarking and add send_btn (#…
Browse files Browse the repository at this point in the history
…3413)

Co-authored-by: Wei-Lin Chiang <[email protected]>
Co-authored-by: Wei-Lin Chiang <[email protected]>
Co-authored-by: simon-mo <[email protected]>
  • Loading branch information
4 people authored Jul 5, 2024
1 parent a71e3c6 commit 68023e1
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 59 deletions.
49 changes: 37 additions & 12 deletions fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,30 +524,55 @@ def to_anthropic_vision_api_messages(self):

def to_reka_api_messages(self):
from fastchat.serve.vision.image import ImageFormat
from reka import ChatMessage, TypedMediaContent, TypedText

ret = []
for i, (_, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
if type(msg) == tuple:
text, images = msg
for image in images:
if image.image_format == ImageFormat.URL:
ret.append(
{"type": "human", "text": text, "media_url": image.url}
)
elif image.image_format == ImageFormat.BYTES:
if image.image_format == ImageFormat.BYTES:
ret.append(
{
"type": "human",
"text": text,
"media_url": f"data:image/{image.filetype};base64,{image.base64_str}",
}
ChatMessage(
content=[
TypedText(
type="text",
text=text,
),
TypedMediaContent(
type="image_url",
image_url=f"data:image/{image.filetype};base64,{image.base64_str}",
),
],
role="user",
)
)
else:
ret.append({"type": "human", "text": msg})
ret.append(
ChatMessage(
content=[
TypedText(
type="text",
text=msg,
)
],
role="user",
)
)
else:
if msg is not None:
ret.append({"type": "model", "text": msg})
ret.append(
ChatMessage(
content=[
TypedText(
type="text",
text=msg,
)
],
role="assistant",
)
)

return ret

Expand Down
46 changes: 22 additions & 24 deletions fastchat/serve/api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,8 +1076,13 @@ def reka_api_stream_iter(
api_key: Optional[str] = None, # default is env var CO_API_KEY
api_base: Optional[str] = None,
):
from reka.client import Reka
from reka import TypedText

api_key = api_key or os.environ["REKA_API_KEY"]

client = Reka(api_key=api_key)

use_search_engine = False
if "-online" in model_name:
model_name = model_name.replace("-online", "")
Expand All @@ -1094,34 +1099,27 @@ def reka_api_stream_iter(

# Make requests for logging
text_messages = []
for message in messages:
text_messages.append({"type": message["type"], "text": message["text"]})
for turn in messages:
for message in turn.content:
if isinstance(message, TypedText):
text_messages.append({"type": message.type, "text": message.text})
logged_request = dict(request)
logged_request["conversation_history"] = text_messages

logger.info(f"==== request ====\n{logged_request}")

response = requests.post(
api_base,
stream=True,
json=request,
headers={
"X-Api-Key": api_key,
},
response = client.chat.create_stream(
messages=messages,
max_tokens=max_new_tokens,
top_p=top_p,
model=model_name,
)

if response.status_code != 200:
error_message = response.text
logger.error(f"==== error from reka api: {error_message} ====")
yield {
"text": f"**API REQUEST ERROR** Reason: {error_message}",
"error_code": 1,
}
return

for line in response.iter_lines():
line = line.decode("utf8")
if not line.startswith("data: "):
continue
gen = json.loads(line[6:])
yield {"text": gen["text"], "error_code": 0}
for chunk in response:
try:
yield {"text": chunk.responses[0].chunk.content, "error_code": 0}
except:
yield {
"text": f"**API REQUEST ERROR** ",
"error_code": 1,
}
48 changes: 37 additions & 11 deletions fastchat/serve/gradio_block_arena_vision_anony.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def clear_history_example(request: gr.Request):
[None] * num_sides
+ [None] * num_sides
+ anony_names
+ [enable_multimodal, invisible_text]
+ [enable_multimodal, invisible_text, invisible_btn]
+ [invisible_btn] * 4
+ [disable_btn] * 2
+ [enable_btn]
Expand Down Expand Up @@ -239,7 +239,7 @@ def clear_history(request: gr.Request):
[None] * num_sides
+ [None] * num_sides
+ anony_names
+ [enable_multimodal, invisible_text]
+ [enable_multimodal, invisible_text, invisible_btn]
+ [invisible_btn] * 4
+ [disable_btn] * 2
+ [enable_btn]
Expand Down Expand Up @@ -297,7 +297,7 @@ def add_text(
return (
states
+ [x.to_gradio_chatbot() for x in states]
+ [None, ""]
+ [None, "", no_change_btn]
+ [
no_change_btn,
]
Expand All @@ -321,7 +321,7 @@ def add_text(
return (
states
+ [x.to_gradio_chatbot() for x in states]
+ [{"text": CONVERSATION_LIMIT_MSG}, ""]
+ [{"text": CONVERSATION_LIMIT_MSG}, "", no_change_btn]
+ [
no_change_btn,
]
Expand All @@ -342,6 +342,7 @@ def add_text(
+ " PLEASE CLICK 🎲 NEW ROUND TO START A NEW CONVERSATION."
},
"",
no_change_btn,
]
+ [no_change_btn] * 7
+ [""]
Expand All @@ -363,7 +364,7 @@ def add_text(
return (
states
+ [x.to_gradio_chatbot() for x in states]
+ [disable_multimodal, visible_text]
+ [disable_multimodal, visible_text, enable_btn]
+ [
disable_btn,
]
Expand Down Expand Up @@ -464,7 +465,9 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
placeholder="Enter your prompt or add image here",
elem_id="input_box",
)
# send_btn = gr.Button(value="Send", variant="primary", scale=0)
send_btn = gr.Button(
value="Send", variant="primary", scale=0, visible=False, interactive=False
)

with gr.Row() as button_row:
if random_questions:
Expand Down Expand Up @@ -548,7 +551,7 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
states
+ chatbots
+ model_selectors
+ [multimodal_textbox, textbox]
+ [multimodal_textbox, textbox, send_btn]
+ btn_list
+ [random_btn]
+ [slow_warning],
Expand Down Expand Up @@ -581,15 +584,19 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
).then(
clear_history_example,
None,
states + chatbots + model_selectors + [multimodal_textbox, textbox] + btn_list,
states
+ chatbots
+ model_selectors
+ [multimodal_textbox, textbox, send_btn]
+ btn_list,
)

multimodal_textbox.submit(
add_text,
states + model_selectors + [multimodal_textbox],
states
+ chatbots
+ [multimodal_textbox, textbox]
+ [multimodal_textbox, textbox, send_btn]
+ btn_list
+ [random_btn]
+ [slow_warning],
Expand All @@ -608,7 +615,26 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
states + model_selectors + [textbox],
states
+ chatbots
+ [multimodal_textbox, textbox]
+ [multimodal_textbox, textbox, send_btn]
+ btn_list
+ [random_btn]
+ [slow_warning],
).then(
bot_response_multi,
states + [temperature, top_p, max_output_tokens],
states + chatbots + btn_list,
).then(
flash_buttons,
[],
btn_list,
)

send_btn.click(
add_text,
states + model_selectors + [textbox],
states
+ chatbots
+ [multimodal_textbox, textbox, send_btn]
+ btn_list
+ [random_btn]
+ [slow_warning],
Expand All @@ -633,7 +659,7 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
states
+ chatbots
+ model_selectors
+ [multimodal_textbox, textbox]
+ [multimodal_textbox, textbox, send_btn]
+ btn_list
+ [random_btn],
)
Expand Down
25 changes: 13 additions & 12 deletions fastchat/serve/monitor/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ def build_arena_tab(
)
return

round_digit = None if vision else None
arena_dfs = {}
category_elo_results = {}
last_updated_time = elo_results["full"]["last_updated_datetime"].split(" ")[0]
Expand All @@ -512,6 +513,7 @@ def update_leaderboard_and_plots(category):
arena_df,
model_table_df,
arena_subset_df=arena_subset_df if category != "Overall" else None,
round_digit=round_digit,
)
if category != "Overall":
arena_values = update_leaderboard_df(arena_values)
Expand Down Expand Up @@ -665,9 +667,7 @@ def update_leaderboard_and_plots(category):
elem_id="leaderboard_markdown",
)

if not vision:
# only live update the text tab
leader_component_values[:] = [default_md, p1, p2, p3, p4]
leader_component_values[:] = [default_md, p1, p2, p3, p4]

if show_plot:
more_stats_md = gr.Markdown(
Expand Down Expand Up @@ -740,7 +740,7 @@ def build_full_leaderboard_tab(elo_results, model_table_df):


def build_leaderboard_tab(
elo_results_file, leaderboard_table_file, show_plot=False, mirror=False
elo_results_file, leaderboard_table_file, vision=True, show_plot=False, mirror=False
):
if elo_results_file is None: # Do live update
default_md = "Loading ..."
Expand Down Expand Up @@ -776,14 +776,15 @@ def build_leaderboard_tab(
default_md,
show_plot=show_plot,
)
with gr.Tab("📣 NEW: Arena (Vision)", id=1):
build_arena_tab(
elo_results_vision,
model_table_df,
default_md,
vision=True,
show_plot=show_plot,
)
if vision:
with gr.Tab("📣 NEW: Arena (Vision)", id=1):
build_arena_tab(
elo_results_vision,
model_table_df,
default_md,
vision=True,
show_plot=show_plot,
)
with gr.Tab("Full Leaderboard", id=2):
build_full_leaderboard_tab(elo_results_text, model_table_df)

Expand Down
Empty file added playground/__init__.py
Empty file.
Loading

0 comments on commit 68023e1

Please sign in to comment.