Skip to content

Commit

Permalink
add enable_short_term to __main__.py
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Jan 26, 2025
1 parent be0c7da commit f4d518c
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/fsrs_optimizer/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,16 @@ def process(filepath, filter_out_flags: list[int]):
"revlog_start_date": "2006-10-05",
"preview": "y",
"filter_out_suspended_cards": "n",
"enable_short_term": "y",
}

# Prompts the user with the key and then falls back on the last answer given.
def remembered_fallback_prompt(key: str, pretty: str = None):
if pretty is None:
pretty = key
remembered_fallbacks[key] = prompt(f"input {pretty}", remembered_fallbacks[key])
remembered_fallbacks[key] = prompt(
f"input {pretty}", remembered_fallbacks.get(key, None)
)

print("The defaults will switch to whatever you entered last.\n")

Expand All @@ -66,6 +69,9 @@ def remembered_fallback_prompt(key: str, pretty: str = None):
remembered_fallback_prompt(
"filter_out_suspended_cards", "filter out suspended cards? (y/n)"
)
remembered_fallback_prompt(
"enable_short_term", "enable short-term component in FSRS model? (y/n)"
)

graphs_input = prompt("Save graphs? (y/n)", remembered_fallbacks["preview"])
else:
Expand All @@ -82,8 +88,9 @@ def remembered_fallback_prompt(key: str, pretty: str = None):
json.dump(remembered_fallbacks, f)

save_graphs = graphs_input != "n"
enable_short_term = remembered_fallbacks["enable_short_term"] == "y"

optimizer = fsrs_optimizer.Optimizer()
optimizer = fsrs_optimizer.Optimizer(enable_short_term=enable_short_term)
if filepath.endswith(".apkg") or filepath.endswith(".colpkg"):
optimizer.anki_extract(
f"{filepath}",
Expand All @@ -108,7 +115,7 @@ def remembered_fallback_prompt(key: str, pretty: str = None):
for i, f in enumerate(figures):
f.savefig(f"pretrain_{i}.png")
plt.close(f)
figures = optimizer.train(verbose=save_graphs)
figures = optimizer.train(verbose=save_graphs, recency_weight=True)
for i, f in enumerate(figures):
f.savefig(f"train_{i}.png")
plt.close(f)
Expand Down

0 comments on commit f4d518c

Please sign in to comment.