Skip to content

Commit

Permalink
fix autoformat issues
Browse files Browse the repository at this point in the history
  • Loading branch information
cthorrez committed Sep 13, 2024
1 parent ae98043 commit 2838751
Showing 1 changed file with 46 additions and 25 deletions.
71 changes: 46 additions & 25 deletions fastchat/serve/monitor/elo_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,45 +69,56 @@ def get_bootstrap_result(battles, func_compute_elo, num_round=1000):
df = pd.DataFrame(rows)
return df[df.median().sort_values(ascending=False).index]


def preprocess_battles_to_arrays(df):
"""convert the battles df into numpy arrays optimized for BT likelihood calculation"""

models = pd.unique(df[['model_a', 'model_b']].values.ravel()).tolist()
model_to_idx = {model:idx for idx,model in enumerate(models)}
models = pd.unique(df[["model_a", "model_b"]].values.ravel()).tolist()
model_to_idx = {model: idx for idx, model in enumerate(models)}
# the 3 columns of schedule represent: model_a id, model_b id, outcome_id
schedule = np.empty((len(df), 3), dtype=np.int32)
# set the two model cols by mapping the model names to their int ids
schedule[:,[0,1]] = df[['model_a', 'model_b']].map(lambda x: model_to_idx[x]).values
schedule[:, [0, 1]] = (
df[["model_a", "model_b"]].map(lambda x: model_to_idx[x]).values
)
# map outcomes to integers (must be same dtype as model ids so it can be in the same array)
# model_a win -> 2, tie -> 1, model_b win -> 0
schedule[:,2] = np.select(
condlist=[df['winner'] == 'model_a', df['winner'] == 'model_b'],
schedule[:, 2] = np.select(
condlist=[df["winner"] == "model_a", df["winner"] == "model_b"],
choicelist=[2, 0],
default=1
default=1,
)
# count the number of occurances of each observed result
matchups_outcomes, weights = np.unique(schedule, return_counts=True, axis=0)
matchups = matchups_outcomes[:,[0,1]]
matchups = matchups_outcomes[:, [0, 1]]
# map 2 -> 1.0, 1 -> 0.5, 0 -> 0.0 which will be used as labels during optimization
outcomes = matchups_outcomes[:,2].astype(np.float64) / 2.0
outcomes = matchups_outcomes[:, 2].astype(np.float64) / 2.0
weights = weights.astype(np.float64)
# each possible result is weighted according to number of times it occured in the dataset
weights = weights / weights.sum()
return matchups, outcomes, weights, models


def bt_loss_and_grad(ratings, matchups, outcomes, weights, alpha=1.0):
"""negative log likelihood and gradient for BT model with numpy array inputs"""
matchup_ratings = ratings[matchups]
logits = alpha * (matchup_ratings[:,0] - matchup_ratings[:,1])
logits = alpha * (matchup_ratings[:, 0] - matchup_ratings[:, 1])
probs = sigmoid(logits)
# this form naturally counts a draw as half a win and half a loss
loss = -((np.log(probs) * outcomes + np.log(1.0 - probs) * (1.0 - outcomes)) * weights).sum()
loss = -(
(np.log(probs) * outcomes + np.log(1.0 - probs) * (1.0 - outcomes)) * weights
).sum()
matchups_grads = -alpha * (outcomes - probs) * weights
model_grad = np.zeros_like(ratings)
# aggregate gradients at the model level using the indices in matchups
np.add.at(model_grad, matchups[:, [0, 1]], matchups_grads[:, None] * np.array([1.0, -1.0], dtype=np.float64))
np.add.at(
model_grad,
matchups[:, [0, 1]],
matchups_grads[:, None] * np.array([1.0, -1.0], dtype=np.float64),
)
return loss, model_grad


def fit_bt(matchups, outcomes, weights, n_models, alpha, tol=1e-6):
"""perform the BT likelihood optimization"""
initial_ratings = np.zeros(n_models, dtype=np.float64)
Expand All @@ -116,20 +127,28 @@ def fit_bt(matchups, outcomes, weights, n_models, alpha, tol=1e-6):
x0=initial_ratings,
args=(matchups, outcomes, weights, alpha),
jac=True,
method='L-BFGS-B',
options={'disp' : False, 'maxiter': 100, 'gtol': tol},
method="L-BFGS-B",
options={"disp": False, "maxiter": 100, "gtol": tol},
)
return result["x"]


def scale_and_offset(ratings, models, scale=400, init_rating=1000, baseline_model="mixtral-8x7b-instruct-v0.1", baseline_rating=1114):
def scale_and_offset(
ratings,
models,
scale=400,
init_rating=1000,
baseline_model="mixtral-8x7b-instruct-v0.1",
baseline_rating=1114,
):
"""convert ratings from the natural scale to the Elo rating scale with an anchored baseline"""
scaled_ratings = (ratings * scale) + init_rating
if baseline_model in models:
baseline_idx = models.index(baseline_model)
scaled_ratings += (baseline_rating - scaled_ratings[..., [baseline_idx]])
scaled_ratings += baseline_rating - scaled_ratings[..., [baseline_idx]]
return scaled_ratings


def compute_elo_mle_with_tie(
df,
SCALE=400,
Expand All @@ -140,22 +159,25 @@ def compute_elo_mle_with_tie(
):
matchups, outcomes, weights, models = preprocess_battles_to_arrays(df)
ratings = fit_bt(matchups, outcomes, weights, len(models), np.log(BASE))
scaled_ratings = scale_and_offset(ratings, models, SCALE, INIT_RATING, baseline_model, baseline_rating)
scaled_ratings = scale_and_offset(
ratings, models, SCALE, INIT_RATING, baseline_model, baseline_rating
)
return pd.Series(scaled_ratings, index=models).sort_values(ascending=False)

def get_bootstrap_result_elo_mle_with_tie(df, num_round, BASE=10.0, SCALE=400.0, INIT_RATING=1000.0):


def get_bootstrap_result_elo_mle_with_tie(
df, num_round, BASE=10.0, SCALE=400.0, INIT_RATING=1000.0
):
matchups, outcomes, weights, models = preprocess_battles_to_arrays(battles)
# bootstrap sample the unique outcomes and their counts directly using the multinomial distribution
idxs = np.random.multinomial(
n=len(battles),
pvals=weights,
size=(num_round)
)
idxs = np.random.multinomial(n=len(battles), pvals=weights, size=(num_round))
# only the distribution over their occurance counts changes between samples (and it can be 0)
boot_weights = idxs.astype(np.float64) / len(battles)

# the only thing different across samples is the distribution of weights
bt_fn = partial(fit_bt, matchups, outcomes, n_models=len(models), alpha=np.log(BASE))
bt_fn = partial(
fit_bt, matchups, outcomes, n_models=len(models), alpha=np.log(BASE)
)

with mp.Pool(os.cpu_count()) as pool:
results = pool.map(bt_fn, boot_weights)
Expand All @@ -164,7 +186,6 @@ def get_bootstrap_result_elo_mle_with_tie(df, num_round, BASE=10.0, SCALE=400.0,
scaled_ratings = scale_and_offset(ratings, models, SCALE, INIT_RATING)
df = pd.DataFrame(scaled_ratings, columns=models)
return df[df.median().sort_values(ascending=False).index]



def get_median_elo_from_bootstrap(bootstrap_df):
Expand Down

0 comments on commit 2838751

Please sign in to comment.