Skip to content

Commit

Permalink
Several updates to dev tools, including move to litellm.
Browse files Browse the repository at this point in the history
  • Loading branch information
liffiton committed Jun 23, 2024
1 parent 8738e36 commit 3a195d7
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 76 deletions.
28 changes: 12 additions & 16 deletions dev/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,26 @@
import csv
import importlib
import inspect
import os
import sys
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Any

from dotenv import load_dotenv
from openai import OpenAI
import litellm
from python_calamine import CalamineWorkbook


def test_and_report_model(model: str) -> None:
# Check for valid model
response = litellm.completion(
model=model,
messages=[{"role": "user", "content": "Write \"OK.\""}],
max_tokens=3,
)
assert response.choices[0].message.content.strip() == "OK."
print(f"Using model: \x1B[32m{model}\x1B[m")


def load_queries(file_path: Path) -> tuple[list[dict[str, Any]], Sequence[str]]:
""" Load query data from a spreadsheet (.csv, .ods, or .xlsx).
Assumes the first row contains column headers.
Expand Down Expand Up @@ -78,16 +87,3 @@ def make_prompt(prompt_func, item):
messages = [{"role": "user", "content": prompt_gen}]

return messages


def setup_openai() -> OpenAI:
# load config values from .env file
load_dotenv()
try:
openai_key = os.environ["OPENAI_API_KEY"]
except KeyError:
print("Error: OPENAI_API_KEY environment variable not set.", file=sys.stderr)
sys.exit(1)

client = OpenAI(api_key=openai_key)
return client
194 changes: 155 additions & 39 deletions dev/model_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,16 @@
import argparse
import json
import sqlite3
import sys
import time
from itertools import chain
from pathlib import Path

from loaders import load_prompt, load_queries, make_prompt, setup_openai
import litellm
from loaders import load_prompt, load_queries, make_prompt, test_and_report_model
from tqdm.auto import tqdm

DEFAULT_MODEL = "anthropic/claude-3-haiku-20240307"
TEMPERATURE = 0.25
MAX_TOKENS = 1000

Expand Down Expand Up @@ -68,131 +74,241 @@ def gen_responses(args):

prompt_set_id = choose_prompt_set(db)

client = setup_openai()

cur = db.execute("INSERT INTO response_set(model, prompt_set_id) VALUES (?, ?)", [args.model, prompt_set_id])
db.commit()
response_set_id = cur.lastrowid

prompts = db.execute("SELECT * FROM prompt WHERE set_id=?", [prompt_set_id]).fetchall()

for i, prompt in enumerate(prompts):
for prompt in tqdm(prompts, ncols=60):
msgs = json.loads(prompt['msgs_json'])
try:
response = client.chat.completions.create(
response = litellm.completion(
model=args.model,
messages=msgs,
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
n=1,
)
response_json = json.dumps(response.model_dump())
text = response.choices[0].message.content
except Exception as e: # noqa
response = f"[An error occurred in the openai completion.]\n{e}"
text = response
text = f"[An error occurred in the completion.]\n{e}"
tqdm.write(f"\x1B31m{text}\x1B[m")
response_json = json.dumps(text)

db.execute(
"INSERT INTO response(set_id, prompt_id, response, text) VALUES(?, ?, ?, ?)",
[response_set_id, prompt['id'], json.dumps(response.model_dump()), text]
[response_set_id, prompt['id'], response_json, text]
)
print(f"{i+1}/{len(prompts)}")
db.commit()

# hack for now for Claude rate limits
if "sonnet" in args.model:
time.sleep(0.25)


def choose_response_set(db) -> int:
response_sets = db.execute("SELECT response_set.id, response_set.created, response_set.model, prompt_set.query_src_file, prompt_set.prompt_func FROM response_set JOIN prompt_set ON response_set.prompt_set_id=prompt_set.id").fetchall()
def choose_response_set(db, eval_model) -> int:
response_sets = db.execute("""
SELECT response_set.id, response_set.created, response_set.model, prompt_set.query_src_file, prompt_set.prompt_func, eval_set.model AS eval_model
FROM response_set
JOIN prompt_set ON response_set.prompt_set_id=prompt_set.id
LEFT JOIN eval_set ON eval_set.response_set_id=response_set.id
ORDER BY response_set.created
""").fetchall()

funcs = {}
allowed_ids = [] # only allow running an eval that hasn't already been done with this model

print("Response sets:")
for response_set in response_sets:
already_evaled = response_set['eval_model'] == eval_model
if not already_evaled:
allowed_ids.append(response_set['id'])
else:
print("\x1B[30;1m", end='') # grayed out
print(f"{response_set['id']}: {response_set['created']} - {response_set['query_src_file']} {response_set['prompt_func']} -> {response_set['model']}")
print("\x1B[m", end='')
funcs[response_set['id']] = response_set['prompt_func']

response_set_id = int(input("Select a response set (by ID): "))

if response_set_id not in allowed_ids:
print("That response set has already been evaluated with {eval_model}!")
sys.exit(1)

return response_set_id, funcs[response_set_id]


def eval_sufficient(client, model, row):
_SUFFICIENT_SYS_PROMPT = """\
You are grading responses given to a student who requested help in a CS class.
Evaluate the given response (in <response> delimiters) by comparing it to the given model (in <model> delimiters).
An ideal response will request or mention every individual point in the model.
For each specific point in the model, evaluate whether it is covered in the response.
Output a JSON object with a key for each point, mapping each to true if the point is covered and false otherwise. Output nothing after the JSON.
"""


def eval_sufficient(model, row):
response = row['text']
model_response = row['model_response']
if model_response == "OK.":
# special case; can check with simple text processing
return {"OK.": "OK." in response}

msgs = [
{"role": "system", "content": """\
Evaluate the given text (in <text> delimiters) by comparing it to the given model (in <model> delimiters).
For each specific point in the model, evaluate whether it is addressed or mentioned in the text. Output a JSON object with a key for each point, mapping each to true if the point is included and false otherwise."""},
{"role": "user", "content": f"<text>\n{response}\n</text>\n\n<model>\n{model_response}\n</model>"},
{"role": "system", "content": _SUFFICIENT_SYS_PROMPT},
{"role": "user", "content": f"<response>\n{response}\n</response>\n\n<model>\n{model_response}\n</model>"},
]
response = client.chat.completions.create(
litellm.drop_params = True # still run if 'response_format' not accepted by the current model
response = litellm.completion(
model=model,
response_format={ "type": "json_object" },
messages=msgs,
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
n=1,
)
litellm.drop_params = False # reset to default
text = response.choices[0].message.content
return json.loads(text)

try:
return json.loads(text)
except json.decoder.JSONDecodeError:
print(f"\x1B[31;1mInvalid:\x1B[m\n\x1B[33m{text}\x1B[m")
raise

def gen_evals(args):
client = setup_openai()

def gen_evals(args):
db = get_db(args.db_path)

response_set_id, prompt_func = choose_response_set(db)
response_set_id, prompt_func = choose_response_set(db, args.model)

rows = db.execute("SELECT response.text, prompt.model_response FROM response JOIN prompt ON response.prompt_id=prompt.id WHERE response.set_id=?", [response_set_id]).fetchall()
rows = db.execute("SELECT response.id, response.text, prompt.model_response FROM response JOIN prompt ON response.prompt_id=prompt.id WHERE response.set_id=?", [response_set_id]).fetchall()

match prompt_func:
case "make_sufficient_prompt":
sys_prompt = _SUFFICIENT_SYS_PROMPT
eval_func = eval_sufficient
summarize_func = summarize_eval_insufficient

evals = []
# Add system prompt if not used previously, get its ID
cur = db.execute("INSERT OR IGNORE INTO eval_prompt (sys_prompt) VALUES (?)", [sys_prompt])
eval_prompt_id = cur.lastrowid
db.commit()

# Create an eval set
cur = db.execute("INSERT INTO eval_set (response_set_id, eval_prompt_id, model) VALUES (?, ?, ?)", [response_set_id, eval_prompt_id, args.model])
eval_set_id = cur.lastrowid

# Generate and add the evaluations
for row in tqdm(rows, ncols=60):
evaluation = eval_func(args.model, row)
db.execute("INSERT INTO eval (set_id, response_id, evaluation) VALUES (?, ?, ?)", [eval_set_id, row['id'], json.dumps(evaluation)])

for i, row in enumerate(rows):
evaluation = eval_func(client, args.model, row)
evals.append(evaluation)
print(f"{i+1}/{len(rows)}")
if False in evaluation.values():
print(row['text'])
print(evaluation)
tqdm.write(row['text'])
tqdm.write(str(evaluation))

db.commit() # only commit if we've generated all rows

summarize_func(db, eval_set_id)


def summarize_eval_insufficient(db, eval_set_id):
eval_rows = db.execute("SELECT * FROM eval WHERE set_id=?", [eval_set_id]).fetchall()

evals = [json.loads(row['evaluation']) for row in eval_rows]

all_points = list(chain.from_iterable(d.keys() for d in evals))
ok_total = sum(x == "OK." for x in all_points)
other_total = sum(x != "OK." for x in all_points)
print(f"{len(evals)} evaluations. {len(all_points)} points. {ok_total} OK. {other_total} Other.")
ok_true = sum(eval_dict.get("OK.") == True for eval_dict in evals)
ok_false = sum(eval_dict.get("OK.") == False for eval_dict in evals)
print(f"OK.: {ok_true} true, {ok_false} false")
print(f" OK.: {ok_true} true, {ok_false} false")
other_true = sum(sum(eval_dict.get(key) == True for key in eval_dict if key != "OK.") for eval_dict in evals)
other_false = sum(sum(eval_dict.get(key) == False for key in eval_dict if key != "OK.") for eval_dict in evals)
print(f"Other: {other_true} true, {other_false} false")
print(f" Other: {other_true} true, {other_false} false")


def show_evals(args):
if args.eval_set is None:
show_all_evals(args)
else:
show_one_eval(args)


def show_one_eval(args):
db = get_db(args.db_path)

eval_rows = db.execute("SELECT * FROM eval JOIN response ON response.id=eval.response_id JOIN prompt ON prompt.id=response.prompt_id WHERE eval.set_id = ?", [args.eval_set]).fetchall()
for row in eval_rows:
if False in json.loads(row['evaluation']).values(): # check if points evaluated as False
print(f"\x1B[33m{row['text']}\x1B[m")
print(f"\x1B[36m{row['evaluation']}\x1B[m")

summarize_eval_insufficient(db, args.eval_set)

def show_all_evals(args):
db = get_db(args.db_path)

eval_set_rows = db.execute("""
SELECT eval_set.*, response_set.model AS response_model, prompt_set.prompt_func, prompt_set.created AS prompt_created
FROM eval_set
JOIN response_set ON response_set.id=eval_set.response_set_id
JOIN prompt_set ON prompt_set.id=response_set.prompt_set_id
"""
+ "ORDER BY prompt_set.prompt_func, prompt_set.created, response_set.model" if args.by_prompt
else "ORDER BY prompt_set.prompt_func, response_set.model, prompt_set.created"
).fetchall()

for row in eval_set_rows:
print(f"{row['id']}: \x1B[36m{row['prompt_func']}+{row['prompt_created']}\x1B[m (response: \x1B[32m{row['response_model']}\x1B[m) \x1B[30;1m(eval: {row['model']})\x1B[m")

eval_set_id = row['id']

summarize_eval_insufficient(db, eval_set_id)


def main() -> None:
parser = argparse.ArgumentParser(description='A tool for running queries against data from a CSV/ODS/XLSX file and evaluating a model\'s responses.')
parser.add_argument('app', type=str, help='The name of the application module from which to load prompts (e.g., codehelp or starburst).')
parser.add_argument('db_path', type=Path, help='Path to the database file storing prompts and evaluations.')
subparsers = parser.add_subparsers(required=True)

parser_load = subparsers.add_parser('load', help='Load a file of queries and model responses; store a generated set of prompts in the database.')
parser_load.add_argument('file_path', type=Path, help='Path to the file to be read.')
parser_load.set_defaults(command_func=load_data)
parser_load.add_argument('file_path', type=Path, help='Path to the file to be read.')

parser_response = subparsers.add_parser('response', help='Generate a response set for a given prompt set.')
parser_response.set_defaults(command_func=gen_responses)
parser_response.add_argument(
'model', type=str, nargs='?', default='gpt-3.5-turbo',
help='(Optional. Default="gpt-3.5-turbo") The LLM to use (gpt-{3.5-turbo, 4o, etc.}).'
'model', type=str, nargs='?', default=DEFAULT_MODEL,
help=f"(Optional. Default='{DEFAULT_MODEL}') The LLM to use."
)
parser_response.set_defaults(command_func=gen_responses)

parser_eval = subparsers.add_parser('eval', help='Evaluate a given response set.')
parser_eval.set_defaults(command_func=gen_evals)
parser_eval.add_argument(
'model', type=str, nargs='?', default='gpt-3.5-turbo',
help='(Optional. Default="gpt-3.5-turbo") The LLM to use (gpt-{3.5-turbo, 4o, etc.}).'
'model', type=str, nargs='?', default=DEFAULT_MODEL,
help=f"(Optional. Default='{DEFAULT_MODEL}') The LLM to use."
)
parser_eval.set_defaults(command_func=gen_evals)

parser_show_evals = subparsers.add_parser('show_evals', help="Display the results of past evals.")
parser_show_evals.set_defaults(command_func=show_evals)
parser_show_evals.add_argument('--by-prompt', action='store_true', help="Order the evaluations by the prompt used (default: order by model, then prompt)")
parser_show_evals.add_argument('eval_set', nargs='?', type=int, help="Show details of the given evaluation set.")

args = parser.parse_args()

if 'model' in args:
test_and_report_model(args.model)

# run the function associated with the chosen command
args.command_func(args)

Expand Down
Loading

0 comments on commit 3a195d7

Please sign in to comment.