Skip to content

Commit

Permalink
Expand Style Control to all category! Add UI support for style contro…
Browse files Browse the repository at this point in the history
…l and deprecated models. (#3517)
  • Loading branch information
CodingWithTim authored Sep 9, 2024
1 parent 853168f commit ef16c16
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 24 deletions.
97 changes: 73 additions & 24 deletions fastchat/serve/monitor/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from fastchat.serve.monitor.monitor_md import (
cat_name_to_baseline,
key_to_category_name,
cat_name_to_explanation,
deprecated_model_name,
arena_hard_title,
make_default_md_1,
make_default_md_2,
Expand Down Expand Up @@ -258,10 +260,14 @@ def create_ranking_str(ranking, ranking_difference):
return f"{int(ranking)}"


def get_arena_table(arena_df, model_table_df, arena_subset_df=None):
def get_arena_table(arena_df, model_table_df, arena_subset_df=None, hidden_models=None):
arena_df = arena_df.sort_values(
by=["final_ranking", "rating"], ascending=[True, False]
)

if hidden_models:
arena_df = arena_df[~arena_df.index.isin(hidden_models)].copy()

arena_df["final_ranking"] = recompute_final_ranking(arena_df)

if arena_subset_df is not None:
Expand Down Expand Up @@ -317,9 +323,11 @@ def process_row(row):
round(row["num_battles"]),
model_info.get("Organization", "Unknown"),
model_info.get("License", "Unknown"),
"Unknown"
if model_info.get("Knowledge cutoff date", "-") == "-"
else model_info.get("Knowledge cutoff date", "Unknown"),
(
"Unknown"
if model_info.get("Knowledge cutoff date", "-") == "-"
else model_info.get("Knowledge cutoff date", "Unknown")
),
]
)

Expand Down Expand Up @@ -350,21 +358,25 @@ def update_leaderboard_df(arena_table_vals):

def highlight_max(s):
return [
"color: green; font-weight: bold"
if "\u2191" in str(v)
else "color: red; font-weight: bold"
if "\u2193" in str(v)
else ""
(
"color: green; font-weight: bold"
if "\u2191" in str(v)
else "color: red; font-weight: bold"
if "\u2193" in str(v)
else ""
)
for v in s
]

def highlight_rank_max(s):
return [
"color: green; font-weight: bold"
if v > 0
else "color: red; font-weight: bold"
if v < 0
else ""
(
"color: green; font-weight: bold"
if v > 0
else "color: red; font-weight: bold"
if v < 0
else ""
)
for v in s
]

Expand Down Expand Up @@ -398,7 +410,13 @@ def build_arena_tab(

arena_df = arena_dfs["Overall"]

def update_leaderboard_and_plots(category):
def update_leaderboard_and_plots(category, filters):
if len(filters) > 0 and "Style Control" in filters:
if f"{category} (Style Control)" in arena_dfs:
category = f"{category} (Style Control)"
else:
gr.Warning("This category does not support style control.")

arena_subset_df = arena_dfs[category]
arena_subset_df = arena_subset_df[arena_subset_df["num_battles"] > 300]
elo_subset_results = category_elo_results[category]
Expand All @@ -409,6 +427,11 @@ def update_leaderboard_and_plots(category):
arena_df,
model_table_df,
arena_subset_df=arena_subset_df if category != "Overall" else None,
hidden_models=(
None
if len(filters) > 0 and "Show Deprecate" in filters
else deprecated_model_name
),
)
if category != "Overall":
arena_values = update_leaderboard_df(arena_values)
Expand Down Expand Up @@ -490,7 +513,9 @@ def update_leaderboard_and_plots(category):
p4 = category_elo_results["Overall"]["average_win_rate_bar"]

# arena table
arena_table_vals = get_arena_table(arena_df, model_table_df)
arena_table_vals = get_arena_table(
arena_df, model_table_df, hidden_models=deprecated_model_name
)

md = make_arena_leaderboard_md(arena_df, last_updated_time, vision=vision)
gr.Markdown(md, elem_id="leaderboard_markdown")
Expand All @@ -501,6 +526,10 @@ def update_leaderboard_and_plots(category):
label="Category",
value="Overall",
)
with gr.Column(scale=2):
category_checkbox = gr.CheckboxGroup(
["Style Control", "Show Deprecate"], label="Apply filter", info=""
)
default_category_details = make_category_arena_leaderboard_md(
arena_df, arena_df, name="Overall"
)
Expand Down Expand Up @@ -599,7 +628,21 @@ def update_leaderboard_and_plots(category):
plot_2 = gr.Plot(p2, show_label=False)
category_dropdown.change(
update_leaderboard_and_plots,
inputs=[category_dropdown],
inputs=[category_dropdown, category_checkbox],
outputs=[
elo_display_df,
plot_1,
plot_2,
plot_3,
plot_4,
more_stats_md,
category_deets,
],
)

category_checkbox.change(
update_leaderboard_and_plots,
inputs=[category_dropdown, category_checkbox],
outputs=[
elo_display_df,
plot_1,
Expand Down Expand Up @@ -659,13 +702,19 @@ def get_arena_category_table(results_df, categories, metric="ranking"):

def highlight_top_3(s):
return [
"background-color: rgba(255, 215, 0, 0.5); text-align: center; font-size: 110%"
if v == 1 and v != 0
else "background-color: rgba(192, 192, 192, 0.5); text-align: center; font-size: 110%"
if v == 2 and v != 0
else "background-color: rgba(255, 165, 0, 0.5); text-align: center; font-size: 110%"
if v == 3 and v != 0
else "text-align: center; font-size: 110%"
(
"background-color: rgba(255, 215, 0, 0.5); text-align: center; font-size: 110%"
if v == 1 and v != 0
else (
"background-color: rgba(192, 192, 192, 0.5); text-align: center; font-size: 110%"
if v == 2 and v != 0
else (
"background-color: rgba(255, 165, 0, 0.5); text-align: center; font-size: 110%"
if v == 3 and v != 0
else "text-align: center; font-size: 110%"
)
)
)
for v in s
]

Expand Down
9 changes: 9 additions & 0 deletions fastchat/serve/monitor/monitor_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

from fastchat.constants import SURVEY_LINK

deprecated_model_name = [
"gemini-1.5-pro-exp-0801",
"gemini-1.5-pro-api-0409-preview",
]

key_to_category_name = {
"full": "Overall",
"full_style_control": "Overall w/ Style Control",
Expand All @@ -29,6 +34,8 @@
"no_refusal": "Exclude Refusal",
"overall_limit_5_user_vote": "overall_limit_5_user_vote",
"full_old": "Overall (Deprecated)",
"full_style_control": "Overall (Style Control)",
"hard_6_style_control": "Hard Prompts (Overall) (Style Control)",
}
cat_name_to_explanation = {
"Overall": "Overall Questions",
Expand All @@ -55,6 +62,8 @@
"Exclude Refusal": 'Exclude model responses with refusal (e.g., "I cannot answer")',
"overall_limit_5_user_vote": "overall_limit_5_user_vote",
"Overall (Deprecated)": "Overall without De-duplicating Top Redundant Queries (top 0.1%). See details in [blog post](https://lmsys.org/blog/2024-05-17-category-hard/#note-enhancing-quality-through-de-duplication).",
"Overall (Style Control)": "Overall Leaderboard with Style Control. See details in [blog post](https://lmsys.org/blog/2024-08-28-style-control/).",
"Hard Prompts (Overall) (Style Control)": "Hard Prompts (Overall) Leaderboard with Style Control. See details in [blog post](https://lmsys.org/blog/2024-08-28-style-control/).",
}
cat_name_to_baseline = {
"Hard Prompts (English)": "English",
Expand Down

0 comments on commit ef16c16

Please sign in to comment.