Skip to content

Commit

Permalink
Fix minor bug and improve evaluation and display for model evaluator.
Browse files Browse the repository at this point in the history
  • Loading branch information
liffiton committed Jul 20, 2024
1 parent 16b0510 commit b586f24
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions dev/model_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,19 @@ def gen_responses(args):

def choose_response_set(db, eval_model) -> int:
response_sets = db.execute("""
SELECT response_set.id, response_set.created, response_set.model, prompt_set.query_src_file, prompt_set.prompt_func, eval_set.model AS eval_model
SELECT response_set.id, response_set.created, response_set.model, prompt_set.query_src_file, prompt_set.prompt_func, eval_set.model==? AS eval_with_this_model
FROM response_set
JOIN prompt_set ON response_set.prompt_set_id=prompt_set.id
LEFT JOIN eval_set ON eval_set.response_set_id=response_set.id
LEFT JOIN eval_set ON eval_set.response_set_id=response_set.id AND eval_set.model=?
ORDER BY response_set.created
""").fetchall()
""", [eval_model, eval_model]).fetchall()

funcs = {}
allowed_ids = [] # only allow running an eval that hasn't already been done with this model

print("Response sets:")
for response_set in response_sets:
already_evaled = response_set['eval_model'] == eval_model
already_evaled = response_set['eval_with_this_model']
if not already_evaled:
allowed_ids.append(response_set['id'])
else:
Expand All @@ -141,15 +141,17 @@ def choose_response_set(db, eval_model) -> int:


_SUFFICIENT_SYS_PROMPT = """\
You are grading responses given to a student who requested help in a CS class.
You are an automated system grading responses given to a student who requested help in a CS class.
Evaluate the given response (in <response> delimiters) by comparing it to the given model (in <model> delimiters).
An ideal response will request or mention every individual point in the model.
For each specific point in the model, evaluate whether it is covered in the response.
Output a JSON object with a key for each point, mapping each to true if the point is covered and false otherwise. Output nothing after the JSON.
Output a JSON object with a key for each point, mapping each to true if the point is covered and false otherwise.
Output nothing after the JSON.
"""


Expand All @@ -159,6 +161,10 @@ def eval_sufficient(model, row):
if model_response == "OK.":
# special case; can check with simple text processing
return {"OK.": "OK." in response}
elif "OK." in response:
# And if the model response is *not* "OK." but the real response includes it,
# we immediately know that's incorrect.
return {x: False for x in model_response.splitlines()}

msgs = [
{"role": "system", "content": _SUFFICIENT_SYS_PROMPT},
Expand Down Expand Up @@ -197,8 +203,9 @@ def gen_evals(args):
summarize_func = summarize_eval_insufficient

# Add system prompt if not used previously, get its ID
cur = db.execute("INSERT OR IGNORE INTO eval_prompt (sys_prompt) VALUES (?)", [sys_prompt])
eval_prompt_id = cur.lastrowid
# SET id=id is no-op, but we need to do an update so we can get the id using RETURNING
cur = db.execute("INSERT INTO eval_prompt (sys_prompt) VALUES (?) ON CONFLICT DO UPDATE SET id=id RETURNING id", [sys_prompt])
eval_prompt_id = cur.fetchone()['id']
db.commit()

# Create an eval set
Expand Down Expand Up @@ -227,13 +234,13 @@ def summarize_eval_insufficient(db, eval_set_id):
all_points = list(chain.from_iterable(d.keys() for d in evals))
ok_total = sum(x == "OK." for x in all_points)
other_total = sum(x != "OK." for x in all_points)
print(f"{len(evals)} evaluations. {len(all_points)} points. {ok_total} OK. {other_total} Other.")
#print(f"{len(evals)} evaluations. {len(all_points)} points. {ok_total} OK. {other_total} Other.")
ok_true = sum(eval_dict.get("OK.") == True for eval_dict in evals)
ok_false = sum(eval_dict.get("OK.") == False for eval_dict in evals)
print(f" OK.: {ok_true} true, {ok_false} false")
print(f" OK.: {'[32m-[m' * ok_true}{'[31m-[m' * ok_false} {ok_true}/{ok_false}")
other_true = sum(sum(eval_dict.get(key) == True for key in eval_dict if key != "OK.") for eval_dict in evals)
other_false = sum(sum(eval_dict.get(key) == False for key in eval_dict if key != "OK.") for eval_dict in evals)
print(f" Other: {other_true} true, {other_false} false")
print(f" Other: {'[32m-[m' * other_true}{'[31m-[m' * other_false} {other_true}/{other_false}")


def show_evals(args):
Expand Down Expand Up @@ -268,7 +275,7 @@ def show_all_evals(args):
).fetchall()

for row in eval_set_rows:
print(f"{row['id']}: \x1B[36m{row['prompt_func']}+{row['prompt_created']}\x1B[m (response: \x1B[32m{row['response_model']}\x1B[m) \x1B[30;1m(eval: {row['model']})\x1B[m")
print(f"{row['id']}: \x1B[36m{row['prompt_func']}+{row['prompt_created']}\x1B[m (response: \x1B[33m{row['response_model']}\x1B[m) \x1B[30;1m(eval: {row['model']})\x1B[m")

eval_set_id = row['id']

Expand Down

0 comments on commit b586f24

Please sign in to comment.