Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get reformulator working #7

Open
wants to merge 10 commits into
base: public
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -391,19 +391,15 @@ First, ensure that you API keys are set in you env variables.

Next, install the [AnkiConnect](https://ankiweb.net/shared/info/2055492159) Anki addon if you don't already have it.

#### Reformulator

Next... create a database? it expects a sqlite db in databases/reformulator/reformulator?
* Can handle it in code

Next... something about adding a field called `AnkiReformulator` to notes you want to change?
* Do you have to create a special note type for this to work?
#### Reformulator

The reformulator expects the notes you modify to have a specific field present so that it can save the old versions and add logging. Modify the note type you want to reformulate by adding a `AnkiReformulator` field to it.
The Reformulator can be run from the command line:

```bash
python reformulator.py \
--query "note:Basic (rated:2:1 OR rated:2:2) -is:suspended" \
--query "note:Cloze (rated:2:1 OR rated:2:2) -is:suspended" \
--dataset_path "examples/reformulator_dataset.txt" \
--string_formatting "examples/string_formatting.py" \
--ntfy_url "ntfy.sh/YOUR_TOPIC" \
Expand Down
54 changes: 32 additions & 22 deletions reformulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def handle_exception(exc_type, exc_value, exc_traceback):
print(json.dumps(db_content, ensure_ascii=False, indent=4))
return
else:
# sync_anki()
sync_anki()
assert query is not None, "Must specify --query"
assert dataset_path is not None, "Must specify --dataset_path"
litellm.set_verbose = verbose
Expand All @@ -222,7 +222,12 @@ def handle_exception(exc_type, exc_value, exc_traceback):
parallel = int(parallel)
main_field_index = int(main_field_index)
assert main_field_index >= 0, "invalid field_index"
self.base_query = query
self.dataset_path = dataset_path
self.mode = mode
self.exclude_done = exclude_done
self.exclude_version = exclude_version

if string_formatting:
red(f"Loading specific string formatting from {string_formatting}")
cloze_input_parser, cloze_output_parser = load_formatting_funcs(
Expand Down Expand Up @@ -254,11 +259,14 @@ def handle_exception(exc_type, exc_value, exc_traceback):
else:
raise Exception(f"{llm} not found in llm_price")
self.verbose = verbose
if mode == "reformulate":
if exclude_done:

def reformulate(self):
query = self.base_query
if self.mode == "reformulate":
if self.exclude_done:
query += " -AnkiReformulator::Done::*"

if exclude_version:
if self.exclude_version:
query += f" -AnkiReformulator:\"*version*=*'{self.VERSION}'*\""

# load db just in case
Expand All @@ -271,13 +279,12 @@ def handle_exception(exc_type, exc_value, exc_traceback):
self.db_content = self.load_db()
assert self.db_content, "Could not create database"

# TODO: What should be in the database normally? This fails with an empty database
whi("Computing estimated costs")
# self.compute_cost(self.db_content)
self.compute_cost(self.db_content)

# load dataset
whi("Loading dataset")
dataset = load_dataset(dataset_path)
dataset = load_dataset(self.dataset_path)
# check that each note is valid but exclude the system prompt, which is
# the first entry
for id, d in enumerate(dataset[1:]):
Expand All @@ -295,17 +302,19 @@ def handle_exception(exc_type, exc_value, exc_traceback):
red(f"Found {len(nids)} notes with tag AnkiReformulator::DOING : {nids}")

# find notes ids for the specific note type
nids = anki(action="findNotes", query="note:AnkiAITest")
# nids = anki(action="findNotes", query=query)
nids = anki(action="findNotes", query=query)
assert nids, f"No notes found for the query '{query}'"

# find the field names for this note type
fields = anki(action="notesInfo",
notes=[int(nids[0])])[0]["fields"]
assert "AnkiReformulator" in fields.keys(), \
"The notetype to edit must have a field called 'AnkiReformulator'"
# NOTE: This gets the first field. Is that what we want? Or do we specifically want the AnkiReformulator field?
self.field_name = list(fields.keys())[0]
try:
self.field_name = list(fields.keys())[self.field_index]
except IndexError:
raise AssertionError(f"main_field_index {self.field_index} is invalid. "
f"Note only has {len(fields.keys())} fields!")

if self.exclude_media:
# now find notes ids after excluding the img in the important field
Expand All @@ -316,7 +325,7 @@ def handle_exception(exc_type, exc_value, exc_traceback):
query += f' -{self.field_name}:"*http://*"'
query += f' -{self.field_name}:"*https://*"'

whi(f"Query to find note: {query}")
whi(f"Query to find note: '{query}'")
nids = anki(action="findNotes", query=query)
assert nids, f"No notes found for the query '{query}'"
whi(f"Found {len(nids)} notes")
Expand Down Expand Up @@ -357,7 +366,7 @@ def handle_exception(exc_type, exc_value, exc_traceback):
else:
assert not tag.lower().startswith("ankireformulator")

# check if too many tokens
# check if required tokens are higher than our limits
tkn_sum = sum(tkn_len(d["content"]) for d in self.dataset)
tkn_sum += sum(tkn_len(replace_media(content=note["fields"][self.field_name]["value"],
media=None,
Expand All @@ -370,9 +379,9 @@ def handle_exception(exc_type, exc_value, exc_traceback):
f"which is higher than the limit of {n_note_limit}")

if self.mode == "reformulate":
func = self.reformulate
func = self.reformulate_note
elif self.mode == "reset":
func = self.reset
func = self.reset_note
else:
raise ValueError(f"Unknown mode {self.mode}")

Expand Down Expand Up @@ -442,12 +451,11 @@ def compute_cost(self, db_content: List[Dict]) -> None:
This is used to know if something went wrong.
"""
n_db = len(db_content)
red(f"Number of entries in databases/reforumulator/reformulator.db: {n_db}")
red(f"Number of entries in databases/reformulator/reformulator.db: {n_db}")
dol_costs = []
dol_missing = 0
for dic in db_content:
# TODO: Mode isn't a field in the reformulator database dictionaries table
if dic["mode"] != "reformulate":
if self.mode != "reformulate":
continue
try:
dol = float(dic["dollar_price"])
Expand All @@ -469,7 +477,7 @@ def compute_cost(self, db_content: List[Dict]) -> None:
elif dol_costs:
self._cost_so_far = dol_total

def reformulate(self, nid: int, note: pd.Series) -> Dict:
def reformulate_note(self, nid: int, note: pd.Series) -> Dict:
"""Generate a reformulated version of a note's content using an LLM.

Parameters
Expand Down Expand Up @@ -640,7 +648,7 @@ def apply_reformulate(self, log: Dict) -> None:

new_minilog = rtoml.dumps(minilog, pretty=True)
new_minilog = new_minilog.strip().replace("\n", "<br>")
previous_minilog = note["fields"]["AnkiReformulator"]["value"].strip()
previous_minilog = note["fields"].get("AnkiReformulator", {}).get("value", "").strip()
if previous_minilog:
new_minilog += "<!--SEPARATOR-->"
new_minilog += "<br><br><details><summary>Older minilog</summary>"
Expand Down Expand Up @@ -670,6 +678,7 @@ def apply_reformulate(self, log: Dict) -> None:
nid,
fields={
self.field_name: log["note_field_formattednewcontent"],
# TODO: Might be nice to not require this

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iirc it was necessary to be rock solid sure that we can rollback easily. But yeah maybe we could just store the previous version and use only the db to handle rollbacks?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify: here I was refering to the fact that where have many strings lile "note_field_*" and would be nice if we didn't

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not entirely sure of the purpose of the db. Does it reflect changes that haven't been committed, or the latest version? having the db is nice for persistence, but it also opens up to some weird state where the notes don't match the database, so generally I prefer having a single source of truth.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of the database is only to act as a kind of very reliable and easy to parse logfile. It does not matter if there are inconsistencies because, for example if the user modified itself since the last time they ran the script. But it ensures complete reliability because it allows to roll back. To me, working with LLM's, it's very important to be able to rollback. Say in six months there is a shiny new LLM that is very cheap and possibly very good. Well, there is only one way to find out if it's good enough to handle real world notes. And then it's only after a few hundred reviews that you can actually judge if it does not make some weird edge case mistakes.

I don't know, just to give an example, at some point I realized that some of my cards related to hours were wrongly parsed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, yeah I see that it's only a log and is outside of anything with the anki reformulator field.

"AnkiReformulator": new_minilog,
},
)
Expand All @@ -685,7 +694,7 @@ def apply_reformulate(self, log: Dict) -> None:
# remove DOING tag
removetags(nid, "AnkiReformulator::DOING")

def reset(self, nid: int, note: pd.Series) -> Dict:
def reset_note(self, nid: int, note: pd.Series) -> Dict:
"""Reset a note back to its state before reformulation.

Parameters
Expand Down Expand Up @@ -992,7 +1001,8 @@ def load_db(self) -> Dict:
print(help(AnkiReformulator), file=sys.stderr)
else:
whi(f"Launching reformulator.py with args '{args}' and kwargs '{kwargs}'")
AnkiReformulator(*args, **kwargs)
r = AnkiReformulator(*args, **kwargs)
r.reformulate()
sync_anki()
except AssertionError as e:
red(e)
Expand Down
8 changes: 5 additions & 3 deletions utils/cloze_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def cloze_input_parser(cloze: str) -> str:
if you use weird formatting that mess with LLMs"""
assert iscloze(cloze), f"Invalid cloze: {cloze}"

# TODO: What is this?
cloze = cloze.replace("\xa0", " ")
Grazfather marked this conversation as resolved.
Show resolved Hide resolved

# make newlines consistent
Expand All @@ -37,7 +38,6 @@ def cloze_input_parser(cloze: str) -> str:
# make spaces consitent
cloze = cloze.replace("&nbsp;", " ")


# misc
cloze = cloze.replace("&gt;", ">")
cloze = cloze.replace("&ge;", ">=")
Expand All @@ -57,9 +57,12 @@ def cloze_output_parser(cloze: str) -> str:
cloze = cloze.strip()

# make sure all newlines are consistent for now
# TODO: You mean <br/>?
thiswillbeyourgithub marked this conversation as resolved.
Show resolved Hide resolved
cloze = cloze.replace("</br>", "<br>")
cloze = cloze.replace("<br/>", "<br>")
cloze = cloze.replace("\r", "<br>")
cloze = cloze.replace("<br>", "\n")
# TODO: Not needed
# cloze = cloze.replace("<br>", "\n")

# make sure all spaces are consistent
cloze = cloze.replace("&nbsp;", " ")
Expand All @@ -68,4 +71,3 @@ def cloze_output_parser(cloze: str) -> str:
cloze = cloze.replace("\n", "<br>")

return cloze

6 changes: 2 additions & 4 deletions utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
"30": 0.002,
"50": 0.004,
"100": 0.007,
# NOTE: Why is this one a string?
"150": "0.01"}
Grazfather marked this conversation as resolved.
Show resolved Hide resolved
"150": 0.01}

tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
llm_cache = Memory(".cache", verbose=0)
Expand Down Expand Up @@ -122,8 +121,7 @@ def wrapped_model_name_matcher(model: str) -> str:
return match[0]
else:
print(f"Couldn't match the modelname {model} to any known model. "
"Continuing but this will probably crash DocToolsLLM further "
"down the code.")
"Continuing but this will probably crash further down the code.")
return model


Expand Down