Skip to content

Commit

Permalink
Better storage of trie_store
Browse files Browse the repository at this point in the history
  • Loading branch information
AG committed Mar 4, 2024
1 parent d05e162 commit f6c97ef
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ def save_trie_store(trie_store):
pickle.dump(trie_store, f, protocol=pickle.HIGHEST_PROTOCOL)
print("trie_store saved due to interruption.")

DEFAULT_TRIE_STORE ={'tries': {'3_words': {}, '2_words': {}, '1_word': {}}, 'scores': {}}

def load_trie_store():
try:
with open('training/trie_store.pkl', 'rb') as f:
return pickle.load(f)
except FileNotFoundError:
return {'tries': {'3_words': {}, '2_words': {}, '1_word': {}}, 'scores': {}}
return DEFAULT_TRIE_STORE

# Define a function to slugify context words into a filename-safe string
def _slugify(text):
Expand Down Expand Up @@ -112,6 +114,7 @@ def main():
print("Retained data loaded.")
else:
# If not retaining data, clear existing training directory and start fresh
trie_store = DEFAULT_TRIE_STORE
if os.path.exists('training'):
shutil.rmtree('training')
print("Previous training data cleared.")
Expand Down Expand Up @@ -213,7 +216,7 @@ def main():
## Three word alternative
context_words_1 = words[i+2:i+3]
finish_filing(trie_store, context_words_1, predictive_words_2, "1_word")
flatten_to_dictionary(trie_store)
flatten_to_dictionary(trie_store, TARGET_DICTIONARY_COUNT)

def finish_filing(trie_store, context_words, predictive_words, dictionary_subpath):
# Slugify the context words
Expand Down

0 comments on commit f6c97ef

Please sign in to comment.